-
Notifications
You must be signed in to change notification settings - Fork 149
CaseWhen uses forward pass with a remaining mask #6804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
768bf0f
1b26a69
1d4b947
145fced
91b71a3
2506209
db79be4
6ea92f0
dbc323d
f52fa2c
9d57dff
39f2a04
9c75c72
6994f6e
cb59e4f
7091207
5b74a42
0c427ac
f26fc23
f175876
675173c
d21575d
6c89ee0
9c65970
56150dc
3908a45
8f3c513
482b6d0
b0e81dc
74497f3
50f3ed8
84b42db
a905aaa
ad4c41f
4f74690
397056e
1172216
fd32c76
b227013
ac72362
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,8 @@ use std::sync::Arc; | |
| use prost::Message; | ||
| use vortex_error::VortexResult; | ||
| use vortex_error::vortex_bail; | ||
| use vortex_mask::AllOr; | ||
| use vortex_mask::Mask; | ||
| use vortex_proto::expr as pb; | ||
| use vortex_session::VortexSession; | ||
|
|
||
|
|
@@ -19,6 +21,7 @@ use crate::ExecutionCtx; | |
| use crate::IntoArray; | ||
| use crate::arrays::BoolArray; | ||
| use crate::arrays::ConstantArray; | ||
| use crate::builders::builder_with_capacity; | ||
| use crate::dtype::DType; | ||
| use crate::expr::Expression; | ||
| use crate::scalar::Scalar; | ||
|
|
@@ -191,37 +194,45 @@ impl ScalarFnVTable for CaseWhen { | |
| let row_count = args.row_count(); | ||
| let num_pairs = options.num_when_then_pairs as usize; | ||
|
|
||
| let mut result: ArrayRef = if options.has_else { | ||
| args.get(num_pairs * 2)? | ||
| } else { | ||
| let then_dtype = args.get(1)?.dtype().as_nullable(); | ||
| ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() | ||
| }; | ||
| // Track unmatched rows; AND each condition with `remaining` to enforce first-match-wins | ||
| // and produce disjoint branch masks. | ||
| let mut remaining = Mask::new_true(row_count); | ||
| let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs); | ||
|
|
||
| for i in (0..num_pairs).rev() { | ||
| let condition = args.get(i * 2)?; | ||
| let then_value = args.get(i * 2 + 1)?; | ||
| for i in 0..num_pairs { | ||
| if remaining.all_false() { | ||
| break; | ||
| } | ||
|
|
||
| let condition = args.get(i * 2)?; | ||
| let cond_bool = condition.execute::<BoolArray>(ctx)?; | ||
| let mask = cond_bool.to_mask_fill_null_false(); | ||
| let cond_mask = cond_bool.to_mask_fill_null_false(); | ||
| let effective_mask = &remaining & &cond_mask; | ||
|
|
||
| if mask.all_true() { | ||
| result = then_value; | ||
| if effective_mask.all_false() { | ||
| continue; | ||
| } | ||
|
|
||
| if mask.all_false() { | ||
| continue; | ||
| } | ||
| let then_value = args.get(i * 2 + 1)?; | ||
| remaining = &remaining & &(!&cond_mask); | ||
|
palaska marked this conversation as resolved.
Outdated
|
||
| branches.push((effective_mask, then_value)); | ||
| } | ||
|
|
||
| result = zip_impl(&then_value, &result, &mask)?; | ||
| let else_value: ArrayRef = if options.has_else { | ||
| args.get(num_pairs * 2)? | ||
| } else { | ||
| let then_dtype = args.get(1)?.dtype().as_nullable(); | ||
| ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() | ||
|
palaska marked this conversation as resolved.
|
||
| }; | ||
|
|
||
| if branches.is_empty() { | ||
| return Ok(else_value); | ||
| } | ||
|
|
||
| Ok(result) | ||
| merge_case_branches(branches, else_value) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This really feels like we want an expr that is n-way merge, but that is future work I think |
||
| } | ||
|
|
||
| fn is_null_sensitive(&self, _options: &Self::Options) -> bool { | ||
| // CaseWhen is null-sensitive because NULL conditions are treated as false | ||
| true | ||
| } | ||
|
|
||
|
|
@@ -230,6 +241,55 @@ impl ScalarFnVTable for CaseWhen { | |
| } | ||
| } | ||
|
|
||
| /// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` in a single pass. | ||
| /// | ||
| /// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`]. | ||
| fn merge_case_branches( | ||
| branches: Vec<(Mask, ArrayRef)>, | ||
| else_value: ArrayRef, | ||
| ) -> VortexResult<ArrayRef> { | ||
| if branches.len() == 1 { | ||
| let (mask, then_value) = &branches[0]; | ||
| return zip_impl(then_value, &else_value, mask); | ||
| } | ||
|
|
||
| let row_count = else_value.len(); | ||
|
|
||
| let return_type = branches | ||
| .iter() | ||
| .fold(else_value.dtype().clone(), |acc, (_, arr)| { | ||
| acc.union_nullability(arr.dtype().nullability()) | ||
| }); | ||
| let mut builder = builder_with_capacity(&return_type, row_count); | ||
|
|
||
| // Collect each branch's true-ranges tagged with branch index, then sort by position. | ||
| let mut events: Vec<(usize, usize, usize)> = Vec::new(); | ||
| for (branch_idx, (mask, _)) in branches.iter().enumerate() { | ||
| match mask.slices() { | ||
| AllOr::All => events.push((0, row_count, branch_idx)), | ||
| AllOr::None => {} | ||
| AllOr::Some(slices) => { | ||
| for &(start, end) in slices { | ||
| events.push((start, end, branch_idx)); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| events.sort_unstable_by_key(|&(start, ..)| start); | ||
|
|
||
| for (start, end, branch_idx) in &events { | ||
| if builder.len() < *start { | ||
| builder.extend_from_array(&else_value.slice(builder.len()..*start)?); | ||
| } | ||
| builder.extend_from_array(&branches[*branch_idx].1.slice(*start..*end)?); | ||
| } | ||
| if builder.len() < row_count { | ||
| builder.extend_from_array(&else_value.slice(builder.len()..row_count)?); | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this might be very expensive if the slices are small
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. based on my benchmark sweep (again 😄 ), 4 as the average length of the runs seems to be a good cutoff point to choose between strategies. I know you wanted this optimization to live in zip but as we talked disjointness is not guaranteed there so I feel like it's fine to have it here? It'd be great if you can double check if my merge_row_by_row implementation is optimal. |
||
| Ok(builder.finish()) | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use std::sync::LazyLock; | ||
|
|
@@ -246,6 +306,7 @@ mod tests { | |
| use crate::arrays::BoolArray; | ||
| use crate::arrays::PrimitiveArray; | ||
| use crate::arrays::StructArray; | ||
| use crate::assert_arrays_eq; | ||
| use crate::dtype::DType; | ||
| use crate::dtype::Nullability; | ||
| use crate::dtype::PType; | ||
|
|
@@ -690,6 +751,65 @@ mod tests { | |
| assert_eq!(result.as_slice::<i32>(), &[100, 100, 100, 100, 100]); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_evaluate_all_true_no_else_returns_correct_dtype() { | ||
| // CASE WHEN value > 0 THEN 100 END — condition is always true, no ELSE. | ||
| // Result must be Nullable because the implicit ELSE is NULL. | ||
| let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) | ||
| .unwrap() | ||
| .into_array(); | ||
|
|
||
| let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32)); | ||
|
|
||
| let result = evaluate_expr(&expr, &test_array); | ||
| assert!( | ||
| result.dtype().is_nullable(), | ||
| "result dtype must be Nullable, got {:?}", | ||
| result.dtype() | ||
| ); | ||
| assert_eq!( | ||
| result.scalar_at(0).unwrap(), | ||
| Scalar::from(100i32).cast(result.dtype()).unwrap() | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> { | ||
| // When a later THEN branch is Nullable and branches[0] and ELSE are NonNullable, | ||
| // the result dtype must still be Nullable. | ||
| // | ||
| // CASE WHEN value = 0 THEN 10 -- NonNullable | ||
| // WHEN value = 1 THEN nullable(20) -- Nullable | ||
| // ELSE 0 -- NonNullable | ||
| // → result must be Nullable(i32) | ||
| let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())]) | ||
| .unwrap() | ||
| .into_array(); | ||
|
|
||
| let nullable_20 = | ||
| Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?; | ||
|
|
||
| let expr = nested_case_when( | ||
| vec![ | ||
| (eq(get_item("value", root()), lit(0i32)), lit(10i32)), | ||
| (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)), | ||
| ], | ||
| Some(lit(0i32)), | ||
| ); | ||
|
|
||
| let result = evaluate_expr(&expr, &test_array); | ||
| assert!( | ||
| result.dtype().is_nullable(), | ||
| "result dtype must be Nullable, got {:?}", | ||
| result.dtype() | ||
| ); | ||
| assert_arrays_eq!( | ||
| result, | ||
| PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array() | ||
| ); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_evaluate_with_literal_condition() { | ||
| let test_array = buffer![1i32, 2, 3].into_array(); | ||
|
|
@@ -893,6 +1013,89 @@ mod tests { | |
| assert_eq!(result.as_slice::<i32>(), &[1, 1, 1]); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_evaluate_nary_early_exit_when_remaining_empty() { | ||
| // After branch 0 claims all rows, remaining becomes all_false. | ||
| // The loop breaks before evaluating branch 1's condition. | ||
| let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) | ||
| .unwrap() | ||
| .into_array(); | ||
|
|
||
| let expr = nested_case_when( | ||
| vec![ | ||
| (gt(get_item("value", root()), lit(0i32)), lit(100i32)), | ||
| // Never evaluated due to early exit; 999 must never appear in output. | ||
| (gt(get_item("value", root()), lit(0i32)), lit(999i32)), | ||
| ], | ||
| Some(lit(0i32)), | ||
| ); | ||
|
|
||
| let result = evaluate_expr(&expr, &test_array).to_primitive(); | ||
| assert_eq!(result.as_slice::<i32>(), &[100, 100, 100]); | ||
|
palaska marked this conversation as resolved.
Outdated
|
||
| } | ||
|
|
||
| #[test] | ||
| fn test_evaluate_nary_skips_branch_with_empty_effective_mask() { | ||
| // Branch 0 claims value=1. Branch 1 targets the same rows but they are already | ||
| // matched → effective_mask is all_false → branch 1 is skipped (THEN not used). | ||
| let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) | ||
| .unwrap() | ||
| .into_array(); | ||
|
|
||
| let expr = nested_case_when( | ||
| vec![ | ||
| (eq(get_item("value", root()), lit(1i32)), lit(10i32)), | ||
| // Same condition as branch 0 — all matching rows already claimed → skipped. | ||
| // 999 must never appear in output. | ||
| (eq(get_item("value", root()), lit(1i32)), lit(999i32)), | ||
| (eq(get_item("value", root()), lit(2i32)), lit(20i32)), | ||
| ], | ||
| Some(lit(0i32)), | ||
| ); | ||
|
|
||
| let result = evaluate_expr(&expr, &test_array).to_primitive(); | ||
| assert_eq!(result.as_slice::<i32>(), &[10, 20, 0]); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_evaluate_nary_string_output() -> VortexResult<()> { | ||
| // Exercises merge_case_branches with a non-primitive (Utf8) builder. | ||
| let test_array = | ||
| StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())]) | ||
| .unwrap() | ||
| .into_array(); | ||
|
|
||
| // CASE WHEN value > 2 THEN 'high' WHEN value > 0 THEN 'low' ELSE 'none' END | ||
| // value=1,2 → 'low' (branch 1 after branch 0 claims 3,4) | ||
| // value=3,4 → 'high' (branch 0) | ||
| let expr = nested_case_when( | ||
| vec![ | ||
| (gt(get_item("value", root()), lit(2i32)), lit("high")), | ||
| (gt(get_item("value", root()), lit(0i32)), lit("low")), | ||
| ], | ||
| Some(lit("none")), | ||
| ); | ||
|
|
||
| let result = evaluate_expr(&expr, &test_array); | ||
| assert_eq!( | ||
| result.scalar_at(0)?, | ||
| Scalar::utf8("low", Nullability::NonNullable) | ||
| ); | ||
| assert_eq!( | ||
| result.scalar_at(1)?, | ||
| Scalar::utf8("low", Nullability::NonNullable) | ||
| ); | ||
| assert_eq!( | ||
| result.scalar_at(2)?, | ||
| Scalar::utf8("high", Nullability::NonNullable) | ||
| ); | ||
| assert_eq!( | ||
| result.scalar_at(3)?, | ||
| Scalar::utf8("high", Nullability::NonNullable) | ||
| ); | ||
| Ok(()) | ||
|
palaska marked this conversation as resolved.
|
||
| } | ||
|
|
||
| #[test] | ||
| fn test_evaluate_nary_with_nullable_conditions() { | ||
| let test_array = StructArray::from_fields(&[ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.