Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
768bf0f
forward pass case when
palaska Mar 5, 2026
1b26a69
Merge branch 'develop' into bp/case-when
palaska Mar 5, 2026
1d4b947
assert_arrays_eq
palaska Mar 5, 2026
145fced
Merge branch 'bp/case-when' of github.com:vortex-data/vortex into bp/…
palaska Mar 5, 2026
91b71a3
add andnot
palaska Mar 5, 2026
2506209
cast in zip_impl
palaska Mar 5, 2026
db79be4
tests
palaska Mar 5, 2026
6ea92f0
Merge branch 'develop' into bp/case-when
palaska Mar 5, 2026
dbc323d
public api
palaska Mar 5, 2026
f52fa2c
Merge branch 'develop' into bp/case-when
palaska Mar 5, 2026
9d57dff
add todo
palaska Mar 6, 2026
39f2a04
Merge branch 'develop' into bp/case-when
palaska Mar 6, 2026
9c75c72
Merge branch 'bp/case-when' of github.com:vortex-data/vortex into bp/…
palaska Mar 6, 2026
6994f6e
Merge branch 'develop' into bp/case-when
palaska Mar 6, 2026
cb59e4f
mask::bitand_not uses fused bitbuffer method, also owned
palaska Mar 6, 2026
7091207
cleaner
palaska Mar 9, 2026
5b74a42
public api
palaska Mar 9, 2026
0c427ac
Merge branch 'develop' into bp/case-when
palaska Mar 9, 2026
f26fc23
rm long running bench
palaska Mar 9, 2026
f175876
Merge branch 'bp/case-when' of github.com:vortex-data/vortex into bp/…
palaska Mar 9, 2026
675173c
zip_impl_with_builder accepts mask values
palaska Mar 10, 2026
d21575d
cap vec
palaska Mar 10, 2026
6c89ee0
bit or
palaska Mar 10, 2026
9c65970
cast arrays
palaska Mar 10, 2026
56150dc
early exit when first branch matches all
palaska Mar 10, 2026
3908a45
iterate over spans once on row_by_row
palaska Mar 10, 2026
8f3c513
Merge branch 'develop' into bp/case-when
palaska Mar 10, 2026
482b6d0
clippy
palaska Mar 10, 2026
b0e81dc
fmt
palaska Mar 10, 2026
74497f3
Merge branch 'develop' into bp/case-when
palaska Mar 10, 2026
50f3ed8
update comments
palaska Mar 10, 2026
84b42db
fix
palaska Mar 10, 2026
a905aaa
early return when no branch matches
palaska Mar 10, 2026
ad4c41f
dont pass row_count
palaska Mar 11, 2026
4f74690
scalar_at before cast
palaska Mar 11, 2026
397056e
Merge branch 'develop' into bp/case-when
joseph-isaacs Mar 11, 2026
1172216
leaner benches
palaska Mar 12, 2026
fd32c76
Merge branch 'bp/case-when' of github.com:vortex-data/vortex into bp/…
palaska Mar 12, 2026
b227013
100
palaska Mar 12, 2026
ac72362
Merge branch 'develop' into bp/case-when
palaska Mar 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions vortex-array/benches/expr/case_when_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use vortex_array::expr::eq;
use vortex_array::expr::get_item;
use vortex_array::expr::gt;
use vortex_array::expr::lit;
use vortex_array::expr::lt;
use vortex_array::expr::nested_case_when;
use vortex_array::expr::root;
use vortex_array::session::ArraySession;
Expand Down Expand Up @@ -185,6 +186,39 @@ fn case_when_all_true(bencher: Bencher, size: usize) {
});
}

/// Benchmark n-ary CASE WHEN where the first branch dominates (~90% of rows).
/// This highlights the early-exit and deferred-merge optimizations: subsequent conditions
/// match no remaining rows and are skipped entirely.
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_nary_early_dominant(bencher: Bencher, size: usize) {
let array = make_struct_array(size);

// CASE WHEN value < 90% THEN 1 WHEN value < 95% THEN 2 WHEN value < 97.5% THEN 3 ELSE 4
let t1 = (size as i32 * 9) / 10;
let t2 = (size as i32 * 19) / 20;
let t3 = (size as i32 * 39) / 40;

let expr = nested_case_when(
vec![
(lt(get_item("value", root()), lit(t1)), lit(1i32)),
(lt(get_item("value", root()), lit(t2)), lit(2i32)),
(lt(get_item("value", root()), lit(t3)), lit(3i32)),
],
Some(lit(4i32)),
);

bencher
.with_inputs(|| (&expr, &array))
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
.unwrap()
});
}

