Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 251 additions & 2 deletions datafusion/core/tests/physical_optimizer/partition_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ mod test {
use datafusion_catalog::TableProvider;
use datafusion_common::Result;
use datafusion_common::stats::Precision;
use datafusion_common::{ColumnStatistics, ScalarValue, Statistics};
use datafusion_common::{
ColumnStatistics, JoinType, NullEquality, ScalarValue, Statistics,
};
use datafusion_execution::TaskContext;
use datafusion_execution::config::SessionConfig;
use datafusion_expr::{WindowFrame, WindowFunctionDefinition};
Expand All @@ -45,7 +47,9 @@ mod test {
use datafusion_physical_plan::common::compute_record_batch_statistics;
use datafusion_physical_plan::empty::EmptyExec;
use datafusion_physical_plan::filter::FilterExec;
use datafusion_physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec};
use datafusion_physical_plan::joins::{
CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode,
};
use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr};
Expand Down Expand Up @@ -1354,4 +1358,249 @@ mod test {

Ok(())
}

#[tokio::test]
async fn test_hash_join_partition_statistics() -> Result<()> {
// Create left table scan and coalesce to 1 partition for CollectLeft mode
let left_scan = create_scan_exec_with_statistics(None, Some(2)).await;
let left_scan_coalesced = Arc::new(CoalescePartitionsExec::new(left_scan.clone()))
as Arc<dyn ExecutionPlan>;

// Create right table scan with different table name
let right_create_table_sql = "CREATE EXTERNAL TABLE t2 (id INT NOT NULL, date DATE) \
STORED AS PARQUET LOCATION './tests/data/test_statistics_per_partition'\
PARTITIONED BY (date) \
WITH ORDER (id ASC);";
let right_scan =
create_scan_exec_with_statistics(Some(right_create_table_sql), Some(2)).await;

// Create join condition: t1.id = t2.id
let on = vec![(
Arc::new(Column::new("id", 0)) as Arc<dyn PhysicalExpr>,
Arc::new(Column::new("id", 0)) as Arc<dyn PhysicalExpr>,
)];

// Test CollectLeft mode - left child must have 1 partition
let collect_left_join = Arc::new(HashJoinExec::try_new(
left_scan_coalesced,
Arc::clone(&right_scan),
on.clone(),
None,
&JoinType::Inner,
None,
PartitionMode::CollectLeft,
NullEquality::NullEqualsNothing,
false,
)?) as Arc<dyn ExecutionPlan>;

// Test partition statistics for CollectLeft mode
let statistics = (0..collect_left_join.output_partitioning().partition_count())
.map(|idx| collect_left_join.partition_statistics(Some(idx)))
.collect::<Result<Vec<_>>>()?;

// Check that we have the expected number of partitions
assert_eq!(statistics.len(), 2);

// For collect left mode, the min/max values are from the entire left table and the specific partition of the right table.
let expected_p0_statistics = Statistics {
num_rows: Precision::Inexact(2),
total_byte_size: Precision::Absent,
column_statistics: vec![
// Left id column: all partitions (id 1..4)
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::Int32(Some(4))),
min_value: Precision::Exact(ScalarValue::Int32(Some(1))),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Exact(16),
},
// Left date column: all partitions (2025-03-01..2025-03-04)
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::Date32(Some(
DATE_2025_03_04,
))),
min_value: Precision::Exact(ScalarValue::Date32(Some(
DATE_2025_03_01,
))),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Exact(16),
},
// Right id column: partition 0 only (id 3..4)
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::Int32(Some(4))),
min_value: Precision::Exact(ScalarValue::Int32(Some(3))),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Exact(8),
},
// Right date column: partition 0 only (2025-03-01..2025-03-02)
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::Date32(Some(
DATE_2025_03_02,
))),
min_value: Precision::Exact(ScalarValue::Date32(Some(
DATE_2025_03_01,
))),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Exact(8),
},
],
};
assert_eq!(statistics[0], expected_p0_statistics);

// Test Partitioned mode
let partitioned_join = Arc::new(HashJoinExec::try_new(
Arc::clone(&left_scan),
Arc::clone(&right_scan),
on.clone(),
None,
&JoinType::Inner,
None,
PartitionMode::Partitioned,
NullEquality::NullEqualsNothing,
false,
)?) as Arc<dyn ExecutionPlan>;

// Test partition statistics for Partitioned mode
let statistics = (0..partitioned_join.output_partitioning().partition_count())
.map(|idx| partitioned_join.partition_statistics(Some(idx)))
.collect::<Result<Vec<_>>>()?;

// Check that we have the expected number of partitions
assert_eq!(statistics.len(), 2);

