From 024f73262d5a2b04c5b6fd26e57c5174520c3214 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Wed, 4 Mar 2026 17:17:27 -0600 Subject: [PATCH] feat: `partition_statistics()` for HashJoinExec --- .../partition_statistics.rs | 253 +++++++++++++++++- .../physical-plan/src/joins/hash_join/exec.rs | 60 ++++- 2 files changed, 298 insertions(+), 15 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs index fa021ed3dcce3..a03792fae826a 100644 --- a/datafusion/core/tests/physical_optimizer/partition_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -27,7 +27,9 @@ mod test { use datafusion_catalog::TableProvider; use datafusion_common::Result; use datafusion_common::stats::Precision; - use datafusion_common::{ColumnStatistics, ScalarValue, Statistics}; + use datafusion_common::{ + ColumnStatistics, JoinType, NullEquality, ScalarValue, Statistics, + }; use datafusion_execution::TaskContext; use datafusion_execution::config::SessionConfig; use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; @@ -45,7 +47,9 @@ mod test { use datafusion_physical_plan::common::compute_record_batch_statistics; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; - use datafusion_physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec}; + use datafusion_physical_plan::joins::{ + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, + }; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; @@ -1354,4 +1358,249 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_hash_join_partition_statistics() -> Result<()> { + // Create left table scan and coalesce to 1 partition for CollectLeft mode + let left_scan = create_scan_exec_with_statistics(None, Some(2)).await; + let left_scan_coalesced = Arc::new(CoalescePartitionsExec::new(left_scan.clone())) + as Arc; + + // Create right table scan with different table name + let right_create_table_sql = "CREATE EXTERNAL TABLE t2 (id INT NOT NULL, date DATE) \ + STORED AS PARQUET LOCATION './tests/data/test_statistics_per_partition'\ + PARTITIONED BY (date) \ + WITH ORDER (id ASC);"; + let right_scan = + create_scan_exec_with_statistics(Some(right_create_table_sql), Some(2)).await; + + // Create join condition: t1.id = t2.id + let on = vec![( + Arc::new(Column::new("id", 0)) as Arc, + Arc::new(Column::new("id", 0)) as Arc, + )]; + + // Test CollectLeft mode - left child must have 1 partition + let collect_left_join = Arc::new(HashJoinExec::try_new( + left_scan_coalesced, + Arc::clone(&right_scan), + on.clone(), + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?) as Arc; + + // Test partition statistics for CollectLeft mode + let statistics = (0..collect_left_join.output_partitioning().partition_count()) + .map(|idx| collect_left_join.partition_statistics(Some(idx))) + .collect::>>()?; + + // Check that we have the expected number of partitions + assert_eq!(statistics.len(), 2); + + // For collect left mode, the min/max values are from the entire left table and the specific partition of the right table. + let expected_p0_statistics = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + // Left id column: all partitions (id 1..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(16), + }, + // Left date column: all partitions (2025-03-01..2025-03-04) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_04, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(16), + }, + // Right id column: partition 0 only (id 3..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + // Right date column: partition 0 only (2025-03-01..2025-03-02) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_02, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + ], + }; + assert_eq!(statistics[0], expected_p0_statistics); + + // Test Partitioned mode + let partitioned_join = Arc::new(HashJoinExec::try_new( + Arc::clone(&left_scan), + Arc::clone(&right_scan), + on.clone(), + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, + false, + )?) as Arc; + + // Test partition statistics for Partitioned mode + let statistics = (0..partitioned_join.output_partitioning().partition_count()) + .map(|idx| partitioned_join.partition_statistics(Some(idx))) + .collect::>>()?; + + // Check that we have the expected number of partitions + assert_eq!(statistics.len(), 2); + + // For partitioned mode, the min/max values are from the specific partition for each side. + let expected_p0_statistics = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + // Left id column: partition 0 only (id 3..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + // Left date column: partition 0 only (2025-03-01..2025-03-02) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_02, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + // Right id column: partition 0 only (id 3..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + // Right date column: partition 0 only (2025-03-01..2025-03-02) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_02, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + ], + }; + assert_eq!(statistics[0], expected_p0_statistics); + + // Test Auto mode - should fall back to getting all partition statistics + let auto_join = Arc::new(HashJoinExec::try_new( + Arc::clone(&left_scan), + Arc::clone(&right_scan), + on, + None, + &JoinType::Inner, + None, + PartitionMode::Auto, + NullEquality::NullEqualsNothing, + false, + )?) as Arc; + + // Test partition statistics for Auto mode + let statistics = (0..auto_join.output_partitioning().partition_count()) + .map(|idx| auto_join.partition_statistics(Some(idx))) + .collect::>>()?; + + // Check that we have the expected number of partitions + assert_eq!(statistics.len(), 2); + + // For auto mode, the min/max values are from the entire left and right tables. + let expected_p0_statistics = Statistics { + num_rows: Precision::Inexact(4), + total_byte_size: Precision::Absent, + column_statistics: vec![ + // Left id column: all partitions (id 1..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(16), + }, + // Left date column: all partitions (2025-03-01..2025-03-04) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_04, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(16), + }, + // Right id column: all partitions (id 1..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(16), + }, + // Right date column: all partitions (2025-03-01..2025-03-04) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_04, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(16), + }, + ], + }; + assert_eq!(statistics[0], expected_p0_statistics); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index b5e8d3c1e30b1..c07ae2d65173a 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -1442,19 +1442,53 @@ impl ExecutionPlan for HashJoinExec { } fn partition_statistics(&self, partition: Option) -> Result { - if partition.is_some() { - return Ok(Statistics::new_unknown(&self.schema())); - } - // TODO stats: it is not possible in general to know the output size of joins - // There are some special cases though, for example: - // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` - let stats = estimate_join_statistics( - self.left.partition_statistics(None)?, - self.right.partition_statistics(None)?, - &self.on, - &self.join_type, - &self.join_schema, - )?; + let stats = match (partition, self.mode) { + // For CollectLeft mode, the left side is collected into a single partition, + // so all left partitions are available to each output partition. + // For the right side, we need the specific partition statistics. + (Some(partition), PartitionMode::CollectLeft) => { + let left_stats = self.left.partition_statistics(None)?; + let right_stats = self.right.partition_statistics(Some(partition))?; + + estimate_join_statistics( + left_stats, + right_stats, + &self.on, + &self.join_type, + &self.join_schema, + )? + } + + // For Partitioned mode, both sides are partitioned, so each output partition + // only has access to the corresponding partition from both sides. + (Some(partition), PartitionMode::Partitioned) => { + let left_stats = self.left.partition_statistics(Some(partition))?; + let right_stats = self.right.partition_statistics(Some(partition))?; + + estimate_join_statistics( + left_stats, + right_stats, + &self.on, + &self.join_type, + &self.join_schema, + )? + } + + // For Auto mode or when no specific partition is requested, fall back to + // the current behavior of getting all partition statistics. + (None, _) | (Some(_), PartitionMode::Auto) => { + // TODO stats: it is not possible in general to know the output size of joins + // There are some special cases though, for example: + // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` + estimate_join_statistics( + self.left.partition_statistics(None)?, + self.right.partition_statistics(None)?, + &self.on, + &self.join_type, + &self.join_schema, + )? + } + }; // Project statistics if there is a projection let stats = stats.project(self.projection.as_ref()); // Apply fetch limit to statistics