From 768bf0ff930d52658ec95f9706f30a081a631af5 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 5 Mar 2026 16:02:27 +0000 Subject: [PATCH 01/22] forward pass case when Signed-off-by: Baris Palaska --- vortex-array/benches/expr/case_when_bench.rs | 34 +++ vortex-array/src/scalar_fn/fns/case_when.rs | 239 +++++++++++++++++-- vortex-array/src/scalar_fn/fns/zip/mod.rs | 33 ++- 3 files changed, 281 insertions(+), 25 deletions(-) diff --git a/vortex-array/benches/expr/case_when_bench.rs b/vortex-array/benches/expr/case_when_bench.rs index e25ad180fa2..11b4f3f7658 100644 --- a/vortex-array/benches/expr/case_when_bench.rs +++ b/vortex-array/benches/expr/case_when_bench.rs @@ -18,6 +18,7 @@ use vortex_array::expr::eq; use vortex_array::expr::get_item; use vortex_array::expr::gt; use vortex_array::expr::lit; +use vortex_array::expr::lt; use vortex_array::expr::nested_case_when; use vortex_array::expr::root; use vortex_array::session::ArraySession; @@ -185,6 +186,39 @@ fn case_when_all_true(bencher: Bencher, size: usize) { }); } +/// Benchmark n-ary CASE WHEN where the first branch dominates (~90% of rows). +/// This highlights the early-exit and deferred-merge optimizations: subsequent conditions +/// match no remaining rows and are skipped entirely. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_nary_early_dominant(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value < 90% THEN 1 WHEN value < 95% THEN 2 WHEN value < 97.5% THEN 3 ELSE 4 + let t1 = (size as i32 * 9) / 10; + let t2 = (size as i32 * 19) / 20; + let t3 = (size as i32 * 39) / 40; + + let expr = nested_case_when( + vec![ + (lt(get_item("value", root()), lit(t1)), lit(1i32)), + (lt(get_item("value", root()), lit(t2)), lit(2i32)), + (lt(get_item("value", root()), lit(t3)), lit(3i32)), + ], + Some(lit(4i32)), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} + /// Benchmark CASE WHEN where all conditions are false. #[divan::bench(args = [1000, 10000, 100000])] fn case_when_all_false(bencher: Bencher, size: usize) { diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index f701548f145..ab425f50226 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -11,6 +11,8 @@ use std::sync::Arc; use prost::Message; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_mask::AllOr; +use vortex_mask::Mask; use vortex_proto::expr as pb; use vortex_session::VortexSession; @@ -19,6 +21,7 @@ use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::BoolArray; use crate::arrays::ConstantArray; +use crate::builders::builder_with_capacity; use crate::dtype::DType; use crate::expr::Expression; use crate::scalar::Scalar; @@ -191,37 +194,45 @@ impl ScalarFnVTable for CaseWhen { let row_count = args.row_count(); let num_pairs = options.num_when_then_pairs as usize; - let mut result: ArrayRef = if options.has_else { - args.get(num_pairs * 2)? - } else { - let then_dtype = args.get(1)?.dtype().as_nullable(); - ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() - }; + // Track unmatched rows; AND each condition with `remaining` to enforce first-match-wins + // and produce disjoint branch masks. + let mut remaining = Mask::new_true(row_count); + let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs); - for i in (0..num_pairs).rev() { - let condition = args.get(i * 2)?; - let then_value = args.get(i * 2 + 1)?; + for i in 0..num_pairs { + if remaining.all_false() { + break; + } + let condition = args.get(i * 2)?; let cond_bool = condition.execute::(ctx)?; - let mask = cond_bool.to_mask_fill_null_false(); + let cond_mask = cond_bool.to_mask_fill_null_false(); + let effective_mask = &remaining & &cond_mask; - if mask.all_true() { - result = then_value; + if effective_mask.all_false() { continue; } - if mask.all_false() { - continue; - } + let then_value = args.get(i * 2 + 1)?; + remaining = &remaining & &(!&cond_mask); + branches.push((effective_mask, then_value)); + } - result = zip_impl(&then_value, &result, &mask)?; + let else_value: ArrayRef = if options.has_else { + args.get(num_pairs * 2)? + } else { + let then_dtype = args.get(1)?.dtype().as_nullable(); + ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() + }; + + if branches.is_empty() { + return Ok(else_value); } - Ok(result) + merge_case_branches(branches, else_value) } fn is_null_sensitive(&self, _options: &Self::Options) -> bool { - // CaseWhen is null-sensitive because NULL conditions are treated as false true } @@ -230,6 +241,55 @@ impl ScalarFnVTable for CaseWhen { } } +/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` in a single pass. +/// +/// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`]. +fn merge_case_branches( + branches: Vec<(Mask, ArrayRef)>, + else_value: ArrayRef, +) -> VortexResult { + if branches.len() == 1 { + let (mask, then_value) = &branches[0]; + return zip_impl(then_value, &else_value, mask); + } + + let row_count = else_value.len(); + + let return_type = branches + .iter() + .fold(else_value.dtype().clone(), |acc, (_, arr)| { + acc.union_nullability(arr.dtype().nullability()) + }); + let mut builder = builder_with_capacity(&return_type, row_count); + + // Collect each branch's true-ranges tagged with branch index, then sort by position. + let mut events: Vec<(usize, usize, usize)> = Vec::new(); + for (branch_idx, (mask, _)) in branches.iter().enumerate() { + match mask.slices() { + AllOr::All => events.push((0, row_count, branch_idx)), + AllOr::None => {} + AllOr::Some(slices) => { + for &(start, end) in slices { + events.push((start, end, branch_idx)); + } + } + } + } + events.sort_unstable_by_key(|&(start, ..)| start); + + for (start, end, branch_idx) in &events { + if builder.len() < *start { + builder.extend_from_array(&else_value.slice(builder.len()..*start)?); + } + builder.extend_from_array(&branches[*branch_idx].1.slice(*start..*end)?); + } + if builder.len() < row_count { + builder.extend_from_array(&else_value.slice(builder.len()..row_count)?); + } + + Ok(builder.finish()) +} + #[cfg(test)] mod tests { use std::sync::LazyLock; @@ -246,6 +306,7 @@ mod tests { use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; + use crate::assert_arrays_eq; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -690,6 +751,65 @@ mod tests { assert_eq!(result.as_slice::(), &[100, 100, 100, 100, 100]); } + #[test] + fn test_evaluate_all_true_no_else_returns_correct_dtype() { + // CASE WHEN value > 0 THEN 100 END — condition is always true, no ELSE. + // Result must be Nullable because the implicit ELSE is NULL. + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32)); + + let result = evaluate_expr(&expr, &test_array); + assert!( + result.dtype().is_nullable(), + "result dtype must be Nullable, got {:?}", + result.dtype() + ); + assert_eq!( + result.scalar_at(0).unwrap(), + Scalar::from(100i32).cast(result.dtype()).unwrap() + ); + } + + #[test] + fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> { + // When a later THEN branch is Nullable and branches[0] and ELSE are NonNullable, + // the result dtype must still be Nullable. + // + // CASE WHEN value = 0 THEN 10 -- NonNullable + // WHEN value = 1 THEN nullable(20) -- Nullable + // ELSE 0 -- NonNullable + // → result must be Nullable(i32) + let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())]) + .unwrap() + .into_array(); + + let nullable_20 = + Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?; + + let expr = nested_case_when( + vec![ + (eq(get_item("value", root()), lit(0i32)), lit(10i32)), + (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array); + assert!( + result.dtype().is_nullable(), + "result dtype must be Nullable, got {:?}", + result.dtype() + ); + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array() + ); + Ok(()) + } + #[test] fn test_evaluate_with_literal_condition() { let test_array = buffer![1i32, 2, 3].into_array(); @@ -893,6 +1013,89 @@ mod tests { assert_eq!(result.as_slice::(), &[1, 1, 1]); } + #[test] + fn test_evaluate_nary_early_exit_when_remaining_empty() { + // After branch 0 claims all rows, remaining becomes all_false. + // The loop breaks before evaluating branch 1's condition. + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(0i32)), lit(100i32)), + // Never evaluated due to early exit; 999 must never appear in output. + (gt(get_item("value", root()), lit(0i32)), lit(999i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[100, 100, 100]); + } + + #[test] + fn test_evaluate_nary_skips_branch_with_empty_effective_mask() { + // Branch 0 claims value=1. Branch 1 targets the same rows but they are already + // matched → effective_mask is all_false → branch 1 is skipped (THEN not used). + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = nested_case_when( + vec![ + (eq(get_item("value", root()), lit(1i32)), lit(10i32)), + // Same condition as branch 0 — all matching rows already claimed → skipped. + // 999 must never appear in output. + (eq(get_item("value", root()), lit(1i32)), lit(999i32)), + (eq(get_item("value", root()), lit(2i32)), lit(20i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[10, 20, 0]); + } + + #[test] + fn test_evaluate_nary_string_output() -> VortexResult<()> { + // Exercises merge_case_branches with a non-primitive (Utf8) builder. + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())]) + .unwrap() + .into_array(); + + // CASE WHEN value > 2 THEN 'high' WHEN value > 0 THEN 'low' ELSE 'none' END + // value=1,2 → 'low' (branch 1 after branch 0 claims 3,4) + // value=3,4 → 'high' (branch 0) + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(2i32)), lit("high")), + (gt(get_item("value", root()), lit(0i32)), lit("low")), + ], + Some(lit("none")), + ); + + let result = evaluate_expr(&expr, &test_array); + assert_eq!( + result.scalar_at(0)?, + Scalar::utf8("low", Nullability::NonNullable) + ); + assert_eq!( + result.scalar_at(1)?, + Scalar::utf8("low", Nullability::NonNullable) + ); + assert_eq!( + result.scalar_at(2)?, + Scalar::utf8("high", Nullability::NonNullable) + ); + assert_eq!( + result.scalar_at(3)?, + Scalar::utf8("high", Nullability::NonNullable) + ); + Ok(()) + } + #[test] fn test_evaluate_nary_with_nullable_conditions() { let test_array = StructArray::from_fields(&[ diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 7a11dbd04dc..8b25fec760c 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -126,11 +126,6 @@ impl ScalarFnVTable for Zip { return if_true.cast(return_dtype)?.execute(ctx); } - let return_dtype = if_true - .dtype() - .clone() - .union_nullability(if_false.dtype().nullability()); - if mask.all_false() { return if_false.cast(return_dtype)?.execute(ctx); } @@ -204,8 +199,14 @@ fn zip_impl_with_builder( mut builder: Box, ) -> VortexResult { match mask.slices() { - AllOr::All => Ok(if_true.to_array()), - AllOr::None => Ok(if_false.to_array()), + AllOr::All => { + builder.extend_from_array(if_true); + Ok(builder.finish()) + } + AllOr::None => { + builder.extend_from_array(if_false); + Ok(builder.finish()) + } AllOr::Some(slices) => { for (start, end) in slices { builder.extend_from_array(&if_false.slice(builder.len()..*start)?); @@ -296,6 +297,24 @@ mod tests { assert_eq!(result.dtype(), if_false.dtype()) } + /// When the mask is all-false and `if_true` is Nullable, the result dtype must be Nullable + /// even though `if_false` is NonNullable. + #[test] + fn test_zip_all_false_widens_nullability() { + let mask = Mask::new_false(4); + let if_true = + PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array(); + let if_false = buffer![1i32, 2, 3, 4].into_array(); + + let result = mask.into_array().zip(if_true.clone(), if_false).unwrap(); + let expected = + PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), Some(4)]).into_array(); + + assert_arrays_eq!(result, expected); + // result must be nullable even though if_false was not + assert_eq!(result.dtype(), if_true.dtype()); + } + #[test] #[should_panic] fn test_invalid_lengths() { From 1d4b94706c1d82ccd54183c3ad4876d0e1570f72 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 5 Mar 2026 17:16:31 +0000 Subject: [PATCH 02/22] assert_arrays_eq Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 136 +++++++------------- 1 file changed, 50 insertions(+), 86 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index ab425f50226..e5e29eeee08 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -301,7 +301,6 @@ mod tests { use super::*; use crate::Canonical; use crate::IntoArray; - use crate::ToCanonical; use crate::VortexSessionExecute as _; use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; @@ -639,8 +638,8 @@ mod tests { lit(0i32), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array()); } #[test] @@ -659,8 +658,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[10, 0, 30, 0, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array()); } #[test] @@ -679,8 +678,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array()); } #[test] @@ -694,26 +693,10 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); - - assert_eq!( - result.scalar_at(0).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(1).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(2).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(3).unwrap(), - Scalar::from(100i32).cast(result.dtype()).unwrap() - ); - assert_eq!( - result.scalar_at(4).unwrap(), - Scalar::from(100i32).cast(result.dtype()).unwrap() + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([None::, None, None, Some(100), Some(100)]) + .into_array() ); } @@ -730,8 +713,8 @@ mod tests { lit(0i32), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 0, 0, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array()); } #[test] @@ -747,8 +730,8 @@ mod tests { lit(0i32), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[100, 100, 100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array()); } #[test] @@ -767,9 +750,9 @@ mod tests { "result dtype must be Nullable, got {:?}", result.dtype() ); - assert_eq!( - result.scalar_at(0).unwrap(), - Scalar::from(100i32).cast(result.dtype()).unwrap() + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array() ); } @@ -816,12 +799,7 @@ mod tests { let expr = case_when(lit(true), lit(100i32), lit(0i32)); let result = evaluate_expr(&expr, &test_array); - if let Some(constant) = result.as_constant() { - assert_eq!(constant, Scalar::from(100i32)); - } else { - let prim = result.to_primitive(); - assert_eq!(prim.as_slice::(), &[100, 100, 100]); - } + assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array()); } #[test] @@ -837,10 +815,11 @@ mod tests { lit(false), ); - let result = evaluate_expr(&expr, &test_array).to_bool(); - assert_eq!( - result.to_bit_buffer().iter().collect::>(), - vec![false, false, true, true, true] + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!( + result, + BoolArray::from_iter([Some(false), Some(false), Some(true), Some(true), Some(true)]) + .into_array() ); } @@ -855,8 +834,8 @@ mod tests { let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[100, 0, 0, 0, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array()); } #[test] @@ -879,8 +858,11 @@ mod tests { ); let result = evaluate_expr(&expr, &test_array); - let prim = result.to_primitive(); - assert_eq!(prim.as_slice::(), &[0, 0, 30, 40, 50]); + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)]) + .into_array() + ); } #[test] @@ -894,8 +876,8 @@ mod tests { let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array()); } // ==================== N-ary Evaluate Tests ==================== @@ -918,26 +900,10 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); - - assert_eq!( - result.scalar_at(0).unwrap(), - Scalar::from(10i32).cast(result.dtype()).unwrap() - ); - assert_eq!( - result.scalar_at(1).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(2).unwrap(), - Scalar::from(30i32).cast(result.dtype()).unwrap() - ); - assert_eq!( - result.scalar_at(3).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(4).unwrap(), - Scalar::null(result.dtype().clone()) + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None]) + .into_array() ); } @@ -960,8 +926,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[10, 20, 30, 40, 50]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array()); } #[test] @@ -981,12 +947,10 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); - for i in 0..3 { - assert_eq!( - result.scalar_at(i).unwrap(), - Scalar::null(result.dtype().clone()) - ); - } + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([None::, None, None]).into_array() + ); } #[test] @@ -1008,9 +972,9 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); + let result = evaluate_expr(&expr, &test_array); // First matching condition always wins - assert_eq!(result.as_slice::(), &[1, 1, 1]); + assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array()); } #[test] @@ -1030,8 +994,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array()); } #[test] @@ -1053,8 +1017,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[10, 20, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array()); } #[test] @@ -1119,10 +1083,10 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); + let result = evaluate_expr(&expr, &test_array); // row 0: cond1=true → 10 // row 1: cond1=NULL(→false), cond2=true → 20 // row 2: cond1=false, cond2=NULL(→false) → else=0 - assert_eq!(result.as_slice::(), &[10, 20, 0]); + assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array()); } } From 91b71a3a9b6aab36e743227b6893ac8bbabd3931 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 5 Mar 2026 19:09:34 +0000 Subject: [PATCH 03/22] add andnot Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 5 +-- vortex-mask/src/bitops.rs | 44 +++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index e5e29eeee08..043fff5b8d6 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -214,7 +214,7 @@ impl ScalarFnVTable for CaseWhen { } let then_value = args.get(i * 2 + 1)?; - remaining = &remaining & &(!&cond_mask); + remaining = remaining.andnot(&cond_mask); branches.push((effective_mask, then_value)); } @@ -818,8 +818,7 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert_arrays_eq!( result, - BoolArray::from_iter([Some(false), Some(false), Some(true), Some(true), Some(true)]) - .into_array() + BoolArray::from_iter([false, false, true, true, true]).into_array() ); } diff --git a/vortex-mask/src/bitops.rs b/vortex-mask/src/bitops.rs index a03778eafab..6a6125a08f3 100644 --- a/vortex-mask/src/bitops.rs +++ b/vortex-mask/src/bitops.rs @@ -46,6 +46,21 @@ impl BitOr for &Mask { } } +impl Mask { + /// Computes `self & !rhs` (AND NOT), equivalent to set difference. + pub fn andnot(&self, rhs: &Mask) -> Mask { + if self.len() != rhs.len() { + vortex_panic!("Masks must have the same length"); + } + match (self.bit_buffer(), rhs.bit_buffer()) { + (AllOr::None, _) | (_, AllOr::All) => Mask::new_false(self.len()), + (_, AllOr::None) => self.clone(), + (AllOr::All, _) => !rhs, + (AllOr::Some(lhs), AllOr::Some(rhs)) => Mask::from_buffer(lhs & !rhs), + } + } +} + impl Not for Mask { type Output = Mask; @@ -353,6 +368,35 @@ mod tests { assert!(!result.value(3)); // (!(!false) | false) & !true = (false | false) & false = false } + #[test] + fn test_andnot() { + let a = Mask::from_buffer(BitBuffer::from_iter([true, true, false, false])); + let b = Mask::from_buffer(BitBuffer::from_iter([true, false, true, false])); + let result = a.andnot(&b); + assert!(!result.value(0)); // true & !true = false + assert!(result.value(1)); // true & !false = true + assert!(!result.value(2)); // false & !true = false + assert!(!result.value(3)); // false & !false = false + + // andnot(All) = None + let all = Mask::new_true(4); + assert!(a.andnot(&all).all_false()); + + // andnot(None) = self + let none = Mask::new_false(4); + assert_eq!(a.andnot(&none).true_count(), a.true_count()); + + // None.andnot(_) = None + assert!(none.andnot(&a).all_false()); + + // All.andnot(x) = !x + let not_b = !&b; + let all_andnot_b = Mask::new_true(4).andnot(&b); + for i in 0..4 { + assert_eq!(all_andnot_b.value(i), not_b.value(i)); + } + } + #[test] fn test_bitor() { // Test basic OR operations From 2506209844009f1818c6324014ccc12cab2afc51 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 5 Mar 2026 19:16:46 +0000 Subject: [PATCH 04/22] cast in zip_impl Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/zip/mod.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 8b25fec760c..7ebf24a5aa1 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -184,6 +184,14 @@ pub(crate) fn zip_impl( .dtype() .clone() .union_nullability(if_false.dtype().nullability()); + + if mask.all_true() { + return if_true.cast(return_type); + } + if mask.all_false() { + return if_false.cast(return_type); + } + zip_impl_with_builder( if_true, if_false, @@ -199,13 +207,8 @@ fn zip_impl_with_builder( mut builder: Box, ) -> VortexResult { match mask.slices() { - AllOr::All => { - builder.extend_from_array(if_true); - Ok(builder.finish()) - } - AllOr::None => { - builder.extend_from_array(if_false); - Ok(builder.finish()) + AllOr::All | AllOr::None => { + unreachable!("zip_impl_with_builder called with all-true or all-false mask; handle in zip_impl") } AllOr::Some(slices) => { for (start, end) in slices { From db79be4ddf916e8883a8d890ae2243eedc476d5f Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 5 Mar 2026 19:21:50 +0000 Subject: [PATCH 05/22] tests Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/zip/mod.rs | 46 ++++++++++++++++++----- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 7ebf24a5aa1..3d73ec3e639 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -208,7 +208,9 @@ fn zip_impl_with_builder( ) -> VortexResult { match mask.slices() { AllOr::All | AllOr::None => { - unreachable!("zip_impl_with_builder called with all-true or all-false mask; handle in zip_impl") + unreachable!( + "zip_impl_with_builder called with all-true or all-false mask; handle in zip_impl" + ) } AllOr::Some(slices) => { for (start, end) in slices { @@ -231,6 +233,7 @@ mod tests { use vortex_error::VortexResult; use vortex_mask::Mask; + use super::zip_impl; use crate::ArrayRef; use crate::IntoArray; use crate::LEGACY_SESSION; @@ -295,13 +298,9 @@ mod tests { PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array(); assert_arrays_eq!(result, expected); - - // result must be nullable even if_true was not assert_eq!(result.dtype(), if_false.dtype()) } - /// When the mask is all-false and `if_true` is Nullable, the result dtype must be Nullable - /// even though `if_false` is NonNullable. #[test] fn test_zip_all_false_widens_nullability() { let mask = Mask::new_false(4); @@ -314,10 +313,42 @@ mod tests { PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), Some(4)]).into_array(); assert_arrays_eq!(result, expected); - // result must be nullable even though if_false was not assert_eq!(result.dtype(), if_true.dtype()); } + #[test] + fn test_zip_impl_all_true_widens_nullability() -> VortexResult<()> { + let mask = Mask::new_true(4); + let if_true = buffer![10i32, 20, 30, 40].into_array(); + let if_false = + PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array(); + + let result = zip_impl(&if_true, &if_false, &mask)?; + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40)]) + .into_array() + ); + assert_eq!(result.dtype(), if_false.dtype()); + Ok(()) + } + + #[test] + fn test_zip_impl_all_false_widens_nullability() -> VortexResult<()> { + let mask = Mask::new_false(4); + let if_true = + PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array(); + let if_false = buffer![1i32, 2, 3, 4].into_array(); + + let result = zip_impl(&if_true, &if_false, &mask)?; + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(1i32), Some(2), Some(3), Some(4)]).into_array() + ); + assert_eq!(result.dtype(), if_true.dtype()); + Ok(()) + } + #[test] #[should_panic] fn test_invalid_lengths() { @@ -361,7 +392,6 @@ mod tests { buffer: views host 1.60 kB (align=16) (96.56%) "); - // test wrapped in a struct let wrapped1 = StructArray::try_from_iter([("nested", const1)])?.into_array(); let wrapped2 = StructArray::try_from_iter([("nested", const2)])?.into_array(); @@ -405,7 +435,6 @@ mod tests { builder.finish() }; - // [1,2,4,5,7,8,..] let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect()); let mask_array = mask.clone().into_array(); @@ -416,7 +445,6 @@ mod tests { .unwrap(); assert_eq!(zipped.nbuffers(), 2); - // assert the result is the same as arrow let expected = arrow_zip( mask.into_array() .into_arrow_preferred() From dbc323da75b2b8dedf4e93cca28f7024ebe89ad9 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 5 Mar 2026 19:32:44 +0000 Subject: [PATCH 06/22] public api Signed-off-by: Baris Palaska --- vortex-mask/public-api.lock | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vortex-mask/public-api.lock b/vortex-mask/public-api.lock index b0cce7abbf9..84aa1cf82f5 100644 --- a/vortex-mask/public-api.lock +++ b/vortex-mask/public-api.lock @@ -98,6 +98,10 @@ pub fn vortex_mask::Mask::values(&self) -> core::option::Option<&vortex_mask::Ma impl vortex_mask::Mask +pub fn vortex_mask::Mask::andnot(&self, rhs: &vortex_mask::Mask) -> vortex_mask::Mask + +impl vortex_mask::Mask + pub fn vortex_mask::Mask::intersect_by_rank(&self, mask: &vortex_mask::Mask) -> vortex_mask::Mask impl vortex_mask::Mask From 9d57dffbdfad7c19fcc7739f674c2d261aed6ea2 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Fri, 6 Mar 2026 13:37:42 +0000 Subject: [PATCH 07/22] add todo Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index 043fff5b8d6..f7384a05932 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -191,11 +191,18 @@ impl ScalarFnVTable for CaseWhen { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { + // Inspired by https://datafusion.apache.org/blog/2026/02/02/datafusion_case/ + // + // Implemented: short-circuit early exit; single-pass merge via `merge_case_branches`. + // Partial: single-branch uses `zip_impl` but THEN/ELSE still evaluated on the full batch. + // + // TODO: shrink input to `remaining` rows between WHEN iterations (batch reduction). + // TODO: project to only referenced columns before batch reduction (column projection). + // TODO: evaluate THEN/ELSE on compact matching/non-matching rows and merge without scatter. + // TODO: for constant WHEN/THEN values, compile to a hash table for a single-pass lookup. let row_count = args.row_count(); let num_pairs = options.num_when_then_pairs as usize; - // Track unmatched rows; AND each condition with `remaining` to enforce first-match-wins - // and produce disjoint branch masks. let mut remaining = Mask::new_true(row_count); let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs); From cb59e4fbf0afc158a0c3ce2ca1579f4c0272cd74 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Fri, 6 Mar 2026 14:10:05 +0000 Subject: [PATCH 08/22] mask::bitand_not uses fused bitbuffer method, also owned Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 2 +- vortex-mask/src/bitops.rs | 29 ++++++++++----------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index d64a6680f68..a476a04b8ed 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -231,7 +231,7 @@ impl ScalarFnVTable for CaseWhen { } let then_value = args.get(i * 2 + 1)?; - remaining = remaining.andnot(&cond_mask); + remaining = remaining.bitand_not(&cond_mask); branches.push((effective_mask, then_value)); } diff --git a/vortex-mask/src/bitops.rs b/vortex-mask/src/bitops.rs index 6a6125a08f3..3ad681db161 100644 --- a/vortex-mask/src/bitops.rs +++ b/vortex-mask/src/bitops.rs @@ -48,15 +48,15 @@ impl BitOr for &Mask { impl Mask { /// Computes `self & !rhs` (AND NOT), equivalent to set difference. - pub fn andnot(&self, rhs: &Mask) -> Mask { + pub fn bitand_not(self, rhs: &Mask) -> Mask { if self.len() != rhs.len() { vortex_panic!("Masks must have the same length"); } match (self.bit_buffer(), rhs.bit_buffer()) { (AllOr::None, _) | (_, AllOr::All) => Mask::new_false(self.len()), - (_, AllOr::None) => self.clone(), + (_, AllOr::None) => self, (AllOr::All, _) => !rhs, - (AllOr::Some(lhs), AllOr::Some(rhs)) => Mask::from_buffer(lhs & !rhs), + (AllOr::Some(lhs), AllOr::Some(rhs)) => Mask::from_buffer(lhs.bitand_not(rhs)), } } } @@ -369,31 +369,30 @@ mod tests { } #[test] - fn test_andnot() { + fn test_bitand_not() { let a = Mask::from_buffer(BitBuffer::from_iter([true, true, false, false])); let b = Mask::from_buffer(BitBuffer::from_iter([true, false, true, false])); - let result = a.andnot(&b); + let result = a.clone().bitand_not(&b); assert!(!result.value(0)); // true & !true = false assert!(result.value(1)); // true & !false = true assert!(!result.value(2)); // false & !true = false assert!(!result.value(3)); // false & !false = false - // andnot(All) = None - let all = Mask::new_true(4); - assert!(a.andnot(&all).all_false()); + // bitand_not(All) = None + assert!(a.clone().bitand_not(&Mask::new_true(4)).all_false()); - // andnot(None) = self + // bitand_not(None) = self let none = Mask::new_false(4); - assert_eq!(a.andnot(&none).true_count(), a.true_count()); + assert_eq!(a.clone().bitand_not(&none).true_count(), a.true_count()); - // None.andnot(_) = None - assert!(none.andnot(&a).all_false()); + // None.bitand_not(_) = None + assert!(none.bitand_not(&a).all_false()); - // All.andnot(x) = !x + // All.bitand_not(x) = !x let not_b = !&b; - let all_andnot_b = Mask::new_true(4).andnot(&b); + let all_bitand_not_b = Mask::new_true(4).bitand_not(&b); for i in 0..4 { - assert_eq!(all_andnot_b.value(i), not_b.value(i)); + assert_eq!(all_bitand_not_b.value(i), not_b.value(i)); } } From 7091207c7cd84ecb52109440a86964050909bbca Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 9 Mar 2026 12:28:31 +0000 Subject: [PATCH 09/22] cleaner Signed-off-by: Baris Palaska --- vortex-array/benches/expr/case_when_bench.rs | 44 +++++++ vortex-array/src/scalar_fn/fns/case_when.rs | 125 ++++++++++++++++--- 2 files changed, 151 insertions(+), 18 deletions(-) diff --git a/vortex-array/benches/expr/case_when_bench.rs b/vortex-array/benches/expr/case_when_bench.rs index 11b4f3f7658..77862a9c7a5 100644 --- a/vortex-array/benches/expr/case_when_bench.rs +++ b/vortex-array/benches/expr/case_when_bench.rs @@ -11,6 +11,7 @@ use vortex_array::ArrayRef; use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::arrays::BoolArray; use vortex_array::arrays::StructArray; use vortex_array::expr::case_when; use vortex_array::expr::case_when_no_else; @@ -40,6 +41,22 @@ fn make_struct_array(size: usize) -> ArrayRef { .into_array() } +/// Array with boolean columns cycling through thirds: `c0[i] = i%3==0`, `c1[i] = i%3==1`. +fn make_fragmented_array(size: usize) -> ArrayRef { + StructArray::from_fields(&[ + ( + "c0", + BoolArray::from_iter((0..size).map(|i| i % 3 == 0)).into_array(), + ), + ( + "c1", + BoolArray::from_iter((0..size).map(|i| i % 3 == 1)).into_array(), + ), + ]) + .unwrap() + .into_array() +} + /// Benchmark a simple binary CASE WHEN with varying array sizes. #[divan::bench(args = [1000, 10000, 100000])] fn case_when_simple(bencher: Bencher, size: usize) { @@ -242,3 +259,30 @@ fn case_when_all_false(bencher: Bencher, size: usize) { .unwrap() }); } + +/// Benchmark CASE WHEN cycling through 3 branches per row (triggers merge_row_by_row). +/// Run length = 1; exercises branch 0, branch 1, and the else fallback at every 3rd row. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_fragmented(bencher: Bencher, size: usize) { + let array = make_fragmented_array(size); + + // CASE WHEN c0 THEN 0 WHEN c1 THEN 1 ELSE 2 END + let expr = nested_case_when( + vec![ + (get_item("c0", root()), lit(0i32)), + (get_item("c1", root()), lit(1i32)), + ], + Some(lit(2i32)), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index a476a04b8ed..3397b5bfc60 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -28,6 +28,7 @@ use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::BoolArray; use crate::arrays::ConstantArray; +use crate::builders::ArrayBuilder; use crate::builders::builder_with_capacity; use crate::dtype::DType; use crate::expr::Expression; @@ -203,12 +204,9 @@ impl ScalarFnVTable for CaseWhen { ) -> VortexResult { // Inspired by https://datafusion.apache.org/blog/2026/02/02/datafusion_case/ // - // Implemented: short-circuit early exit; single-pass merge via `merge_case_branches`. - // Partial: single-branch uses `zip_impl` but THEN/ELSE still evaluated on the full batch. - // // TODO: shrink input to `remaining` rows between WHEN iterations (batch reduction). // TODO: project to only referenced columns before batch reduction (column projection). - // TODO: evaluate THEN/ELSE on compact matching/non-matching rows and merge without scatter. + // TODO: evaluate THEN/ELSE on compact matching/non-matching rows and scatter-merge the results. // TODO: for constant WHEN/THEN values, compile to a hash table for a single-pass lookup. let row_count = args.row_count(); let num_pairs = options.num_when_then_pairs as usize; @@ -260,6 +258,10 @@ impl ScalarFnVTable for CaseWhen { /// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` in a single pass. /// +/// Average run length at which slicing + `extend_from_array` becomes cheaper than `scalar_at`. +/// Measured empirically via benchmarks. +const SLICE_CROSSOVER_RUN_LEN: usize = 4; + /// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`]. fn merge_case_branches( branches: Vec<(Mask, ArrayRef)>, @@ -271,30 +273,82 @@ fn merge_case_branches( } let row_count = else_value.len(); - - let return_type = branches - .iter() - .fold(else_value.dtype().clone(), |acc, (_, arr)| { - acc.union_nullability(arr.dtype().nullability()) - }); - let mut builder = builder_with_capacity(&return_type, row_count); - - // Collect each branch's true-ranges tagged with branch index, then sort by position. - let mut events: Vec<(usize, usize, usize)> = Vec::new(); + let mut spans: Vec<(usize, usize, usize)> = Vec::new(); for (branch_idx, (mask, _)) in branches.iter().enumerate() { match mask.slices() { - AllOr::All => events.push((0, row_count, branch_idx)), + AllOr::All => spans.push((0, row_count, branch_idx)), AllOr::None => {} AllOr::Some(slices) => { for &(start, end) in slices { - events.push((start, end, branch_idx)); + spans.push((start, end, branch_idx)); } } } } - events.sort_unstable_by_key(|&(start, ..)| start); + spans.sort_unstable_by_key(|&(start, ..)| start); + + let output_dtype = branches + .iter() + .fold(else_value.dtype().clone(), |acc, (_, arr)| { + acc.union_nullability(arr.dtype().nullability()) + }); + let builder = builder_with_capacity(&output_dtype, row_count); + + let fragmented = !spans.is_empty() && spans.len() > row_count / SLICE_CROSSOVER_RUN_LEN; + if fragmented { + merge_row_by_row(&branches, &else_value, &spans, row_count, builder) + } else { + merge_run_by_run(&branches, &else_value, &spans, row_count, builder) + } +} + +/// Walks rows with a span cursor, emitting one `scalar_at` per row. +/// Zero per-run allocations; preferred for fragmented masks (avg run < [`SLICE_CROSSOVER_RUN_LEN`]). +fn merge_row_by_row( + branches: &[(Mask, ArrayRef)], + else_value: &ArrayRef, + spans: &[(usize, usize, usize)], + row_count: usize, + mut builder: Box, +) -> VortexResult { + let builder_dtype = builder.dtype().clone(); + let needs_cast = branches + .iter() + .any(|(_, arr)| arr.dtype() != &builder_dtype) + || else_value.dtype() != &builder_dtype; + + let mut cursor = 0; + for row in 0..row_count { + while cursor < spans.len() && spans[cursor].1 <= row { + cursor += 1; + } + let src = if cursor < spans.len() && spans[cursor].0 <= row { + &branches[spans[cursor].2].1 + } else { + else_value + }; + let scalar = src.scalar_at(row)?; + let scalar = if needs_cast && scalar.dtype() != &builder_dtype { + scalar.cast(&builder_dtype)? + } else { + scalar + }; + builder.append_scalar(&scalar)?; + } - for (start, end, branch_idx) in &events { + Ok(builder.finish()) +} + +/// Bulk-copies each span via `extend_from_array`, one `slice()` per run. +/// Preferred when runs are long enough that memcpy dominates over per-run Arc allocation. +fn merge_run_by_run( + branches: &[(Mask, ArrayRef)], + else_value: &ArrayRef, + spans: &[(usize, usize, usize)], + row_count: usize, + mut builder: Box, +) -> VortexResult { + for (start, end, branch_idx) in spans { if builder.len() < *start { builder.extend_from_array(&else_value.slice(builder.len()..*start)?); } @@ -1109,4 +1163,39 @@ mod tests { // row 2: cond1=false, cond2=NULL(→false) → else=0 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array()); } + + #[test] + fn test_merge_case_branches_alternating_mask() -> VortexResult<()> { + // Exercises the scalar path: alternating rows produce one slice per row (no runs), + // triggering the per-row cursor path in merge_case_branches. + let n = 100usize; + + // Branch 0: even rows → 0, Branch 1: odd rows → 1, Else: never reached. + let branch0_mask = Mask::from_indices(n, (0..n).step_by(2).collect()); + let branch1_mask = Mask::from_indices(n, (1..n).step_by(2).collect()); + + let result = merge_case_branches( + vec![ + ( + branch0_mask, + PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(), + ), + ( + branch1_mask, + PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(), + ), + ], + PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(), + )?; + + // Even rows → 0, odd rows → 1. + let expected: Vec> = (0..n) + .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) }) + .collect(); + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter(expected).into_array() + ); + Ok(()) + } } From 5b74a426a0297be3f094a651e28560d08f602f0a Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 9 Mar 2026 15:01:04 +0000 Subject: [PATCH 10/22] public api Signed-off-by: Baris Palaska --- vortex-mask/public-api.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vortex-mask/public-api.lock b/vortex-mask/public-api.lock index 84aa1cf82f5..1fdf9ccd598 100644 --- a/vortex-mask/public-api.lock +++ b/vortex-mask/public-api.lock @@ -98,7 +98,7 @@ pub fn vortex_mask::Mask::values(&self) -> core::option::Option<&vortex_mask::Ma impl vortex_mask::Mask -pub fn vortex_mask::Mask::andnot(&self, rhs: &vortex_mask::Mask) -> vortex_mask::Mask +pub fn vortex_mask::Mask::bitand_not(self, rhs: &vortex_mask::Mask) -> vortex_mask::Mask impl vortex_mask::Mask From f26fc238ac7aa4f9d191a902cb0810ffadb076d2 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Mon, 9 Mar 2026 15:02:29 +0000 Subject: [PATCH 11/22] rm long running bench Signed-off-by: Baris Palaska --- vortex-array/benches/expr/case_when_bench.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vortex-array/benches/expr/case_when_bench.rs b/vortex-array/benches/expr/case_when_bench.rs index 77862a9c7a5..520a250c124 100644 --- a/vortex-array/benches/expr/case_when_bench.rs +++ b/vortex-array/benches/expr/case_when_bench.rs @@ -262,7 +262,7 @@ fn case_when_all_false(bencher: Bencher, size: usize) { /// Benchmark CASE WHEN cycling through 3 branches per row (triggers merge_row_by_row). /// Run length = 1; exercises branch 0, branch 1, and the else fallback at every 3rd row. -#[divan::bench(args = [1000, 10000, 100000])] +#[divan::bench(args = [1000, 10000])] fn case_when_fragmented(bencher: Bencher, size: usize) { let array = make_fragmented_array(size); From 675173c22b3a0940b667dfaa5cf83073733a6924 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 10 Mar 2026 13:52:04 +0000 Subject: [PATCH 12/22] zip_impl_with_builder accepts mask values Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/zip/mod.rs | 30 +++++++++-------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 3d73ec3e639..17dd8aea472 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -8,8 +8,8 @@ use std::fmt::Formatter; pub use kernel::*; use vortex_error::VortexResult; use vortex_error::vortex_ensure; -use vortex_mask::AllOr; use vortex_mask::Mask; +use vortex_mask::MaskValues; use vortex_session::VortexSession; use crate::ArrayRef; @@ -195,7 +195,8 @@ pub(crate) fn zip_impl( zip_impl_with_builder( if_true, if_false, - mask, + mask.values() + .expect("zip_impl_with_builder: mask is not all-true or all-false"), builder_with_capacity(&return_type, if_true.len()), ) } @@ -203,26 +204,17 @@ pub(crate) fn zip_impl( fn zip_impl_with_builder( if_true: &ArrayRef, if_false: &ArrayRef, - mask: &Mask, + mask: &MaskValues, mut builder: Box, ) -> VortexResult { - match mask.slices() { - AllOr::All | AllOr::None => { - unreachable!( - "zip_impl_with_builder called with all-true or all-false mask; handle in zip_impl" - ) - } - AllOr::Some(slices) => { - for (start, end) in slices { - builder.extend_from_array(&if_false.slice(builder.len()..*start)?); - builder.extend_from_array(&if_true.slice(*start..*end)?); - } - if builder.len() < if_false.len() { - builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?); - } - Ok(builder.finish()) - } + for (start, end) in mask.slices() { + builder.extend_from_array(&if_false.slice(builder.len()..*start)?); + builder.extend_from_array(&if_true.slice(*start..*end)?); + } + if builder.len() < if_false.len() { + builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?); } + Ok(builder.finish()) } #[cfg(test)] From d21575db8bd39dbdcd05fa721a3a742d1b9f012c Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 10 Mar 2026 13:57:26 +0000 Subject: [PATCH 13/22] cap vec Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index 3ce68a97312..8848820de6f 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -273,7 +273,15 @@ fn merge_case_branches( } let row_count = else_value.len(); - let mut spans: Vec<(usize, usize, usize)> = Vec::new(); + let spans_cap: usize = branches + .iter() + .map(|(mask, _)| match mask.slices() { + AllOr::All => 1, + AllOr::None => 0, + AllOr::Some(slices) => slices.len(), + }) + .sum(); + let mut spans: Vec<(usize, usize, usize)> = Vec::with_capacity(spans_cap); for (branch_idx, (mask, _)) in branches.iter().enumerate() { match mask.slices() { AllOr::All => spans.push((0, row_count, branch_idx)), From 6c89ee0e9718cc6786304b3f208b5d740a6b3feb Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 10 Mar 2026 14:00:10 +0000 Subject: [PATCH 14/22] bit or Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index 8848820de6f..e19a43d34dd 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -295,11 +295,12 @@ fn merge_case_branches( } spans.sort_unstable_by_key(|&(start, ..)| start); - let output_dtype = branches + let output_nullability = branches .iter() - .fold(else_value.dtype().clone(), |acc, (_, arr)| { - acc.union_nullability(arr.dtype().nullability()) + .fold(else_value.dtype().nullability(), |acc, (_, arr)| { + acc | arr.dtype().nullability() }); + let output_dtype = else_value.dtype().with_nullability(output_nullability); let builder = builder_with_capacity(&output_dtype, row_count); let fragmented = !spans.is_empty() && spans.len() > row_count / SLICE_CROSSOVER_RUN_LEN; From 9c6597049cf0dc74e13a7a2d74d2e856f179643c Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 10 Mar 2026 14:05:07 +0000 Subject: [PATCH 15/22] cast arrays Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index e19a43d34dd..1cec21fa4b8 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -30,6 +30,7 @@ use crate::arrays::BoolArray; use crate::arrays::ConstantArray; use crate::builders::ArrayBuilder; use crate::builders::builder_with_capacity; +use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::expr::Expression; use crate::scalar::Scalar; @@ -321,10 +322,11 @@ fn merge_row_by_row( mut builder: Box, ) -> VortexResult { let builder_dtype = builder.dtype().clone(); - let needs_cast = branches + let branch_arrays: Vec = branches .iter() - .any(|(_, arr)| arr.dtype() != &builder_dtype) - || else_value.dtype() != &builder_dtype; + .map(|(_, arr)| arr.cast(builder_dtype.clone())) + .collect::>()?; + let else_value = else_value.cast(builder_dtype)?; let mut cursor = 0; for row in 0..row_count { @@ -332,17 +334,11 @@ fn merge_row_by_row( cursor += 1; } let src = if cursor < spans.len() && spans[cursor].0 <= row { - &branches[spans[cursor].2].1 + &branch_arrays[spans[cursor].2] } else { - else_value + &else_value }; - let scalar = src.scalar_at(row)?; - let scalar = if needs_cast && scalar.dtype() != &builder_dtype { - scalar.cast(&builder_dtype)? - } else { - scalar - }; - builder.append_scalar(&scalar)?; + builder.append_scalar(&src.scalar_at(row)?)?; } Ok(builder.finish()) From 56150dc5acde6857347ae82f25195063a3052355 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 10 Mar 2026 14:09:58 +0000 Subject: [PATCH 16/22] early exit when first branch matches all Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index 1cec21fa4b8..faaab9c87fb 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -274,6 +274,14 @@ fn merge_case_branches( } let row_count = else_value.len(); + + let output_nullability = branches + .iter() + .fold(else_value.dtype().nullability(), |acc, (_, arr)| { + acc | arr.dtype().nullability() + }); + let output_dtype = else_value.dtype().with_nullability(output_nullability); + let spans_cap: usize = branches .iter() .map(|(mask, _)| match mask.slices() { @@ -285,7 +293,7 @@ fn merge_case_branches( let mut spans: Vec<(usize, usize, usize)> = Vec::with_capacity(spans_cap); for (branch_idx, (mask, _)) in branches.iter().enumerate() { match mask.slices() { - AllOr::All => spans.push((0, row_count, branch_idx)), + AllOr::All => return branches[branch_idx].1.cast(output_dtype), AllOr::None => {} AllOr::Some(slices) => { for &(start, end) in slices { @@ -295,13 +303,6 @@ fn merge_case_branches( } } spans.sort_unstable_by_key(|&(start, ..)| start); - - let output_nullability = branches - .iter() - .fold(else_value.dtype().nullability(), |acc, (_, arr)| { - acc | arr.dtype().nullability() - }); - let output_dtype = else_value.dtype().with_nullability(output_nullability); let builder = builder_with_capacity(&output_dtype, row_count); let fragmented = !spans.is_empty() && spans.len() > row_count / SLICE_CROSSOVER_RUN_LEN; From 3908a45324130cbb430d6723f9fe734208052876 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 10 Mar 2026 14:18:43 +0000 Subject: [PATCH 17/22] iterate over spans once on row_by_row Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index faaab9c87fb..207c423de1e 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -329,17 +329,18 @@ fn merge_row_by_row( .collect::>()?; let else_value = else_value.cast(builder_dtype)?; - let mut cursor = 0; - for row in 0..row_count { - while cursor < spans.len() && spans[cursor].1 <= row { - cursor += 1; + let mut pos = 0; + for &(start, end, branch_idx) in spans { + for row in pos..start { + builder.append_scalar(&else_value.scalar_at(row)?)?; } - let src = if cursor < spans.len() && spans[cursor].0 <= row { - &branch_arrays[spans[cursor].2] - } else { - &else_value - }; - builder.append_scalar(&src.scalar_at(row)?)?; + for row in start..end { + builder.append_scalar(&branch_arrays[branch_idx].scalar_at(row)?)?; + } + pos = end; + } + for row in pos..row_count { + builder.append_scalar(&else_value.scalar_at(row)?)?; } Ok(builder.finish()) From 482b6d08ca1194f01d8e88f57574e4ac47038bd1 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 10 Mar 2026 14:29:39 +0000 Subject: [PATCH 18/22] clippy Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/zip/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 17dd8aea472..2d1db3d22c3 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -8,6 +8,7 @@ use std::fmt::Formatter; pub use kernel::*; use vortex_error::VortexResult; use vortex_error::vortex_ensure; +use vortex_error::VortexExpect as _; use vortex_mask::Mask; use vortex_mask::MaskValues; use vortex_session::VortexSession; @@ -196,7 +197,7 @@ pub(crate) fn zip_impl( if_true, if_false, mask.values() - .expect("zip_impl_with_builder: mask is not all-true or all-false"), + .vortex_expect("zip_impl_with_builder: mask is not all-true or all-false"), builder_with_capacity(&return_type, if_true.len()), ) } From b0e81dcec19f57be00b921b034a79f01d4654d3d Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 10 Mar 2026 14:43:37 +0000 Subject: [PATCH 19/22] fmt Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/zip/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 2d1db3d22c3..8dd7cde1d07 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -6,9 +6,9 @@ mod kernel; use std::fmt::Formatter; pub use kernel::*; +use vortex_error::VortexExpect as _; use vortex_error::VortexResult; use vortex_error::vortex_ensure; -use vortex_error::VortexExpect as _; use vortex_mask::Mask; use vortex_mask::MaskValues; use vortex_session::VortexSession; From 50f3ed8d245521881b69a809e0140cb8b353aa18 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 10 Mar 2026 15:05:54 +0000 Subject: [PATCH 20/22] update comments Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index 207c423de1e..9fc11761224 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -257,12 +257,12 @@ impl ScalarFnVTable for CaseWhen { } } -/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` in a single pass. -/// /// Average run length at which slicing + `extend_from_array` becomes cheaper than `scalar_at`. /// Measured empirically via benchmarks. const SLICE_CROSSOVER_RUN_LEN: usize = 4; +/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` into a single array. +/// /// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`]. fn merge_case_branches( branches: Vec<(Mask, ArrayRef)>, @@ -313,7 +313,7 @@ fn merge_case_branches( } } -/// Walks rows with a span cursor, emitting one `scalar_at` per row. +/// Iterates spans directly, emitting one `scalar_at` per row. /// Zero per-run allocations; preferred for fragmented masks (avg run < [`SLICE_CROSSOVER_RUN_LEN`]). fn merge_row_by_row( branches: &[(Mask, ArrayRef)], @@ -346,8 +346,8 @@ fn merge_row_by_row( Ok(builder.finish()) } -/// Bulk-copies each span via `extend_from_array`, one `slice()` per run. -/// Preferred when runs are long enough that memcpy dominates over per-run Arc allocation. +/// Bulk-copies each span via `slice()` + `extend_from_array`. +/// Preferred when runs are long enough that memcpy dominates over per-slice allocation cost. fn merge_run_by_run( branches: &[(Mask, ArrayRef)], else_value: &ArrayRef, From 84b42db256dbe9f81f9e7bb482511626eded2a47 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 10 Mar 2026 15:20:28 +0000 Subject: [PATCH 21/22] fix Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 35 ++++++++------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index 9fc11761224..485576711be 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -282,15 +282,7 @@ fn merge_case_branches( }); let output_dtype = else_value.dtype().with_nullability(output_nullability); - let spans_cap: usize = branches - .iter() - .map(|(mask, _)| match mask.slices() { - AllOr::All => 1, - AllOr::None => 0, - AllOr::Some(slices) => slices.len(), - }) - .sum(); - let mut spans: Vec<(usize, usize, usize)> = Vec::with_capacity(spans_cap); + let mut spans: Vec<(usize, usize, usize)> = Vec::new(); for (branch_idx, (mask, _)) in branches.iter().enumerate() { match mask.slices() { AllOr::All => return branches[branch_idx].1.cast(output_dtype), @@ -305,30 +297,29 @@ fn merge_case_branches( spans.sort_unstable_by_key(|&(start, ..)| start); let builder = builder_with_capacity(&output_dtype, row_count); + let branch_arrays: Vec = branches + .iter() + .map(|(_, arr)| arr.cast(output_dtype.clone())) + .collect::>()?; + let else_value = else_value.cast(output_dtype)?; + let fragmented = !spans.is_empty() && spans.len() > row_count / SLICE_CROSSOVER_RUN_LEN; if fragmented { - merge_row_by_row(&branches, &else_value, &spans, row_count, builder) + merge_row_by_row(&branch_arrays, &else_value, &spans, row_count, builder) } else { - merge_run_by_run(&branches, &else_value, &spans, row_count, builder) + merge_run_by_run(&branch_arrays, &else_value, &spans, row_count, builder) } } /// Iterates spans directly, emitting one `scalar_at` per row. /// Zero per-run allocations; preferred for fragmented masks (avg run < [`SLICE_CROSSOVER_RUN_LEN`]). fn merge_row_by_row( - branches: &[(Mask, ArrayRef)], + branch_arrays: &[ArrayRef], else_value: &ArrayRef, spans: &[(usize, usize, usize)], row_count: usize, mut builder: Box, ) -> VortexResult { - let builder_dtype = builder.dtype().clone(); - let branch_arrays: Vec = branches - .iter() - .map(|(_, arr)| arr.cast(builder_dtype.clone())) - .collect::>()?; - let else_value = else_value.cast(builder_dtype)?; - let mut pos = 0; for &(start, end, branch_idx) in spans { for row in pos..start { @@ -349,7 +340,7 @@ fn merge_row_by_row( /// Bulk-copies each span via `slice()` + `extend_from_array`. /// Preferred when runs are long enough that memcpy dominates over per-slice allocation cost. fn merge_run_by_run( - branches: &[(Mask, ArrayRef)], + branch_arrays: &[ArrayRef], else_value: &ArrayRef, spans: &[(usize, usize, usize)], row_count: usize, @@ -359,7 +350,7 @@ fn merge_run_by_run( if builder.len() < *start { builder.extend_from_array(&else_value.slice(builder.len()..*start)?); } - builder.extend_from_array(&branches[*branch_idx].1.slice(*start..*end)?); + builder.extend_from_array(&branch_arrays[*branch_idx].slice(*start..*end)?); } if builder.len() < row_count { builder.extend_from_array(&else_value.slice(builder.len()..row_count)?); @@ -380,6 +371,7 @@ mod tests { use crate::Canonical; use crate::IntoArray; use crate::VortexSessionExecute as _; + use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; use crate::assert_arrays_eq; @@ -397,7 +389,6 @@ mod tests { use crate::expr::root; use crate::expr::test_harness; use crate::scalar::Scalar; - use crate::scalar_fn::fns::case_when::BoolArray; use crate::session::ArraySession; static SESSION: LazyLock = From a905aaa193e7bb24b37f89310d42d53a695f738b Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Tue, 10 Mar 2026 15:32:02 +0000 Subject: [PATCH 22/22] early return when no branch matches Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index 485576711be..2d24bc805ab 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -295,6 +295,11 @@ fn merge_case_branches( } } spans.sort_unstable_by_key(|&(start, ..)| start); + + if spans.is_empty() { + return else_value.cast(output_dtype); + } + let builder = builder_with_capacity(&output_dtype, row_count); let branch_arrays: Vec = branches