@@ -31,7 +31,8 @@ use std::io::{BufReader, BufWriter};
3131use std:: sync:: Arc ;
3232
3333use ahash:: RandomState ;
34- use arrow:: compute:: { concat_batches, interleave_record_batch} ;
34+ use arrow:: array:: UInt32Array ;
35+ use arrow:: compute:: { concat_batches, take} ;
3536use arrow:: datatypes:: SchemaRef ;
3637use arrow:: ipc:: reader:: StreamReader ;
3738use arrow:: ipc:: writer:: StreamWriter ;
@@ -646,7 +647,7 @@ impl ScratchSpace {
646647 ( self . partition_starts [ partition_id + 1 ] - self . partition_starts [ partition_id] ) as usize
647648 }
648649
649- /// Extract a sub-batch for a partition using `interleave_record_batch` .
650+ /// Extract a sub-batch for a partition using Arrow's `take()` kernel .
650651 fn take_partition (
651652 & self ,
652653 batch : & RecordBatch ,
@@ -656,13 +657,13 @@ impl ScratchSpace {
656657 if row_indices. is_empty ( ) {
657658 return Ok ( None ) ;
658659 }
659- let indices: Vec < ( usize , usize ) > = row_indices
660+ let indices_array = UInt32Array :: from ( row_indices. to_vec ( ) ) ;
661+ let columns: Vec < _ > = batch
662+ . columns ( )
660663 . iter ( )
661- . map ( |& idx| ( 0usize , idx as usize ) )
662- . collect ( ) ;
663- let batches = [ batch] ;
664- let result = interleave_record_batch ( & batches, & indices) ?;
665- Ok ( Some ( result) )
664+ . map ( |col| take ( col. as_ref ( ) , & indices_array, None ) )
665+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
666+ Ok ( Some ( RecordBatch :: try_new ( batch. schema ( ) , columns) ?) )
666667 }
667668}
668669
@@ -722,9 +723,14 @@ async fn partition_build_side(
722723 continue ;
723724 }
724725
725- let sub_batch = scratch. take_partition ( & batch, part_idx) ?. unwrap ( ) ;
726- // Estimate size proportionally rather than calling get_array_memory_size per sub-batch
727726 let sub_rows = scratch. partition_len ( part_idx) ;
727+ // Use entire batch directly when all rows go to one partition
728+ let sub_batch = if sub_rows == total_rows {
729+ batch. clone ( )
730+ } else {
731+ scratch. take_partition ( & batch, part_idx) ?. unwrap ( )
732+ } ;
733+ // Estimate size proportionally rather than calling get_array_memory_size per sub-batch
728734 let batch_size = if total_rows > 0 {
729735 ( total_batch_size as u64 * sub_rows as u64 / total_rows as u64 ) as usize
730736 } else {
@@ -859,6 +865,7 @@ async fn partition_probe_side(
859865 metrics. input_batches . add ( 1 ) ;
860866 metrics. input_rows . add ( batch. num_rows ( ) ) ;
861867
868+ let total_rows = batch. num_rows ( ) ;
862869 scratch. compute_partitions ( & batch, keys, num_partitions, 0 ) ?;
863870
864871 #[ allow( clippy:: needless_range_loop) ]
@@ -867,7 +874,12 @@ async fn partition_probe_side(
867874 continue ;
868875 }
869876
870- let sub_batch = scratch. take_partition ( & batch, part_idx) ?. unwrap ( ) ;
877+ // Use entire batch directly when all rows go to one partition
878+ let sub_batch = if scratch. partition_len ( part_idx) == total_rows {
879+ batch. clone ( )
880+ } else {
881+ scratch. take_partition ( & batch, part_idx) ?. unwrap ( )
882+ } ;
871883
872884 if partitions[ part_idx] . build_spilled ( ) {
873885 // Build side was spilled, so spill probe side too
@@ -1194,9 +1206,15 @@ fn repartition_and_join(
11941206 let mut build_sub: Vec < Vec < RecordBatch > > =
11951207 ( 0 ..num_sub_partitions) . map ( |_| Vec :: new ( ) ) . collect ( ) ;
11961208 for batch in & build_batches {
1209+ let total_rows = batch. num_rows ( ) ;
11971210 scratch. compute_partitions ( batch, & build_keys, num_sub_partitions, recursion_level) ?;
11981211 for ( i, sub_vec) in build_sub. iter_mut ( ) . enumerate ( ) {
1199- if let Some ( sub) = scratch. take_partition ( batch, i) ? {
1212+ if scratch. partition_len ( i) == 0 {
1213+ continue ;
1214+ }
1215+ if scratch. partition_len ( i) == total_rows {
1216+ sub_vec. push ( batch. clone ( ) ) ;
1217+ } else if let Some ( sub) = scratch. take_partition ( batch, i) ? {
12001218 sub_vec. push ( sub) ;
12011219 }
12021220 }
@@ -1206,9 +1224,15 @@ fn repartition_and_join(
12061224 let mut probe_sub: Vec < Vec < RecordBatch > > =
12071225 ( 0 ..num_sub_partitions) . map ( |_| Vec :: new ( ) ) . collect ( ) ;
12081226 for batch in & probe_batches {
1227+ let total_rows = batch. num_rows ( ) ;
12091228 scratch. compute_partitions ( batch, & probe_keys, num_sub_partitions, recursion_level) ?;
12101229 for ( i, sub_vec) in probe_sub. iter_mut ( ) . enumerate ( ) {
1211- if let Some ( sub) = scratch. take_partition ( batch, i) ? {
1230+ if scratch. partition_len ( i) == 0 {
1231+ continue ;
1232+ }
1233+ if scratch. partition_len ( i) == total_rows {
1234+ sub_vec. push ( batch. clone ( ) ) ;
1235+ } else if let Some ( sub) = scratch. take_partition ( batch, i) ? {
12121236 sub_vec. push ( sub) ;
12131237 }
12141238 }
0 commit comments