/// Benchmark CASE WHEN where all conditions are false.
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_all_false(bencher: Bencher, size: usize) {
Expand Down
239 changes: 221 additions & 18 deletions vortex-array/src/scalar_fn/fns/case_when.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use std::sync::Arc;
use prost::Message;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_mask::AllOr;
use vortex_mask::Mask;
use vortex_proto::expr as pb;
use vortex_session::VortexSession;

Expand All @@ -19,6 +21,7 @@ use crate::ExecutionCtx;
use crate::IntoArray;
use crate::arrays::BoolArray;
use crate::arrays::ConstantArray;
use crate::builders::builder_with_capacity;
use crate::dtype::DType;
use crate::expr::Expression;
use crate::scalar::Scalar;
Expand Down Expand Up @@ -191,37 +194,45 @@ impl ScalarFnVTable for CaseWhen {
let row_count = args.row_count();
let num_pairs = options.num_when_then_pairs as usize;

let mut result: ArrayRef = if options.has_else {
args.get(num_pairs * 2)?
} else {
let then_dtype = args.get(1)?.dtype().as_nullable();
ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
};
// Track unmatched rows; AND each condition with `remaining` to enforce first-match-wins
// and produce disjoint branch masks.
let mut remaining = Mask::new_true(row_count);
let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);

for i in (0..num_pairs).rev() {
let condition = args.get(i * 2)?;
let then_value = args.get(i * 2 + 1)?;
for i in 0..num_pairs {
if remaining.all_false() {
break;
}

let condition = args.get(i * 2)?;
let cond_bool = condition.execute::<BoolArray>(ctx)?;
let mask = cond_bool.to_mask_fill_null_false();
let cond_mask = cond_bool.to_mask_fill_null_false();
let effective_mask = &remaining & &cond_mask;

if mask.all_true() {
result = then_value;
if effective_mask.all_false() {
continue;
Comment thread
palaska marked this conversation as resolved.
}

if mask.all_false() {
continue;
}
let then_value = args.get(i * 2 + 1)?;
remaining = &remaining & &(!&cond_mask);
Comment thread
palaska marked this conversation as resolved.
Outdated
branches.push((effective_mask, then_value));
}

result = zip_impl(&then_value, &result, &mask)?;
let else_value: ArrayRef = if options.has_else {
args.get(num_pairs * 2)?
} else {
let then_dtype = args.get(1)?.dtype().as_nullable();
ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
Comment thread
palaska marked this conversation as resolved.
};

if branches.is_empty() {
return Ok(else_value);
}

Ok(result)
merge_case_branches(branches, else_value)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This really feels like we want an expr that is n-way merge, but that is future work I think

}

fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
// CaseWhen is null-sensitive because NULL conditions are treated as false
true
}

Expand All @@ -230,6 +241,55 @@ impl ScalarFnVTable for CaseWhen {
}
}

/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` in a single pass.
///
/// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`].
fn merge_case_branches(
branches: Vec<(Mask, ArrayRef)>,
else_value: ArrayRef,
) -> VortexResult<ArrayRef> {
if branches.len() == 1 {
let (mask, then_value) = &branches[0];
return zip_impl(then_value, &else_value, mask);
}

let row_count = else_value.len();

let return_type = branches
.iter()
.fold(else_value.dtype().clone(), |acc, (_, arr)| {
acc.union_nullability(arr.dtype().nullability())
});
let mut builder = builder_with_capacity(&return_type, row_count);

// Collect each branch's true-ranges tagged with branch index, then sort by position.
let mut events: Vec<(usize, usize, usize)> = Vec::new();
for (branch_idx, (mask, _)) in branches.iter().enumerate() {
match mask.slices() {
AllOr::All => events.push((0, row_count, branch_idx)),
AllOr::None => {}
AllOr::Some(slices) => {
for &(start, end) in slices {
events.push((start, end, branch_idx));
}
}
}
}
events.sort_unstable_by_key(|&(start, ..)| start);

for (start, end, branch_idx) in &events {
if builder.len() < *start {
builder.extend_from_array(&else_value.slice(builder.len()..*start)?);
}
builder.extend_from_array(&branches[*branch_idx].1.slice(*start..*end)?);
}
if builder.len() < row_count {
builder.extend_from_array(&else_value.slice(builder.len()..row_count)?);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might be very expensive if the slices are small

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on my benchmark sweep (again 😄 ), 4 as the average length of the runs seems to be a good cutoff point to choose between strategies. I know you wanted this optimization to live in zip but as we talked disjointness is not guaranteed there so I feel like it's fine to have it here?

It'd be great if you can double check if my merge_row_by_row implementation is optimal.

Ok(builder.finish())
}

#[cfg(test)]
mod tests {
use std::sync::LazyLock;
Expand All @@ -246,6 +306,7 @@ mod tests {
use crate::arrays::BoolArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::StructArray;
use crate::assert_arrays_eq;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
Expand Down Expand Up @@ -690,6 +751,65 @@ mod tests {
assert_eq!(result.as_slice::<i32>(), &[100, 100, 100, 100, 100]);
}

#[test]
fn test_evaluate_all_true_no_else_returns_correct_dtype() {
// CASE WHEN value > 0 THEN 100 END — condition is always true, no ELSE.
// Result must be Nullable because the implicit ELSE is NULL.
let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
.unwrap()
.into_array();

let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));

let result = evaluate_expr(&expr, &test_array);
assert!(
result.dtype().is_nullable(),
"result dtype must be Nullable, got {:?}",
result.dtype()
);
assert_eq!(
result.scalar_at(0).unwrap(),
Scalar::from(100i32).cast(result.dtype()).unwrap()
);
}

#[test]
fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
// When a later THEN branch is Nullable and branches[0] and ELSE are NonNullable,
// the result dtype must still be Nullable.
//
// CASE WHEN value = 0 THEN 10 -- NonNullable
// WHEN value = 1 THEN nullable(20) -- Nullable
// ELSE 0 -- NonNullable
// → result must be Nullable(i32)
let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])
.unwrap()
.into_array();

