CaseWhen uses forward pass with a remaining mask#6804
CaseWhen uses forward pass with a remaining mask#6804
Conversation
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
Merging this PR will improve performance by 59.97%
Performance Changes
Comparing Footnotes
|
|
Cool |
|
mind if we get this on in first: #6806 |
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
Signed-off-by: Baris Palaska <barispalaska@gmail.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
| for (start, end, branch_idx) in &events { | ||
| if builder.len() < *start { | ||
| builder.extend_from_array(&else_value.slice(builder.len()..*start)?); | ||
| } | ||
| builder.extend_from_array(&branches[*branch_idx].1.slice(*start..*end)?); | ||
| } | ||
| if builder.len() < row_count { | ||
| builder.extend_from_array(&else_value.slice(builder.len()..row_count)?); | ||
| } | ||
|
|
There was a problem hiding this comment.
this might be very expensive if the slices are small
There was a problem hiding this comment.
based on my benchmark sweep (again 😄 ), 4 as the average length of the runs seems to be a good cutoff point to choose between strategies. I know you wanted this optimization to live in zip but as we talked disjointness is not guaranteed there so I feel like it's fine to have it here?
It'd be great if you can double check if my merge_row_by_row implementation is optimal.
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
|
|
||
| /// Walks rows with a span cursor, emitting one `scalar_at` per row. | ||
| /// Zero per-run allocations; preferred for fragmented masks (avg run < [`SLICE_CROSSOVER_RUN_LEN`]). | ||
| fn merge_row_by_row( |
There was a problem hiding this comment.
Is this benchmarked?
There was a problem hiding this comment.
Yes but only for primitive arrays (avg length of 4 is calibrated using primitive arrays). I guess the cost will differ based on the underlying encodings of the arrays. Probably a lower cutoff will work better when scalar_at is more expensive. Not sure what the best move here is.
| } | ||
|
|
||
| Ok(result) | ||
| merge_case_branches(branches, else_value) |
There was a problem hiding this comment.
This really feels like we want an expr that is n-way merge, but that is future work I think
| } | ||
| } | ||
| } | ||
| spans.sort_unstable_by_key(|&(start, ..)| start); |
There was a problem hiding this comment.
This works since each range is globally disjoint?
| for row in 0..row_count { | ||
| while cursor < spans.len() && spans[cursor].1 <= row { | ||
| cursor += 1; | ||
| } |
There was a problem hiding this comment.
This seems very inefficient? Didn't you just sort this?
There was a problem hiding this comment.
can we do something else other than append_scalar(branch_arr.scalar_at(...)) once per row? I guess you are saying inefficient because of the inner while loop which just increments a counter. I refactored it so it's better structured now but it's still does n scalar_at calls
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
| branch_arrays: &[ArrayRef], | ||
| else_value: &ArrayRef, | ||
| spans: &[(usize, usize, usize)], | ||
| row_count: usize, |
Each condition is ANDed with remaining (unmatched rows), producing disjoint branch masks by construction. This enables early exit when all rows are claimed and skips evaluating THEN expressions for branches with no remaining matches.
Matched branches are merged in a single builder pass instead of N intermediate arrays.
Inspired by Datafusion's CASE WHEN optimization blog post
Future work
simplify?)Also fixes a nullability bug in
zip_impl_with_builderwhereAllOr::All/AllOr::Noneshort-circuits bypassed the builder's declared return type, producing a wrong dtype when if_true and if_false had different nullability.Benches


Before
After