diff --git a/vortex-array/benches/expr/case_when_bench.rs b/vortex-array/benches/expr/case_when_bench.rs index e25ad180fa2..520a250c124 100644 --- a/vortex-array/benches/expr/case_when_bench.rs +++ b/vortex-array/benches/expr/case_when_bench.rs @@ -11,6 +11,7 @@ use vortex_array::ArrayRef; use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::arrays::BoolArray; use vortex_array::arrays::StructArray; use vortex_array::expr::case_when; use vortex_array::expr::case_when_no_else; @@ -18,6 +19,7 @@ use vortex_array::expr::eq; use vortex_array::expr::get_item; use vortex_array::expr::gt; use vortex_array::expr::lit; +use vortex_array::expr::lt; use vortex_array::expr::nested_case_when; use vortex_array::expr::root; use vortex_array::session::ArraySession; @@ -39,6 +41,22 @@ fn make_struct_array(size: usize) -> ArrayRef { .into_array() } +/// Array with boolean columns cycling through thirds: `c0[i] = i%3==0`, `c1[i] = i%3==1`. +fn make_fragmented_array(size: usize) -> ArrayRef { + StructArray::from_fields(&[ + ( + "c0", + BoolArray::from_iter((0..size).map(|i| i % 3 == 0)).into_array(), + ), + ( + "c1", + BoolArray::from_iter((0..size).map(|i| i % 3 == 1)).into_array(), + ), + ]) + .unwrap() + .into_array() +} + /// Benchmark a simple binary CASE WHEN with varying array sizes. #[divan::bench(args = [1000, 10000, 100000])] fn case_when_simple(bencher: Bencher, size: usize) { @@ -185,6 +203,39 @@ fn case_when_all_true(bencher: Bencher, size: usize) { }); } +/// Benchmark n-ary CASE WHEN where the first branch dominates (~90% of rows). +/// This highlights the early-exit and deferred-merge optimizations: subsequent conditions +/// match no remaining rows and are skipped entirely. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_nary_early_dominant(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value < 90% THEN 1 WHEN value < 95% THEN 2 WHEN value < 97.5% THEN 3 ELSE 4 + let t1 = (size as i32 * 9) / 10; + let t2 = (size as i32 * 19) / 20; + let t3 = (size as i32 * 39) / 40; + + let expr = nested_case_when( + vec![ + (lt(get_item("value", root()), lit(t1)), lit(1i32)), + (lt(get_item("value", root()), lit(t2)), lit(2i32)), + (lt(get_item("value", root()), lit(t3)), lit(3i32)), + ], + Some(lit(4i32)), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} + /// Benchmark CASE WHEN where all conditions are false. #[divan::bench(args = [1000, 10000, 100000])] fn case_when_all_false(bencher: Bencher, size: usize) { @@ -208,3 +259,30 @@ fn case_when_all_false(bencher: Bencher, size: usize) { .unwrap() }); } + +/// Benchmark CASE WHEN cycling through 3 branches per row (triggers merge_row_by_row). +/// Run length = 1; exercises branch 0, branch 1, and the else fallback at every 3rd row. +#[divan::bench(args = [1000, 10000])] +fn case_when_fragmented(bencher: Bencher, size: usize) { + let array = make_fragmented_array(size); + + // CASE WHEN c0 THEN 0 WHEN c1 THEN 1 ELSE 2 END + let expr = nested_case_when( + vec![ + (get_item("c0", root()), lit(0i32)), + (get_item("c1", root()), lit(1i32)), + ], + Some(lit(2i32)), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index f05ee45bdad..2d24bc805ab 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -18,6 +18,8 @@ use std::sync::Arc; use prost::Message; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_mask::AllOr; +use vortex_mask::Mask; use vortex_proto::expr as pb; use vortex_session::VortexSession; @@ -26,6 +28,9 @@ use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::BoolArray; use crate::arrays::ConstantArray; +use crate::builders::ArrayBuilder; +use crate::builders::builder_with_capacity; +use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::expr::Expression; use crate::scalar::Scalar; @@ -198,43 +203,52 @@ impl ScalarFnVTable for CaseWhen { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { + // Inspired by https://datafusion.apache.org/blog/2026/02/02/datafusion_case/ + // + // TODO: shrink input to `remaining` rows between WHEN iterations (batch reduction). + // TODO: project to only referenced columns before batch reduction (column projection). + // TODO: evaluate THEN/ELSE on compact matching/non-matching rows and scatter-merge the results. + // TODO: for constant WHEN/THEN values, compile to a hash table for a single-pass lookup. let row_count = args.row_count(); let num_pairs = options.num_when_then_pairs as usize; - let mut result: ArrayRef = if options.has_else { - args.get(num_pairs * 2)? - } else { - let then_dtype = args.get(1)?.dtype().as_nullable(); - ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() - }; + let mut remaining = Mask::new_true(row_count); + let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs); - // TODO(perf): this reverse-zip approach touches every row for every condition. - // A left-to-right filter approach could maintain an "unmatched" mask, narrow it - // as conditions match, and exit early once all rows are resolved. - for i in (0..num_pairs).rev() { - let condition = args.get(i * 2)?; - let then_value = args.get(i * 2 + 1)?; + for i in 0..num_pairs { + if remaining.all_false() { + break; + } + let condition = args.get(i * 2)?; let cond_bool = condition.execute::(ctx)?; - let mask = cond_bool.to_mask_fill_null_false(); + let cond_mask = cond_bool.to_mask_fill_null_false(); + let effective_mask = &remaining & &cond_mask; - if mask.all_true() { - result = then_value; + if effective_mask.all_false() { continue; } - if mask.all_false() { - continue; - } + let then_value = args.get(i * 2 + 1)?; + remaining = remaining.bitand_not(&cond_mask); + branches.push((effective_mask, then_value)); + } - result = zip_impl(&then_value, &result, &mask)?; + let else_value: ArrayRef = if options.has_else { + args.get(num_pairs * 2)? + } else { + let then_dtype = args.get(1)?.dtype().as_nullable(); + ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() + }; + + if branches.is_empty() { + return Ok(else_value); } - Ok(result) + merge_case_branches(branches, else_value) } fn is_null_sensitive(&self, _options: &Self::Options) -> bool { - // CaseWhen is null-sensitive because NULL conditions are treated as false true } @@ -243,6 +257,113 @@ impl ScalarFnVTable for CaseWhen { } } +/// Average run length at which slicing + `extend_from_array` becomes cheaper than `scalar_at`. +/// Measured empirically via benchmarks. +const SLICE_CROSSOVER_RUN_LEN: usize = 4; + +/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` into a single array. +/// +/// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`]. +fn merge_case_branches( + branches: Vec<(Mask, ArrayRef)>, + else_value: ArrayRef, +) -> VortexResult { + if branches.len() == 1 { + let (mask, then_value) = &branches[0]; + return zip_impl(then_value, &else_value, mask); + } + + let row_count = else_value.len(); + + let output_nullability = branches + .iter() + .fold(else_value.dtype().nullability(), |acc, (_, arr)| { + acc | arr.dtype().nullability() + }); + let output_dtype = else_value.dtype().with_nullability(output_nullability); + + let mut spans: Vec<(usize, usize, usize)> = Vec::new(); + for (branch_idx, (mask, _)) in branches.iter().enumerate() { + match mask.slices() { + AllOr::All => return branches[branch_idx].1.cast(output_dtype), + AllOr::None => {} + AllOr::Some(slices) => { + for &(start, end) in slices { + spans.push((start, end, branch_idx)); + } + } + } + } + spans.sort_unstable_by_key(|&(start, ..)| start); + + if spans.is_empty() { + return else_value.cast(output_dtype); + } + + let builder = builder_with_capacity(&output_dtype, row_count); + + let branch_arrays: Vec = branches + .iter() + .map(|(_, arr)| arr.cast(output_dtype.clone())) + .collect::>()?; + let else_value = else_value.cast(output_dtype)?; + + let fragmented = !spans.is_empty() && spans.len() > row_count / SLICE_CROSSOVER_RUN_LEN; + if fragmented { + merge_row_by_row(&branch_arrays, &else_value, &spans, row_count, builder) + } else { + merge_run_by_run(&branch_arrays, &else_value, &spans, row_count, builder) + } +} + +/// Iterates spans directly, emitting one `scalar_at` per row. +/// Zero per-run allocations; preferred for fragmented masks (avg run < [`SLICE_CROSSOVER_RUN_LEN`]). +fn merge_row_by_row( + branch_arrays: &[ArrayRef], + else_value: &ArrayRef, + spans: &[(usize, usize, usize)], + row_count: usize, + mut builder: Box, +) -> VortexResult { + let mut pos = 0; + for &(start, end, branch_idx) in spans { + for row in pos..start { + builder.append_scalar(&else_value.scalar_at(row)?)?; + } + for row in start..end { + builder.append_scalar(&branch_arrays[branch_idx].scalar_at(row)?)?; + } + pos = end; + } + for row in pos..row_count { + builder.append_scalar(&else_value.scalar_at(row)?)?; + } + + Ok(builder.finish()) +} + +/// Bulk-copies each span via `slice()` + `extend_from_array`. +/// Preferred when runs are long enough that memcpy dominates over per-slice allocation cost. +fn merge_run_by_run( + branch_arrays: &[ArrayRef], + else_value: &ArrayRef, + spans: &[(usize, usize, usize)], + row_count: usize, + mut builder: Box, +) -> VortexResult { + for (start, end, branch_idx) in spans { + if builder.len() < *start { + builder.extend_from_array(&else_value.slice(builder.len()..*start)?); + } + builder.extend_from_array(&branch_arrays[*branch_idx].slice(*start..*end)?); + } + if builder.len() < row_count { + builder.extend_from_array(&else_value.slice(builder.len()..row_count)?); + } + + Ok(builder.finish()) +} + #[cfg(test)] mod tests { use std::sync::LazyLock; @@ -254,10 +375,11 @@ mod tests { use super::*; use crate::Canonical; use crate::IntoArray; - use crate::ToCanonical; use crate::VortexSessionExecute as _; + use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; + use crate::assert_arrays_eq; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -272,7 +394,6 @@ mod tests { use crate::expr::root; use crate::expr::test_harness; use crate::scalar::Scalar; - use crate::scalar_fn::fns::case_when::BoolArray; use crate::session::ArraySession; static SESSION: LazyLock = @@ -595,8 +716,8 @@ mod tests { lit(0i32), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array()); } #[test] @@ -615,8 +736,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[10, 0, 30, 0, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array()); } #[test] @@ -635,8 +756,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array()); } #[test] @@ -650,26 +771,10 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); - - assert_eq!( - result.scalar_at(0).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(1).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(2).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(3).unwrap(), - Scalar::from(100i32).cast(result.dtype()).unwrap() - ); - assert_eq!( - result.scalar_at(4).unwrap(), - Scalar::from(100i32).cast(result.dtype()).unwrap() + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([None::, None, None, Some(100), Some(100)]) + .into_array() ); } @@ -686,8 +791,8 @@ mod tests { lit(0i32), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 0, 0, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array()); } #[test] @@ -703,8 +808,67 @@ mod tests { lit(0i32), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[100, 100, 100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array()); + } + + #[test] + fn test_evaluate_all_true_no_else_returns_correct_dtype() { + // CASE WHEN value > 0 THEN 100 END — condition is always true, no ELSE. + // Result must be Nullable because the implicit ELSE is NULL. + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32)); + + let result = evaluate_expr(&expr, &test_array); + assert!( + result.dtype().is_nullable(), + "result dtype must be Nullable, got {:?}", + result.dtype() + ); + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array() + ); + } + + #[test] + fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> { + // When a later THEN branch is Nullable and branches[0] and ELSE are NonNullable, + // the result dtype must still be Nullable. + // + // CASE WHEN value = 0 THEN 10 -- NonNullable + // WHEN value = 1 THEN nullable(20) -- Nullable + // ELSE 0 -- NonNullable + // → result must be Nullable(i32) + let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())]) + .unwrap() + .into_array(); + + let nullable_20 = + Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?; + + let expr = nested_case_when( + vec![ + (eq(get_item("value", root()), lit(0i32)), lit(10i32)), + (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array); + assert!( + result.dtype().is_nullable(), + "result dtype must be Nullable, got {:?}", + result.dtype() + ); + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array() + ); + Ok(()) } #[test] @@ -713,12 +877,7 @@ mod tests { let expr = case_when(lit(true), lit(100i32), lit(0i32)); let result = evaluate_expr(&expr, &test_array); - if let Some(constant) = result.as_constant() { - assert_eq!(constant, Scalar::from(100i32)); - } else { - let prim = result.to_primitive(); - assert_eq!(prim.as_slice::(), &[100, 100, 100]); - } + assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array()); } #[test] @@ -734,10 +893,10 @@ mod tests { lit(false), ); - let result = evaluate_expr(&expr, &test_array).to_bool(); - assert_eq!( - result.to_bit_buffer().iter().collect::>(), - vec![false, false, true, true, true] + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!( + result, + BoolArray::from_iter([false, false, true, true, true]).into_array() ); } @@ -752,8 +911,8 @@ mod tests { let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[100, 0, 0, 0, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array()); } #[test] @@ -776,8 +935,11 @@ mod tests { ); let result = evaluate_expr(&expr, &test_array); - let prim = result.to_primitive(); - assert_eq!(prim.as_slice::(), &[0, 0, 30, 40, 50]); + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)]) + .into_array() + ); } #[test] @@ -791,8 +953,8 @@ mod tests { let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array()); } // ==================== N-ary Evaluate Tests ==================== @@ -815,26 +977,10 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); - - assert_eq!( - result.scalar_at(0).unwrap(), - Scalar::from(10i32).cast(result.dtype()).unwrap() - ); - assert_eq!( - result.scalar_at(1).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(2).unwrap(), - Scalar::from(30i32).cast(result.dtype()).unwrap() - ); - assert_eq!( - result.scalar_at(3).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(4).unwrap(), - Scalar::null(result.dtype().clone()) + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None]) + .into_array() ); } @@ -857,8 +1003,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[10, 20, 30, 40, 50]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array()); } #[test] @@ -878,12 +1024,10 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); - for i in 0..3 { - assert_eq!( - result.scalar_at(i).unwrap(), - Scalar::null(result.dtype().clone()) - ); - } + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([None::, None, None]).into_array() + ); } #[test] @@ -905,9 +1049,92 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); + let result = evaluate_expr(&expr, &test_array); // First matching condition always wins - assert_eq!(result.as_slice::(), &[1, 1, 1]); + assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array()); + } + + #[test] + fn test_evaluate_nary_early_exit_when_remaining_empty() { + // After branch 0 claims all rows, remaining becomes all_false. + // The loop breaks before evaluating branch 1's condition. + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(0i32)), lit(100i32)), + // Never evaluated due to early exit; 999 must never appear in output. + (gt(get_item("value", root()), lit(0i32)), lit(999i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array()); + } + + #[test] + fn test_evaluate_nary_skips_branch_with_empty_effective_mask() { + // Branch 0 claims value=1. Branch 1 targets the same rows but they are already + // matched → effective_mask is all_false → branch 1 is skipped (THEN not used). + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = nested_case_when( + vec![ + (eq(get_item("value", root()), lit(1i32)), lit(10i32)), + // Same condition as branch 0 — all matching rows already claimed → skipped. + // 999 must never appear in output. + (eq(get_item("value", root()), lit(1i32)), lit(999i32)), + (eq(get_item("value", root()), lit(2i32)), lit(20i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array()); + } + + #[test] + fn test_evaluate_nary_string_output() -> VortexResult<()> { + // Exercises merge_case_branches with a non-primitive (Utf8) builder. + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())]) + .unwrap() + .into_array(); + + // CASE WHEN value > 2 THEN 'high' WHEN value > 0 THEN 'low' ELSE 'none' END + // value=1,2 → 'low' (branch 1 after branch 0 claims 3,4) + // value=3,4 → 'high' (branch 0) + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(2i32)), lit("high")), + (gt(get_item("value", root()), lit(0i32)), lit("low")), + ], + Some(lit("none")), + ); + + let result = evaluate_expr(&expr, &test_array); + assert_eq!( + result.scalar_at(0)?, + Scalar::utf8("low", Nullability::NonNullable) + ); + assert_eq!( + result.scalar_at(1)?, + Scalar::utf8("low", Nullability::NonNullable) + ); + assert_eq!( + result.scalar_at(2)?, + Scalar::utf8("high", Nullability::NonNullable) + ); + assert_eq!( + result.scalar_at(3)?, + Scalar::utf8("high", Nullability::NonNullable) + ); + Ok(()) } #[test] @@ -933,10 +1160,45 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); + let result = evaluate_expr(&expr, &test_array); // row 0: cond1=true → 10 // row 1: cond1=NULL(→false), cond2=true → 20 // row 2: cond1=false, cond2=NULL(→false) → else=0 - assert_eq!(result.as_slice::(), &[10, 20, 0]); + assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array()); + } + + #[test] + fn test_merge_case_branches_alternating_mask() -> VortexResult<()> { + // Exercises the scalar path: alternating rows produce one slice per row (no runs), + // triggering the per-row cursor path in merge_case_branches. + let n = 100usize; + + // Branch 0: even rows → 0, Branch 1: odd rows → 1, Else: never reached. + let branch0_mask = Mask::from_indices(n, (0..n).step_by(2).collect()); + let branch1_mask = Mask::from_indices(n, (1..n).step_by(2).collect()); + + let result = merge_case_branches( + vec![ + ( + branch0_mask, + PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(), + ), + ( + branch1_mask, + PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(), + ), + ], + PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(), + )?; + + // Even rows → 0, odd rows → 1. + let expected: Vec> = (0..n) + .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) }) + .collect(); + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter(expected).into_array() + ); + Ok(()) } } diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 7a11dbd04dc..8dd7cde1d07 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -6,10 +6,11 @@ mod kernel; use std::fmt::Formatter; pub use kernel::*; +use vortex_error::VortexExpect as _; use vortex_error::VortexResult; use vortex_error::vortex_ensure; -use vortex_mask::AllOr; use vortex_mask::Mask; +use vortex_mask::MaskValues; use vortex_session::VortexSession; use crate::ArrayRef; @@ -126,11 +127,6 @@ impl ScalarFnVTable for Zip { return if_true.cast(return_dtype)?.execute(ctx); } - let return_dtype = if_true - .dtype() - .clone() - .union_nullability(if_false.dtype().nullability()); - if mask.all_false() { return if_false.cast(return_dtype)?.execute(ctx); } @@ -189,10 +185,19 @@ pub(crate) fn zip_impl( .dtype() .clone() .union_nullability(if_false.dtype().nullability()); + + if mask.all_true() { + return if_true.cast(return_type); + } + if mask.all_false() { + return if_false.cast(return_type); + } + zip_impl_with_builder( if_true, if_false, - mask, + mask.values() + .vortex_expect("zip_impl_with_builder: mask is not all-true or all-false"), builder_with_capacity(&return_type, if_true.len()), ) } @@ -200,23 +205,17 @@ pub(crate) fn zip_impl( fn zip_impl_with_builder( if_true: &ArrayRef, if_false: &ArrayRef, - mask: &Mask, + mask: &MaskValues, mut builder: Box, ) -> VortexResult { - match mask.slices() { - AllOr::All => Ok(if_true.to_array()), - AllOr::None => Ok(if_false.to_array()), - AllOr::Some(slices) => { - for (start, end) in slices { - builder.extend_from_array(&if_false.slice(builder.len()..*start)?); - builder.extend_from_array(&if_true.slice(*start..*end)?); - } - if builder.len() < if_false.len() { - builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?); - } - Ok(builder.finish()) - } + for (start, end) in mask.slices() { + builder.extend_from_array(&if_false.slice(builder.len()..*start)?); + builder.extend_from_array(&if_true.slice(*start..*end)?); } + if builder.len() < if_false.len() { + builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?); + } + Ok(builder.finish()) } #[cfg(test)] @@ -227,6 +226,7 @@ mod tests { use vortex_error::VortexResult; use vortex_mask::Mask; + use super::zip_impl; use crate::ArrayRef; use crate::IntoArray; use crate::LEGACY_SESSION; @@ -291,11 +291,57 @@ mod tests { PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array(); assert_arrays_eq!(result, expected); - - // result must be nullable even if_true was not assert_eq!(result.dtype(), if_false.dtype()) } + #[test] + fn test_zip_all_false_widens_nullability() { + let mask = Mask::new_false(4); + let if_true = + PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array(); + let if_false = buffer![1i32, 2, 3, 4].into_array(); + + let result = mask.into_array().zip(if_true.clone(), if_false).unwrap(); + let expected = + PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), Some(4)]).into_array(); + + assert_arrays_eq!(result, expected); + assert_eq!(result.dtype(), if_true.dtype()); + } + + #[test] + fn test_zip_impl_all_true_widens_nullability() -> VortexResult<()> { + let mask = Mask::new_true(4); + let if_true = buffer![10i32, 20, 30, 40].into_array(); + let if_false = + PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array(); + + let result = zip_impl(&if_true, &if_false, &mask)?; + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40)]) + .into_array() + ); + assert_eq!(result.dtype(), if_false.dtype()); + Ok(()) + } + + #[test] + fn test_zip_impl_all_false_widens_nullability() -> VortexResult<()> { + let mask = Mask::new_false(4); + let if_true = + PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array(); + let if_false = buffer![1i32, 2, 3, 4].into_array(); + + let result = zip_impl(&if_true, &if_false, &mask)?; + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(1i32), Some(2), Some(3), Some(4)]).into_array() + ); + assert_eq!(result.dtype(), if_true.dtype()); + Ok(()) + } + #[test] #[should_panic] fn test_invalid_lengths() { @@ -339,7 +385,6 @@ mod tests { buffer: views host 1.60 kB (align=16) (96.56%) "); - // test wrapped in a struct let wrapped1 = StructArray::try_from_iter([("nested", const1)])?.into_array(); let wrapped2 = StructArray::try_from_iter([("nested", const2)])?.into_array(); @@ -383,7 +428,6 @@ mod tests { builder.finish() }; - // [1,2,4,5,7,8,..] let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect()); let mask_array = mask.clone().into_array(); @@ -394,7 +438,6 @@ mod tests { .unwrap(); assert_eq!(zipped.nbuffers(), 2); - // assert the result is the same as arrow let expected = arrow_zip( mask.into_array() .into_arrow_preferred() diff --git a/vortex-mask/public-api.lock b/vortex-mask/public-api.lock index b0cce7abbf9..1fdf9ccd598 100644 --- a/vortex-mask/public-api.lock +++ b/vortex-mask/public-api.lock @@ -98,6 +98,10 @@ pub fn vortex_mask::Mask::values(&self) -> core::option::Option<&vortex_mask::Ma impl vortex_mask::Mask +pub fn vortex_mask::Mask::bitand_not(self, rhs: &vortex_mask::Mask) -> vortex_mask::Mask + +impl vortex_mask::Mask + pub fn vortex_mask::Mask::intersect_by_rank(&self, mask: &vortex_mask::Mask) -> vortex_mask::Mask impl vortex_mask::Mask diff --git a/vortex-mask/src/bitops.rs b/vortex-mask/src/bitops.rs index a03778eafab..3ad681db161 100644 --- a/vortex-mask/src/bitops.rs +++ b/vortex-mask/src/bitops.rs @@ -46,6 +46,21 @@ impl BitOr for &Mask { } } +impl Mask { + /// Computes `self & !rhs` (AND NOT), equivalent to set difference. + pub fn bitand_not(self, rhs: &Mask) -> Mask { + if self.len() != rhs.len() { + vortex_panic!("Masks must have the same length"); + } + match (self.bit_buffer(), rhs.bit_buffer()) { + (AllOr::None, _) | (_, AllOr::All) => Mask::new_false(self.len()), + (_, AllOr::None) => self, + (AllOr::All, _) => !rhs, + (AllOr::Some(lhs), AllOr::Some(rhs)) => Mask::from_buffer(lhs.bitand_not(rhs)), + } + } +} + impl Not for Mask { type Output = Mask; @@ -353,6 +368,34 @@ mod tests { assert!(!result.value(3)); // (!(!false) | false) & !true = (false | false) & false = false } + #[test] + fn test_bitand_not() { + let a = Mask::from_buffer(BitBuffer::from_iter([true, true, false, false])); + let b = Mask::from_buffer(BitBuffer::from_iter([true, false, true, false])); + let result = a.clone().bitand_not(&b); + assert!(!result.value(0)); // true & !true = false + assert!(result.value(1)); // true & !false = true + assert!(!result.value(2)); // false & !true = false + assert!(!result.value(3)); // false & !false = false + + // bitand_not(All) = None + assert!(a.clone().bitand_not(&Mask::new_true(4)).all_false()); + + // bitand_not(None) = self + let none = Mask::new_false(4); + assert_eq!(a.clone().bitand_not(&none).true_count(), a.true_count()); + + // None.bitand_not(_) = None + assert!(none.bitand_not(&a).all_false()); + + // All.bitand_not(x) = !x + let not_b = !&b; + let all_bitand_not_b = Mask::new_true(4).bitand_not(&b); + for i in 0..4 { + assert_eq!(all_bitand_not_b.value(i), not_b.value(i)); + } + } + #[test] fn test_bitor() { // Test basic OR operations