Skip to content

Commit a8d1f07

Browse files
andygroveclaude
andcommitted
perf: Optimize Grace Hash Join with take() kernel and whole-batch passthrough
Two performance optimizations: 1. Replace interleave_record_batch with Arrow's take() kernel in take_partition - SIMD-optimized and avoids (batch_idx, row_idx) tuple overhead for single-batch case 2. Skip take_partition when entire batch goes to one partition - use batch directly via cheap clone instead of copying through take() Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent af71b50 commit a8d1f07

1 file changed

Lines changed: 37 additions & 13 deletions

File tree

native/core/src/execution/operators/grace_hash_join.rs

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ use std::io::{BufReader, BufWriter};
3131
use std::sync::Arc;
3232

3333
use ahash::RandomState;
34-
use arrow::compute::{concat_batches, interleave_record_batch};
34+
use arrow::array::UInt32Array;
35+
use arrow::compute::{concat_batches, take};
3536
use arrow::datatypes::SchemaRef;
3637
use arrow::ipc::reader::StreamReader;
3738
use arrow::ipc::writer::StreamWriter;
@@ -646,7 +647,7 @@ impl ScratchSpace {
646647
(self.partition_starts[partition_id + 1] - self.partition_starts[partition_id]) as usize
647648
}
648649

649-
/// Extract a sub-batch for a partition using `interleave_record_batch`.
650+
/// Extract a sub-batch for a partition using Arrow's `take()` kernel.
650651
fn take_partition(
651652
&self,
652653
batch: &RecordBatch,
@@ -656,13 +657,13 @@ impl ScratchSpace {
656657
if row_indices.is_empty() {
657658
return Ok(None);
658659
}
659-
let indices: Vec<(usize, usize)> = row_indices
660+
let indices_array = UInt32Array::from(row_indices.to_vec());
661+
let columns: Vec<_> = batch
662+
.columns()
660663
.iter()
661-
.map(|&idx| (0usize, idx as usize))
662-
.collect();
663-
let batches = [batch];
664-
let result = interleave_record_batch(&batches, &indices)?;
665-
Ok(Some(result))
664+
.map(|col| take(col.as_ref(), &indices_array, None))
665+
.collect::<Result<Vec<_>, _>>()?;
666+
Ok(Some(RecordBatch::try_new(batch.schema(), columns)?))
666667
}
667668
}
668669

@@ -722,9 +723,14 @@ async fn partition_build_side(
722723
continue;
723724
}
724725

725-
let sub_batch = scratch.take_partition(&batch, part_idx)?.unwrap();
726-
// Estimate size proportionally rather than calling get_array_memory_size per sub-batch
727726
let sub_rows = scratch.partition_len(part_idx);
727+
// Use entire batch directly when all rows go to one partition
728+
let sub_batch = if sub_rows == total_rows {
729+
batch.clone()
730+
} else {
731+
scratch.take_partition(&batch, part_idx)?.unwrap()
732+
};
733+
// Estimate size proportionally rather than calling get_array_memory_size per sub-batch
728734
let batch_size = if total_rows > 0 {
729735
(total_batch_size as u64 * sub_rows as u64 / total_rows as u64) as usize
730736
} else {
@@ -859,6 +865,7 @@ async fn partition_probe_side(
859865
metrics.input_batches.add(1);
860866
metrics.input_rows.add(batch.num_rows());
861867

868+
let total_rows = batch.num_rows();
862869
scratch.compute_partitions(&batch, keys, num_partitions, 0)?;
863870

864871
#[allow(clippy::needless_range_loop)]
@@ -867,7 +874,12 @@ async fn partition_probe_side(
867874
continue;
868875
}
869876

870-
let sub_batch = scratch.take_partition(&batch, part_idx)?.unwrap();
877+
// Use entire batch directly when all rows go to one partition
878+
let sub_batch = if scratch.partition_len(part_idx) == total_rows {
879+
batch.clone()
880+
} else {
881+
scratch.take_partition(&batch, part_idx)?.unwrap()
882+
};
871883

872884
if partitions[part_idx].build_spilled() {
873885
// Build side was spilled, so spill probe side too
@@ -1194,9 +1206,15 @@ fn repartition_and_join(
11941206
let mut build_sub: Vec<Vec<RecordBatch>> =
11951207
(0..num_sub_partitions).map(|_| Vec::new()).collect();
11961208
for batch in &build_batches {
1209+
let total_rows = batch.num_rows();
11971210
scratch.compute_partitions(batch, &build_keys, num_sub_partitions, recursion_level)?;
11981211
for (i, sub_vec) in build_sub.iter_mut().enumerate() {
1199-
if let Some(sub) = scratch.take_partition(batch, i)? {
1212+
if scratch.partition_len(i) == 0 {
1213+
continue;
1214+
}
1215+
if scratch.partition_len(i) == total_rows {
1216+
sub_vec.push(batch.clone());
1217+
} else if let Some(sub) = scratch.take_partition(batch, i)? {
12001218
sub_vec.push(sub);
12011219
}
12021220
}
@@ -1206,9 +1224,15 @@ fn repartition_and_join(
12061224
let mut probe_sub: Vec<Vec<RecordBatch>> =
12071225
(0..num_sub_partitions).map(|_| Vec::new()).collect();
12081226
for batch in &probe_batches {
1227+
let total_rows = batch.num_rows();
12091228
scratch.compute_partitions(batch, &probe_keys, num_sub_partitions, recursion_level)?;
12101229
for (i, sub_vec) in probe_sub.iter_mut().enumerate() {
1211-
if let Some(sub) = scratch.take_partition(batch, i)? {
1230+
if scratch.partition_len(i) == 0 {
1231+
continue;
1232+
}
1233+
if scratch.partition_len(i) == total_rows {
1234+
sub_vec.push(batch.clone());
1235+
} else if let Some(sub) = scratch.take_partition(batch, i)? {
12121236
sub_vec.push(sub);
12131237
}
12141238
}

0 commit comments

Comments
 (0)