let nullable_20 =
Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;

let expr = nested_case_when(
vec![
(eq(get_item("value", root()), lit(0i32)), lit(10i32)),
(eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
],
Some(lit(0i32)),
);

let result = evaluate_expr(&expr, &test_array);
assert!(
result.dtype().is_nullable(),
"result dtype must be Nullable, got {:?}",
result.dtype()
);
assert_arrays_eq!(
result,
PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
);
Ok(())
}

#[test]
fn test_evaluate_with_literal_condition() {
let test_array = buffer![1i32, 2, 3].into_array();
Expand Down Expand Up @@ -893,6 +1013,89 @@ mod tests {
assert_eq!(result.as_slice::<i32>(), &[1, 1, 1]);
}

#[test]
fn test_evaluate_nary_early_exit_when_remaining_empty() {
// After branch 0 claims all rows, remaining becomes all_false.
// The loop breaks before evaluating branch 1's condition.
let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
.unwrap()
.into_array();

let expr = nested_case_when(
vec![
(gt(get_item("value", root()), lit(0i32)), lit(100i32)),
// Never evaluated due to early exit; 999 must never appear in output.
(gt(get_item("value", root()), lit(0i32)), lit(999i32)),
],
Some(lit(0i32)),
);

let result = evaluate_expr(&expr, &test_array).to_primitive();
assert_eq!(result.as_slice::<i32>(), &[100, 100, 100]);
Comment thread
palaska marked this conversation as resolved.
Outdated
}

#[test]
fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
// Branch 0 claims value=1. Branch 1 targets the same rows but they are already
// matched → effective_mask is all_false → branch 1 is skipped (THEN not used).
let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
.unwrap()
.into_array();

let expr = nested_case_when(
vec![
(eq(get_item("value", root()), lit(1i32)), lit(10i32)),
// Same condition as branch 0 — all matching rows already claimed → skipped.
// 999 must never appear in output.
(eq(get_item("value", root()), lit(1i32)), lit(999i32)),
(eq(get_item("value", root()), lit(2i32)), lit(20i32)),
],
Some(lit(0i32)),
);

let result = evaluate_expr(&expr, &test_array).to_primitive();
assert_eq!(result.as_slice::<i32>(), &[10, 20, 0]);
}

#[test]
fn test_evaluate_nary_string_output() -> VortexResult<()> {
// Exercises merge_case_branches with a non-primitive (Utf8) builder.
let test_array =
StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])
.unwrap()
.into_array();

// CASE WHEN value > 2 THEN 'high' WHEN value > 0 THEN 'low' ELSE 'none' END
// value=1,2 → 'low' (branch 1 after branch 0 claims 3,4)
// value=3,4 → 'high' (branch 0)
let expr = nested_case_when(
vec![
(gt(get_item("value", root()), lit(2i32)), lit("high")),
(gt(get_item("value", root()), lit(0i32)), lit("low")),
],
Some(lit("none")),
);

let result = evaluate_expr(&expr, &test_array);
assert_eq!(
result.scalar_at(0)?,
Scalar::utf8("low", Nullability::NonNullable)
);
assert_eq!(
result.scalar_at(1)?,
Scalar::utf8("low", Nullability::NonNullable)
);
assert_eq!(
result.scalar_at(2)?,
Scalar::utf8("high", Nullability::NonNullable)
);
assert_eq!(
result.scalar_at(3)?,
Scalar::utf8("high", Nullability::NonNullable)
);
Ok(())
Comment thread
palaska marked this conversation as resolved.
}

#[test]
fn test_evaluate_nary_with_nullable_conditions() {
let test_array = StructArray::from_fields(&[
Expand Down
Loading
Loading