// For partitioned mode, the min/max values are from the specific partition for each side.
let expected_p0_statistics = Statistics {
num_rows: Precision::Inexact(2),
total_byte_size: Precision::Absent,
column_statistics: vec![
// Left id column: partition 0 only (id 3..4)
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::Int32(Some(4))),
min_value: Precision::Exact(ScalarValue::Int32(Some(3))),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Exact(8),
},
// Left date column: partition 0 only (2025-03-01..2025-03-02)
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::Date32(Some(
DATE_2025_03_02,
))),
min_value: Precision::Exact(ScalarValue::Date32(Some(
DATE_2025_03_01,
))),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Exact(8),
},
// Right id column: partition 0 only (id 3..4)
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::Int32(Some(4))),
min_value: Precision::Exact(ScalarValue::Int32(Some(3))),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Exact(8),
},
// Right date column: partition 0 only (2025-03-01..2025-03-02)
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::Date32(Some(
DATE_2025_03_02,
))),
min_value: Precision::Exact(ScalarValue::Date32(Some(
DATE_2025_03_01,
))),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Exact(8),
},
],
};
assert_eq!(statistics[0], expected_p0_statistics);

// Test Auto mode - should fall back to getting all partition statistics
let auto_join = Arc::new(HashJoinExec::try_new(
Arc::clone(&left_scan),
Arc::clone(&right_scan),
on,
None,
&JoinType::Inner,
None,
PartitionMode::Auto,
NullEquality::NullEqualsNothing,
false,
)?) as Arc<dyn ExecutionPlan>;

// Test partition statistics for Auto mode
let statistics = (0..auto_join.output_partitioning().partition_count())
.map(|idx| auto_join.partition_statistics(Some(idx)))
.collect::<Result<Vec<_>>>()?;

// Check that we have the expected number of partitions
assert_eq!(statistics.len(), 2);

// For auto mode, the min/max values are from the entire left and right tables.
let expected_p0_statistics = Statistics {
num_rows: Precision::Inexact(4),
total_byte_size: Precision::Absent,
column_statistics: vec![
// Left id column: all partitions (id 1..4)
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::Int32(Some(4))),
min_value: Precision::Exact(ScalarValue::Int32(Some(1))),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Exact(16),
},
// Left date column: all partitions (2025-03-01..2025-03-04)
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::Date32(Some(
DATE_2025_03_04,
))),
min_value: Precision::Exact(ScalarValue::Date32(Some(
DATE_2025_03_01,
))),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Exact(16),
},
// Right id column: all partitions (id 1..4)
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::Int32(Some(4))),
min_value: Precision::Exact(ScalarValue::Int32(Some(1))),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Exact(16),
},
// Right date column: all partitions (2025-03-01..2025-03-04)
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::Date32(Some(
DATE_2025_03_04,
))),
min_value: Precision::Exact(ScalarValue::Date32(Some(
DATE_2025_03_01,
))),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Exact(16),
},
],
};
assert_eq!(statistics[0], expected_p0_statistics);
Ok(())
}
}
60 changes: 47 additions & 13 deletions datafusion/physical-plan/src/joins/hash_join/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1442,19 +1442,53 @@ impl ExecutionPlan for HashJoinExec {
}

fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
if partition.is_some() {
return Ok(Statistics::new_unknown(&self.schema()));
}
// TODO stats: it is not possible in general to know the output size of joins
// There are some special cases though, for example:
// - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)`
let stats = estimate_join_statistics(
self.left.partition_statistics(None)?,
self.right.partition_statistics(None)?,
&self.on,
&self.join_type,
&self.join_schema,
)?;
let stats = match (partition, self.mode) {
// For CollectLeft mode, the left side is collected into a single partition,
// so all left partitions are available to each output partition.
// For the right side, we need the specific partition statistics.
(Some(partition), PartitionMode::CollectLeft) => {
let left_stats = self.left.partition_statistics(None)?;
let right_stats = self.right.partition_statistics(Some(partition))?;

estimate_join_statistics(
left_stats,
right_stats,
&self.on,
&self.join_type,
&self.join_schema,
)?
}

// For Partitioned mode, both sides are partitioned, so each output partition
// only has access to the corresponding partition from both sides.
(Some(partition), PartitionMode::Partitioned) => {
let left_stats = self.left.partition_statistics(Some(partition))?;
let right_stats = self.right.partition_statistics(Some(partition))?;

estimate_join_statistics(
left_stats,
right_stats,
&self.on,
&self.join_type,
&self.join_schema,
)?
}

// For Auto mode or when no specific partition is requested, fall back to
// the current behavior of getting all partition statistics.
(None, _) | (Some(_), PartitionMode::Auto) => {
// TODO stats: it is not possible in general to know the output size of joins
// There are some special cases though, for example:
// - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)`
estimate_join_statistics(
self.left.partition_statistics(None)?,
self.right.partition_statistics(None)?,
&self.on,
&self.join_type,
&self.join_schema,
)?
}
};
// Project statistics if there is a projection
let stats = stats.project(self.projection.as_ref());
// Apply fetch limit to statistics
Expand Down