diff --git a/Cargo.lock b/Cargo.lock index af52588e5338e..f20e10e21c4f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2399,6 +2399,7 @@ dependencies = [ "datafusion-expr", "datafusion-expr-common", "datafusion-functions", + "datafusion-functions-aggregate", "datafusion-functions-window", "datafusion-physical-expr", "datafusion-physical-expr-common", diff --git a/datafusion/core/src/optimizer_rule_reference.md b/datafusion/core/src/optimizer_rule_reference.md index fcbb200c71624..6747f1316f2b5 100644 --- a/datafusion/core/src/optimizer_rule_reference.md +++ b/datafusion/core/src/optimizer_rule_reference.md @@ -76,18 +76,19 @@ in multiple phases. | 5 | `FilterPushdown` | pre-optimization phase | Pushes supported physical filters down toward data sources before distribution and sorting are enforced. | | 6 | `EnforceDistribution` | - | Adds repartitioning only where needed to satisfy physical distribution requirements. | | 7 | `CombinePartialFinalAggregate` | - | Collapses adjacent partial and final aggregates when the distributed shape makes them redundant. | -| 8 | `EnforceSorting` | - | Adds or removes local sorts to satisfy required input orderings. | -| 9 | `OptimizeAggregateOrder` | - | Updates aggregate expressions to use the best ordering once sort requirements are known. | -| 10 | `WindowTopN` | - | Replaces eligible row-number window and filter patterns with per-partition TopK execution. | -| 11 | `ProjectionPushdown` | early pass | Pushes projections toward inputs before later physical rewrites add more limit and TopK structure. | -| 12 | `OutputRequirements` | remove phase | Removes the temporary output-requirement helper nodes after requirement-sensitive planning is done. | -| 13 | `LimitAggregation` | - | Passes a limit hint into eligible aggregations so they can keep fewer accumulator buckets. | -| 14 | `LimitPushPastWindows` | - | Pushes fetch limits through bounded window operators when doing so keeps the result correct. | -| 15 | `HashJoinBuffering` | - | Adds buffering on the probe side of hash joins so probing can start before build completion. | -| 16 | `LimitPushdown` | - | Moves physical limits into child operators or fetch-enabled variants to cut data early. | -| 17 | `TopKRepartition` | - | Pushes TopK below hash repartition when the partition key is a prefix of the sort key. | -| 18 | `ProjectionPushdown` | late pass | Runs projection pushdown again after limit and TopK rewrites expose new pruning opportunities. | -| 19 | `PushdownSort` | - | Pushes sort requirements into data sources that can already return sorted output. | -| 20 | `EnsureCooperative` | - | Wraps non-cooperative plan parts so long-running tasks yield fairly. | -| 21 | `FilterPushdown(Post)` | post-optimization phase | Pushes dynamic filters at the end of optimization, after plan references stop moving. | -| 22 | `SanityCheckPlan` | - | Validates that the final physical plan meets ordering, distribution, and infinite-input safety requirements. | +| 8 | `group_join` | - | Fuses eligible aggregate-over-hash-join plans when grouping keys match join keys. | +| 9 | `EnforceSorting` | - | Adds or removes local sorts to satisfy required input orderings. | +| 10 | `OptimizeAggregateOrder` | - | Updates aggregate expressions to use the best ordering once sort requirements are known. | +| 11 | `WindowTopN` | - | Replaces eligible row-number window and filter patterns with per-partition TopK execution. | +| 12 | `ProjectionPushdown` | early pass | Pushes projections toward inputs before later physical rewrites add more limit and TopK structure. | +| 13 | `OutputRequirements` | remove phase | Removes the temporary output-requirement helper nodes after requirement-sensitive planning is done. | +| 14 | `LimitAggregation` | - | Passes a limit hint into eligible aggregations so they can keep fewer accumulator buckets. | +| 15 | `LimitPushPastWindows` | - | Pushes fetch limits through bounded window operators when doing so keeps the result correct. | +| 16 | `HashJoinBuffering` | - | Adds buffering on the probe side of hash joins so probing can start before build completion. | +| 17 | `LimitPushdown` | - | Moves physical limits into child operators or fetch-enabled variants to cut data early. | +| 18 | `TopKRepartition` | - | Pushes TopK below hash repartition when the partition key is a prefix of the sort key. | +| 19 | `ProjectionPushdown` | late pass | Runs projection pushdown again after limit and TopK rewrites expose new pruning opportunities. | +| 20 | `PushdownSort` | - | Pushes sort requirements into data sources that can already return sorted output. | +| 21 | `EnsureCooperative` | - | Wraps non-cooperative plan parts so long-running tasks yield fairly. | +| 22 | `FilterPushdown(Post)` | post-optimization phase | Pushes dynamic filters at the end of optimization, after plan references stop moving. | +| 23 | `SanityCheckPlan` | - | Validates that the final physical plan meets ordering, distribution, and infinite-input safety requirements. | diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index 38c8a7c37211f..4b65746ae73a5 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -56,6 +56,7 @@ recursive = { workspace = true, optional = true } [dev-dependencies] datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-functions-window = { workspace = true } insta = { workspace = true } tokio = { workspace = true } diff --git a/datafusion/physical-optimizer/src/group_join.rs b/datafusion/physical-optimizer/src/group_join.rs new file mode 100644 index 0000000000000..94916a9c46805 --- /dev/null +++ b/datafusion/physical-optimizer/src/group_join.rs @@ -0,0 +1,412 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GroupJoinOptimizer`] replaces an `AggregateExec` directly above a +//! `HashJoinExec` with a fused [`GroupJoinExec`] when the aggregate's GROUP BY +//! keys match the join's equi-join keys. +//! +//! Based on: Moerkotte & Neumann, "Accelerating Queries with Group-By and Join +//! by Groupjoin", PVLDB 4(11), 2011. + +use std::sync::Arc; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{JoinType, Result}; +use datafusion_physical_expr::physical_exprs_equal; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::aggregates::{AggregateExec, AggregateMode}; +use datafusion_physical_plan::joins::group_join::GroupJoinExec; +use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; +use datafusion_physical_plan::projection::ProjectionExec; + +use crate::PhysicalOptimizerRule; + +/// Replaces `AggregateExec(HashJoinExec)` with a fused `GroupJoinExec` when: +/// +/// 1. The aggregate mode is `Single`, `SinglePartitioned`, or `Partial` +/// 2. The aggregate has at least one aggregate expression (not just DISTINCT) +/// 3. The aggregate has no GROUPING SETS +/// 4. The input is a `HashJoinExec` (possibly through a `ProjectionExec`) +/// 5. The join type is `Inner` or `Left` +/// 6. The join has no residual filter (equi-join only) +/// 7. The GROUP BY expressions exactly match the left join keys +/// 8. All aggregate functions support `GroupsAccumulator` +/// +/// This rule should run after `CombinePartialFinalAggregate` (which may +/// collapse two-phase aggregation into Single mode) and after `JoinSelection` +/// (which decides build/probe sides). +#[derive(Default, Debug)] +pub struct GroupJoinOptimizer {} + +impl GroupJoinOptimizer { + /// Create a new `GroupJoinOptimizer`. + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for GroupJoinOptimizer { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_down(|plan| { + let Some(agg_exec) = plan.downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + + if !matches!( + agg_exec.mode(), + AggregateMode::Single + | AggregateMode::SinglePartitioned + | AggregateMode::Partial + ) { + return Ok(Transformed::no(plan)); + } + + // Must have actual aggregate functions (not just GROUP BY for DISTINCT) + let aggr_exprs = agg_exec.aggr_expr(); + if aggr_exprs.is_empty() { + return Ok(Transformed::no(plan)); + } + + // No GROUPING SETS + if agg_exec.group_expr().groups().len() > 1 { + return Ok(Transformed::no(plan)); + } + + // Find HashJoinExec (possibly through a ProjectionExec) + let input = agg_exec.input(); + let hash_join: &HashJoinExec; + if let Some(hj) = input.downcast_ref::() { + hash_join = hj; + } else if let Some(proj) = input.downcast_ref::() { + if let Some(hj) = proj.input().downcast_ref::() { + hash_join = hj; + } else { + return Ok(Transformed::no(plan)); + } + } else { + return Ok(Transformed::no(plan)); + }; + + // Inner and Left joins + if !matches!( + hash_join.join_type(), + JoinType::Inner | JoinType::Left + ) { + return Ok(Transformed::no(plan)); + } + + // No residual join filter (equi-join only) + if hash_join.filter().is_some() { + return Ok(Transformed::no(plan)); + } + + // GroupJoinExec requires partitioned inputs keyed by the join + // expressions. CollectLeft hash joins do not provide that shape. + if *hash_join.partition_mode() != PartitionMode::Partitioned { + return Ok(Transformed::no(plan)); + } + + // GROUP BY keys must exactly match left join keys + let group_exprs: Vec<_> = agg_exec + .group_expr() + .expr() + .iter() + .map(|(expr, _)| Arc::clone(expr)) + .collect(); + + let join_on = hash_join.on(); + let left_join_keys: Vec<_> = + join_on.iter().map(|(l, _)| Arc::clone(l)).collect(); + + if group_exprs.len() != left_join_keys.len() { + return Ok(Transformed::no(plan)); + } + + if !physical_exprs_equal(&group_exprs, &left_join_keys) { + return Ok(Transformed::no(plan)); + } + + // All aggregates must support GroupsAccumulator + for agg in aggr_exprs { + if !agg.groups_accumulator_supported() { + return Ok(Transformed::no(plan)); + } + } + + // For Inner joins, skip if any aggregate has a literal argument + // (e.g., COUNT(*) rewritten as count(Int64(1))). These queries + // don't benefit enough from GroupJoin to justify changing the plan. + if *hash_join.join_type() == JoinType::Inner { + let has_literal_arg = aggr_exprs.iter().any(|agg| { + agg.expressions().iter().any(|expr| { + expr.as_ref() + .downcast_ref::() + .is_some() + }) + }); + if has_literal_arg { + return Ok(Transformed::no(plan)); + } + } + + // All preconditions met — create GroupJoinExec + let group_by_with_names: Vec<_> = agg_exec + .group_expr() + .expr() + .iter() + .map(|(expr, name)| (Arc::clone(expr), name.clone())) + .collect(); + + let group_join = GroupJoinExec::try_new_with_aggr_input_schema( + Arc::clone(hash_join.left()), + Arc::clone(hash_join.right()), + join_on.to_vec(), + *hash_join.join_type(), + group_by_with_names, + aggr_exprs.to_vec(), + agg_exec.input_schema(), + )?; + + Ok(Transformed::yes( + Arc::new(group_join) as Arc + )) + }) + .data() + } + + fn name(&self) -> &str { + "group_join" + } + + fn schema_check(&self) -> bool { + false // Schema changes (aggregate output differs from join output) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::NullEquality; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::PhysicalExprRef; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::{col, lit}; + use datafusion_physical_plan::displayable; + use datafusion_physical_plan::empty::EmptyExec; + use datafusion_physical_plan::joins::PartitionMode; + use datafusion_physical_plan::projection::ProjectionExec; + use insta::assert_snapshot; + + fn left_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("l_key", DataType::Int32, false), + Field::new("l_value", DataType::Int32, true), + ])) + } + + fn right_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("r_key", DataType::Int32, false), + Field::new("r_value", DataType::Int32, true), + ])) + } + + fn join( + join_type: JoinType, + left_key: &str, + partition_mode: PartitionMode, + ) -> Result> { + let left_schema = left_schema(); + let right_schema = right_schema(); + let left = Arc::new(EmptyExec::new(Arc::clone(&left_schema))); + let right = Arc::new(EmptyExec::new(Arc::clone(&right_schema))); + + Ok(Arc::new(HashJoinExec::try_new( + left, + right, + vec![(col(left_key, &left_schema)?, col("r_key", &right_schema)?)], + None, + &join_type, + None, + partition_mode, + NullEquality::NullEqualsNull, + false, + )?)) + } + + fn partitioned_join( + join_type: JoinType, + left_key: &str, + ) -> Result> { + join(join_type, left_key, PartitionMode::Partitioned) + } + + fn aggregate( + input: Arc, + group_expr: PhysicalExprRef, + aggr_expr: PhysicalExprRef, + ) -> Result> { + let input_schema = input.schema(); + let aggr_expr = Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![aggr_expr]) + .schema(Arc::clone(&input_schema)) + .alias("count_values") + .build()?, + ); + + Ok(Arc::new(AggregateExec::try_new( + AggregateMode::Single, + datafusion_physical_plan::aggregates::PhysicalGroupBy::new_single(vec![( + group_expr, + "l_key".to_string(), + )]), + vec![aggr_expr], + vec![None], + input, + input_schema, + )?)) + } + + fn optimize(plan: Arc) -> Result { + let optimized = + GroupJoinOptimizer::new().optimize(plan, &ConfigOptions::new())?; + Ok(displayable(optimized.as_ref()).indent(true).to_string()) + } + + #[test] + fn rewrites_aggregate_above_inner_hash_join() -> Result<()> { + let join = partitioned_join(JoinType::Inner, "l_key")?; + let join_schema = join.schema(); + let plan = aggregate( + join, + col("l_key", &join_schema)?, + col("r_value", &join_schema)?, + )?; + + assert_snapshot!(optimize(plan)?, @r" + GroupJoinExec: join_type=Inner, on=[(l_key@0, r_key@0)], aggr=[count_values] + EmptyExec + EmptyExec + "); + Ok(()) + } + + #[test] + fn rewrites_through_projection() -> Result<()> { + let join = partitioned_join(JoinType::Left, "l_key")?; + let join_schema = join.schema(); + let projection = Arc::new(ProjectionExec::try_new( + vec![ + (col("l_key", &join_schema)?, "l_key".to_string()), + (col("r_value", &join_schema)?, "r_value".to_string()), + ], + join, + )?); + let projection_schema = projection.schema(); + let plan = aggregate( + projection, + col("l_key", &projection_schema)?, + col("r_value", &projection_schema)?, + )?; + + assert_snapshot!(optimize(plan)?, @r" + GroupJoinExec: join_type=Left, on=[(l_key@0, r_key@0)], aggr=[count_values] + EmptyExec + EmptyExec + "); + Ok(()) + } + + #[test] + fn does_not_rewrite_when_group_by_does_not_match_join_key() -> Result<()> { + let join = partitioned_join(JoinType::Inner, "l_key")?; + let join_schema = join.schema(); + let plan = aggregate( + join, + col("l_value", &join_schema)?, + col("r_value", &join_schema)?, + )?; + + assert_snapshot!(optimize(plan)?, @r" + AggregateExec: mode=Single, gby=[l_value@1 as l_key], aggr=[count_values] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_key@0, r_key@0)], NullsEqual: true + EmptyExec + EmptyExec + "); + Ok(()) + } + + #[test] + fn does_not_rewrite_unsupported_join_type() -> Result<()> { + let join = partitioned_join(JoinType::Right, "l_key")?; + let join_schema = join.schema(); + let plan = aggregate( + join, + col("l_key", &join_schema)?, + col("r_value", &join_schema)?, + )?; + + assert_snapshot!(optimize(plan)?, @r" + AggregateExec: mode=Single, gby=[l_key@0 as l_key], aggr=[count_values] + HashJoinExec: mode=Partitioned, join_type=Right, on=[(l_key@0, r_key@0)], NullsEqual: true + EmptyExec + EmptyExec + "); + Ok(()) + } + + #[test] + fn does_not_rewrite_inner_join_with_literal_aggregate_argument() -> Result<()> { + let join = partitioned_join(JoinType::Inner, "l_key")?; + let join_schema = join.schema(); + let plan = aggregate(join, col("l_key", &join_schema)?, lit(1i64))?; + + assert_snapshot!(optimize(plan)?, @r" + AggregateExec: mode=Single, gby=[l_key@0 as l_key], aggr=[count_values] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_key@0, r_key@0)], NullsEqual: true + EmptyExec + EmptyExec + "); + Ok(()) + } + + #[test] + fn does_not_rewrite_collect_left_hash_join() -> Result<()> { + let join = join(JoinType::Left, "l_key", PartitionMode::CollectLeft)?; + let join_schema = join.schema(); + let plan = aggregate( + join, + col("l_key", &join_schema)?, + col("r_value", &join_schema)?, + )?; + + assert_snapshot!(optimize(plan)?, @r" + AggregateExec: mode=Single, gby=[l_key@0 as l_key], aggr=[count_values] + HashJoinExec: mode=CollectLeft, join_type=Left, on=[(l_key@0, r_key@0)], NullsEqual: true + EmptyExec + EmptyExec + "); + Ok(()) + } +} diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index 5fac8948b7f04..4079c875aba12 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -31,6 +31,7 @@ pub mod enforce_distribution; pub mod enforce_sorting; pub mod ensure_coop; pub mod filter_pushdown; +pub mod group_join; pub mod join_selection; pub mod limit_pushdown; pub mod limit_pushdown_past_window; diff --git a/datafusion/physical-optimizer/src/optimizer.rs b/datafusion/physical-optimizer/src/optimizer.rs index 05df642f8446b..17befb93989c2 100644 --- a/datafusion/physical-optimizer/src/optimizer.rs +++ b/datafusion/physical-optimizer/src/optimizer.rs @@ -26,6 +26,7 @@ use crate::enforce_distribution::EnforceDistribution; use crate::enforce_sorting::EnforceSorting; use crate::ensure_coop::EnsureCooperative; use crate::filter_pushdown::FilterPushdown; +use crate::group_join::GroupJoinOptimizer; use crate::join_selection::JoinSelection; use crate::limit_pushdown::LimitPushdown; use crate::limited_distinct_aggregation::LimitedDistinctAggregation; @@ -177,6 +178,11 @@ impl PhysicalOptimizer { Arc::new(EnforceDistribution::new()), // The CombinePartialFinalAggregate rule should be applied after the EnforceDistribution rule Arc::new(CombinePartialFinalAggregate::new()), + // GroupJoinOptimizer fuses Aggregate+HashJoin into GroupJoinExec + // when GROUP BY keys match join keys. Runs after + // CombinePartialFinalAggregate (which creates SinglePartitioned + // aggregates) and JoinSelection (which decides build/probe sides). + Arc::new(GroupJoinOptimizer::new()), // The EnforceSorting rule is for adding essential local sorting to satisfy the required // ordering. Please make sure that the whole plan tree is determined before this rule. // Note that one should always run this rule after running the EnforceDistribution rule diff --git a/datafusion/physical-plan/src/joins/group_join.rs b/datafusion/physical-plan/src/joins/group_join.rs new file mode 100644 index 0000000000000..2c8e69666ce94 --- /dev/null +++ b/datafusion/physical-plan/src/joins/group_join.rs @@ -0,0 +1,581 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GroupJoinExec`] fuses a hash join and a subsequent group-by aggregation +//! into a single operator when they share the same key. +//! +//! Based on: Moerkotte & Neumann, "Accelerating Queries with Group-By and Join +//! by Groupjoin", PVLDB 4(11), 2011. Strategy 2 (Memoizing GroupJoin) from +//! Fent et al., VLDB Journal 2023. +//! +//! Instead of building two hash tables (one for the join, one for aggregation), +//! GroupJoin builds a single hash table on the build side with aggregate +//! accumulators embedded in each entry. During the probe phase, matching rows +//! update the accumulators in-place, avoiding materialization of the full +//! intermediate join result. + +use std::fmt::{self, Debug}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::{ArrayRef, BooleanArray, BooleanBufferBuilder}; +use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{JoinType, Result, internal_err}; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::{ + EquivalenceProperties, GroupsAccumulatorAdapter, PhysicalExpr, PhysicalExprRef, +}; +use log::debug; + +use crate::aggregates::group_values::{GroupValues, new_group_values}; +use crate::aggregates::order::GroupOrdering; +use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use crate::{DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, PlanProperties}; + +use futures::{Stream, StreamExt, ready}; + +/// A fused join + group-by operator that combines a hash join and subsequent +/// aggregation into a single operator when they share the same key. +/// +/// # Preconditions +/// +/// These are enforced by the `GroupJoinOptimizer` physical optimizer rule: +/// - The aggregate GROUP BY keys exactly match the join equi-keys +/// - The join type is Inner or Left +/// - All aggregate functions support [`GroupsAccumulator`] +/// - The aggregate has at least one aggregate expression (not just DISTINCT) +/// +/// # Algorithm +/// +/// 1. **Build phase**: Consume the left (build) side, interning each row's +/// group key into a [`GroupValues`] hash table. +/// 2. **Probe phase**: For each right (probe) batch, look up group indices +/// via the same hash table and update [`GroupsAccumulator`]s in-place. +/// Probe rows that don't match any build-side group are filtered out +/// (for Inner join) or ignored (for Left join — the build-side group +/// retains its initial accumulator value). +/// 3. **Emit phase**: Scan the hash table and produce one output row per +/// build-side group: the group key columns plus the evaluated aggregates. +/// +#[derive(Debug)] +pub struct GroupJoinExec { + /// Build side (left) — each unique key becomes one output group + left: Arc, + /// Probe side (right) — rows update aggregate accumulators + right: Arc, + /// Equi-join key expressions: (left_expr, right_expr) + on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + /// Join type (Inner or Left) + join_type: JoinType, + /// GROUP BY expressions with output aliases (evaluated against build side) + group_by_exprs: Vec<(PhysicalExprRef, String)>, + /// Aggregate function expressions (e.g., COUNT, SUM) + aggr_expr: Vec>, + /// Input schema used to build aggregate expressions + aggr_input_schema: SchemaRef, + /// Output schema: group-by columns + aggregate outputs + schema: SchemaRef, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Cached plan properties + cache: Arc, +} + +impl GroupJoinExec { + /// Create a new `GroupJoinExec`. + /// + /// Returns an error if the join type is not supported. + pub fn try_new( + left: Arc, + right: Arc, + on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + join_type: JoinType, + group_by_exprs: Vec<(PhysicalExprRef, String)>, + aggr_expr: Vec>, + ) -> Result { + let aggr_input_schema = right.schema(); + Self::try_new_with_aggr_input_schema( + left, + right, + on, + join_type, + group_by_exprs, + aggr_expr, + aggr_input_schema, + ) + } + + /// Create a new `GroupJoinExec` with the schema used to build aggregate + /// expressions. + pub fn try_new_with_aggr_input_schema( + left: Arc, + right: Arc, + on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + join_type: JoinType, + group_by_exprs: Vec<(PhysicalExprRef, String)>, + aggr_expr: Vec>, + aggr_input_schema: SchemaRef, + ) -> Result { + if !matches!(join_type, JoinType::Inner | JoinType::Left) { + return internal_err!( + "GroupJoinExec only supports Inner/Left joins, got {:?}", + join_type + ); + } + + let left_schema = left.schema(); + let mut fields: Vec> = Vec::new(); + + for (expr, alias) in &group_by_exprs { + let dt = expr.data_type(&left_schema)?; + let nullable = expr.nullable(&left_schema)?; + fields.push(Arc::new(Field::new(alias, dt, nullable))); + } + + for agg in &aggr_expr { + fields.push(Arc::clone(&agg.field())); + } + + let schema = Arc::new(Schema::new(fields)); + + let props = left.properties(); + let cache = Arc::new(PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + props.partitioning.clone(), + props.emission_type, + props.boundedness, + )); + + Ok(Self { + left, + right, + on, + join_type, + group_by_exprs, + aggr_expr, + aggr_input_schema, + schema, + metrics: ExecutionPlanMetricsSet::new(), + cache, + }) + } + + /// Build side input. + pub fn left(&self) -> &Arc { + &self.left + } + + /// Probe side input. + pub fn right(&self) -> &Arc { + &self.right + } + + /// Equi-join key expressions. + pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] { + &self.on + } + + /// Join type. + pub fn join_type(&self) -> &JoinType { + &self.join_type + } + + /// GROUP BY expressions with output aliases. + pub fn group_by_exprs(&self) -> &[(PhysicalExprRef, String)] { + &self.group_by_exprs + } + + /// Aggregate expressions. + pub fn aggr_expr(&self) -> &[Arc] { + &self.aggr_expr + } + + /// Input schema used to build aggregate expressions. + pub fn aggr_input_schema(&self) -> &SchemaRef { + &self.aggr_input_schema + } +} + +impl DisplayAs for GroupJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let on: Vec = + self.on.iter().map(|(l, r)| format!("({l}, {r})")).collect(); + let aggrs: Vec = self + .aggr_expr + .iter() + .map(|a| a.name().to_string()) + .collect(); + write!( + f, + "GroupJoinExec: join_type={:?}, on=[{}], aggr=[{}]", + self.join_type, + on.join(", "), + aggrs.join(", "), + ) + } + DisplayFormatType::TreeRender => { + write!(f, "GroupJoinExec") + } + } + } +} + +impl ExecutionPlan for GroupJoinExec { + fn name(&self) -> &'static str { + "GroupJoinExec" + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(GroupJoinExec::try_new_with_aggr_input_schema( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.on.clone(), + self.join_type, + self.group_by_exprs.clone(), + self.aggr_expr.clone(), + Arc::clone(&self.aggr_input_schema), + )?)) + } + fn required_input_distribution(&self) -> Vec { + let left_exprs: Vec = + self.on.iter().map(|(l, _)| Arc::clone(l)).collect(); + let right_exprs: Vec = + self.on.iter().map(|(_, r)| Arc::clone(r)).collect(); + vec![ + Distribution::HashPartitioned(left_exprs), + Distribution::HashPartitioned(right_exprs), + ] + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + for (left_expr, right_expr) in &self.on { + let r = f(left_expr.as_ref())?; + if r == TreeNodeRecursion::Stop { + return Ok(r); + } + let r = f(right_expr.as_ref())?; + if r == TreeNodeRecursion::Stop { + return Ok(r); + } + } + Ok(TreeNodeRecursion::Continue) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let left_stream = self.left.execute(partition, Arc::clone(&context))?; + let right_stream = self.right.execute(partition, Arc::clone(&context))?; + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + let left_schema = self.left.schema(); + let key_fields: Vec> = self + .group_by_exprs + .iter() + .map(|(expr, alias)| { + let dt = expr.data_type(&left_schema).unwrap(); + let nullable = expr.nullable(&left_schema).unwrap(); + Arc::new(Field::new(alias, dt, nullable)) + }) + .collect(); + let key_schema = Arc::new(Schema::new(key_fields)); + + let accumulators: Vec> = self + .aggr_expr + .iter() + .map(|agg_expr| -> Result> { + if agg_expr.groups_accumulator_supported() { + agg_expr.create_groups_accumulator() + } else { + debug!( + "GroupJoinExec: using GroupsAccumulatorAdapter for {}", + agg_expr.name() + ); + let captured = Arc::clone(agg_expr); + let factory = move || captured.create_accumulator(); + Ok(Box::new(GroupsAccumulatorAdapter::new(factory))) + } + }) + .collect::>()?; + + let group_values = new_group_values(key_schema, &GroupOrdering::None)?; + + let group_by_exprs: Vec = self + .group_by_exprs + .iter() + .map(|(expr, _)| Arc::clone(expr)) + .collect(); + + Ok(Box::pin(GroupJoinStream { + left_stream: Some(left_stream), + right_stream, + right_on: self.on.iter().map(|(_, r)| Arc::clone(r)).collect(), + group_by_exprs, + aggr_expr: self.aggr_expr.clone(), + accumulators, + group_values, + schema: Arc::clone(&self.schema), + baseline_metrics, + state: GroupJoinState::CollectBuildSide, + group_indices: Vec::new(), + num_build_groups: 0, + join_type: self.join_type, + visited: BooleanBufferBuilder::new(0), + })) + } +} + +// ── Stream implementation ────────────────────────────────────────────── + +enum GroupJoinState { + CollectBuildSide, + Probe, + Emit, + Done, +} + +struct GroupJoinStream { + left_stream: Option, + right_stream: SendableRecordBatchStream, + right_on: Vec, + group_by_exprs: Vec, + aggr_expr: Vec>, + accumulators: Vec>, + group_values: Box, + schema: SchemaRef, + baseline_metrics: BaselineMetrics, + state: GroupJoinState, + group_indices: Vec, + num_build_groups: usize, + join_type: JoinType, + /// Tracks which build-side groups received at least one probe match. + /// For Inner joins, only visited groups are emitted. + /// For Left joins, all build-side groups are emitted (visited or not). + visited: BooleanBufferBuilder, +} + +impl Debug for GroupJoinStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GroupJoinStream").finish_non_exhaustive() + } +} + +impl GroupJoinStream { + fn collect_build_side(&mut self, cx: &mut Context<'_>) -> Poll> { + let stream = self.left_stream.as_mut().unwrap(); + loop { + match ready!(stream.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + continue; + } + let key_arrays: Vec = self + .group_by_exprs + .iter() + .map(|expr| { + expr.evaluate(&batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect::>()?; + + self.group_indices.clear(); + self.group_values + .intern(&key_arrays, &mut self.group_indices)?; + } + Some(Err(e)) => return Poll::Ready(Err(e)), + None => { + self.left_stream = None; + self.num_build_groups = self.group_values.len(); + // Initialize visited bitmap for Inner join filtering + self.visited = BooleanBufferBuilder::new(self.num_build_groups); + self.visited.append_n(self.num_build_groups, false); + return Poll::Ready(Ok(())); + } + } + } + } + + fn process_probe_batch(&mut self, batch: &RecordBatch) -> Result<()> { + if batch.num_rows() == 0 { + return Ok(()); + } + + let key_arrays: Vec = self + .right_on + .iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect::>()?; + + self.group_indices.clear(); + self.group_values + .intern(&key_arrays, &mut self.group_indices)?; + let total_num_groups = self.group_values.len(); + + // Filter: only update accumulators for rows matching build-side groups. + // Probe rows creating new groups (not in build side) are excluded. + let filter: Option = if total_num_groups > self.num_build_groups { + let mask: Vec = self + .group_indices + .iter() + .map(|&idx| idx < self.num_build_groups) + .collect(); + Some(BooleanArray::from(mask)) + } else { + None + }; + + // Mark build-side groups that received at least one probe match + for &idx in &self.group_indices { + if idx < self.num_build_groups { + self.visited.set_bit(idx, true); + } + } + + for (acc_idx, agg) in self.aggr_expr.iter().enumerate() { + let values: Vec = agg + .expressions() + .iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect::>()?; + + self.accumulators[acc_idx].update_batch( + &values, + &self.group_indices, + filter.as_ref(), + total_num_groups, + )?; + } + + Ok(()) + } + + fn emit_results(&mut self) -> Result { + if self.num_build_groups == 0 { + return Ok(RecordBatch::new_empty(Arc::clone(&self.schema))); + } + + let emit_to = if self.group_values.len() > self.num_build_groups { + EmitTo::First(self.num_build_groups) + } else { + EmitTo::All + }; + + let mut columns: Vec = self.group_values.emit(emit_to)?; + + for acc in &mut self.accumulators { + columns.push(acc.evaluate(emit_to)?); + } + + let batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + + // For Inner joins, filter to only groups that had at least one probe match + if self.join_type == JoinType::Inner { + let visited_mask = BooleanArray::new(self.visited.finish(), None); + let filtered = arrow::compute::filter_record_batch(&batch, &visited_mask)?; + Ok(filtered) + } else { + // Left joins emit all build-side groups (unmatched get default values) + Ok(batch) + } + } +} + +impl Stream for GroupJoinStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + match self.state { + GroupJoinState::CollectBuildSide => { + match ready!(self.collect_build_side(cx)) { + Ok(()) => self.state = GroupJoinState::Probe, + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + GroupJoinState::Probe => { + match ready!(self.right_stream.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if let Err(e) = self.process_probe_batch(&batch) { + return Poll::Ready(Some(Err(e))); + } + } + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => { + self.state = GroupJoinState::Emit; + } + } + } + GroupJoinState::Emit => { + self.state = GroupJoinState::Done; + let result = self.emit_results(); + if let Ok(ref batch) = result { + self.baseline_metrics.record_output(batch.num_rows()); + } + return Poll::Ready(Some(result)); + } + GroupJoinState::Done => { + return Poll::Ready(None); + } + } + } + } +} + +impl RecordBatchStream for GroupJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 2cdfa1e6ac020..0d99d60bcfcae 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -31,6 +31,7 @@ pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; pub mod chain; mod cross_join; +pub mod group_join; mod hash_join; mod nested_loop_join; mod piecewise_merge_join; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 511e8eb1b012e..b2f130f96a5bd 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -788,6 +788,7 @@ message PhysicalPlanNode { BufferExecNode buffer = 37; ArrowScanExecNode arrow_scan = 38; ScalarSubqueryExecNode scalar_subquery = 39; + GroupJoinExecNode group_join = 40; } } @@ -1189,6 +1190,18 @@ message HashJoinExecNode { bool null_aware = 10; } +message GroupJoinExecNode { + PhysicalPlanNode left = 1; + PhysicalPlanNode right = 2; + repeated JoinOn on = 3; + datafusion_common.JoinType join_type = 4; + repeated PhysicalExprNode group_expr = 5; + repeated string group_expr_name = 6; + repeated PhysicalExprNode aggr_expr = 7; + repeated string aggr_expr_name = 8; + datafusion_common.Schema input_schema = 9; +} + enum StreamPartitionMode { SINGLE_PARTITION = 0; PARTITIONED_EXEC = 1; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index c05d3283eac8e..4a3575fe66e1b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -8691,6 +8691,241 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { deserializer.deserialize_struct("datafusion.GlobalLimitExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for GroupJoinExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.left.is_some() { + len += 1; + } + if self.right.is_some() { + len += 1; + } + if !self.on.is_empty() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if !self.group_expr.is_empty() { + len += 1; + } + if !self.group_expr_name.is_empty() { + len += 1; + } + if !self.aggr_expr.is_empty() { + len += 1; + } + if !self.aggr_expr_name.is_empty() { + len += 1; + } + if self.input_schema.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.GroupJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if !self.on.is_empty() { + struct_ser.serialize_field("on", &self.on)?; + } + if self.join_type != 0 { + let v = super::datafusion_common::JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if !self.group_expr.is_empty() { + struct_ser.serialize_field("groupExpr", &self.group_expr)?; + } + if !self.group_expr_name.is_empty() { + struct_ser.serialize_field("groupExprName", &self.group_expr_name)?; + } + if !self.aggr_expr.is_empty() { + struct_ser.serialize_field("aggrExpr", &self.aggr_expr)?; + } + if !self.aggr_expr_name.is_empty() { + struct_ser.serialize_field("aggrExprName", &self.aggr_expr_name)?; + } + if let Some(v) = self.input_schema.as_ref() { + struct_ser.serialize_field("inputSchema", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for GroupJoinExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "left", + "right", + "on", + "join_type", + "joinType", + "group_expr", + "groupExpr", + "group_expr_name", + "groupExprName", + "aggr_expr", + "aggrExpr", + "aggr_expr_name", + "aggrExprName", + "input_schema", + "inputSchema", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Left, + Right, + On, + JoinType, + GroupExpr, + GroupExprName, + AggrExpr, + AggrExprName, + InputSchema, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "on" => Ok(GeneratedField::On), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "groupExpr" | "group_expr" => Ok(GeneratedField::GroupExpr), + "groupExprName" | "group_expr_name" => Ok(GeneratedField::GroupExprName), + "aggrExpr" | "aggr_expr" => Ok(GeneratedField::AggrExpr), + "aggrExprName" | "aggr_expr_name" => Ok(GeneratedField::AggrExprName), + "inputSchema" | "input_schema" => Ok(GeneratedField::InputSchema), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GroupJoinExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.GroupJoinExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut left__ = None; + let mut right__ = None; + let mut on__ = None; + let mut join_type__ = None; + let mut group_expr__ = None; + let mut group_expr_name__ = None; + let mut aggr_expr__ = None; + let mut aggr_expr_name__ = None; + let mut input_schema__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); + } + left__ = map_.next_value()?; + } + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; + } + GeneratedField::On => { + if on__.is_some() { + return Err(serde::de::Error::duplicate_field("on")); + } + on__ = Some(map_.next_value()?); + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::GroupExpr => { + if group_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("groupExpr")); + } + group_expr__ = Some(map_.next_value()?); + } + GeneratedField::GroupExprName => { + if group_expr_name__.is_some() { + return Err(serde::de::Error::duplicate_field("groupExprName")); + } + group_expr_name__ = Some(map_.next_value()?); + } + GeneratedField::AggrExpr => { + if aggr_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("aggrExpr")); + } + aggr_expr__ = Some(map_.next_value()?); + } + GeneratedField::AggrExprName => { + if aggr_expr_name__.is_some() { + return Err(serde::de::Error::duplicate_field("aggrExprName")); + } + aggr_expr_name__ = Some(map_.next_value()?); + } + GeneratedField::InputSchema => { + if input_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("inputSchema")); + } + input_schema__ = map_.next_value()?; + } + } + } + Ok(GroupJoinExecNode { + left: left__, + right: right__, + on: on__.unwrap_or_default(), + join_type: join_type__.unwrap_or_default(), + group_expr: group_expr__.unwrap_or_default(), + group_expr_name: group_expr_name__.unwrap_or_default(), + aggr_expr: aggr_expr__.unwrap_or_default(), + aggr_expr_name: aggr_expr_name__.unwrap_or_default(), + input_schema: input_schema__, + }) + } + } + deserializer.deserialize_struct("datafusion.GroupJoinExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for GroupingSetNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -18489,6 +18724,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::ScalarSubquery(v) => { struct_ser.serialize_field("scalarSubquery", v)?; } + physical_plan_node::PhysicalPlanType::GroupJoin(v) => { + struct_ser.serialize_field("groupJoin", v)?; + } } } struct_ser.end() @@ -18561,6 +18799,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "arrowScan", "scalar_subquery", "scalarSubquery", + "group_join", + "groupJoin", ]; #[allow(clippy::enum_variant_names)] @@ -18603,6 +18843,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { Buffer, ArrowScan, ScalarSubquery, + GroupJoin, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18662,6 +18903,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "buffer" => Ok(GeneratedField::Buffer), "arrowScan" | "arrow_scan" => Ok(GeneratedField::ArrowScan), "scalarSubquery" | "scalar_subquery" => Ok(GeneratedField::ScalarSubquery), + "groupJoin" | "group_join" => Ok(GeneratedField::GroupJoin), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18948,6 +19190,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("scalarSubquery")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ScalarSubquery) +; + } + GeneratedField::GroupJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("groupJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GroupJoin) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index af9b1404bb80a..7f82cf1afce14 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1119,7 +1119,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40" )] pub physical_plan_type: ::core::option::Option, } @@ -1205,6 +1205,8 @@ pub mod physical_plan_node { ArrowScan(super::ArrowScanExecNode), #[prost(message, tag = "39")] ScalarSubquery(::prost::alloc::boxed::Box), + #[prost(message, tag = "40")] + GroupJoin(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1774,6 +1776,27 @@ pub struct HashJoinExecNode { pub null_aware: bool, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct GroupJoinExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub left: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub right: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub on: ::prost::alloc::vec::Vec, + #[prost(enumeration = "super::datafusion_common::JoinType", tag = "4")] + pub join_type: i32, + #[prost(message, repeated, tag = "5")] + pub group_expr: ::prost::alloc::vec::Vec, + #[prost(string, repeated, tag = "6")] + pub group_expr_name: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(message, repeated, tag = "7")] + pub aggr_expr: ::prost::alloc::vec::Vec, + #[prost(string, repeated, tag = "8")] + pub aggr_expr_name: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(message, optional, tag = "9")] + pub input_schema: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct SymmetricHashJoinExecNode { #[prost(message, optional, boxed, tag = "1")] pub left: ::core::option::Option<::prost::alloc::boxed::Box>, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 5172a552fad4f..5538d8f0614d4 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -72,6 +72,7 @@ use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::explain::ExplainExec; use datafusion_physical_plan::expressions::PhysicalSortExpr; use datafusion_physical_plan::filter::{FilterExec, FilterExecBuilder}; +use datafusion_physical_plan::joins::group_join::GroupJoinExec; use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion_physical_plan::joins::{ CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, @@ -293,6 +294,9 @@ impl protobuf::PhysicalPlanNode { PhysicalPlanType::HashJoin(hashjoin) => { self.try_into_hash_join_physical_plan(hashjoin, ctx, proto_converter) } + PhysicalPlanType::GroupJoin(group_join) => { + self.try_into_group_join_physical_plan(group_join, ctx, proto_converter) + } PhysicalPlanType::SymmetricHashJoin(sym_join) => self .try_into_symmetric_hash_join_physical_plan( sym_join, @@ -425,6 +429,14 @@ impl protobuf::PhysicalPlanNode { ); } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_group_join_exec( + exec, + codec, + proto_converter, + ); + } + if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_symmetric_hash_join_exec( exec, @@ -1213,72 +1225,13 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.aggr_expr_name.iter()) .map(|(expr, name)| { - let expr_type = expr.expr_type.as_ref().ok_or_else(|| { - proto_error("Unexpected empty aggregate physical expression") - })?; - - match expr_type { - ExprType::AggregateExpr(agg_node) => { - let input_phy_expr: Vec> = agg_node - .expr - .iter() - .map(|e| { - proto_converter.proto_to_physical_expr( - e, - &physical_schema, - ctx, - ) - }) - .collect::>>()?; - let order_bys = agg_node - .ordering_req - .iter() - .map(|e| { - parse_physical_sort_expr( - e, - ctx, - &physical_schema, - proto_converter, - ) - }) - .collect::>()?; - agg_node - .aggregate_function - .as_ref() - .map(|func| match func { - AggregateFunction::UserDefinedAggrFunction(udaf_name) => { - let agg_udf = match &agg_node.fun_definition { - Some(buf) => { - ctx.codec().try_decode_udaf(udaf_name, buf)? - } - None => ctx.task_ctx().udaf(udaf_name).or_else( - |_| { - ctx.codec() - .try_decode_udaf(udaf_name, &[]) - }, - )?, - }; - - AggregateExprBuilder::new(agg_udf, input_phy_expr) - .schema(Arc::clone(&physical_schema)) - .alias(name) - .human_display(agg_node.human_display.clone()) - .with_ignore_nulls(agg_node.ignore_nulls) - .with_distinct(agg_node.distinct) - .order_by(order_bys) - .build() - .map(Arc::new) - } - }) - .transpose()? - .ok_or_else(|| { - proto_error( - "Invalid AggregateExpr, missing aggregate_function", - ) - }) - } - _ => internal_err!("Invalid aggregate expression for AggregateExec"), - } + parse_physical_aggr_expr_node( + expr, + name, + &physical_schema, + ctx, + proto_converter, + ) }) .collect::, _>>()?; @@ -1421,6 +1374,87 @@ impl protobuf::PhysicalPlanNode { )?)) } + fn try_into_group_join_physical_plan( + &self, + group_join: &protobuf::GroupJoinExecNode, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let left = into_physical_plan(&group_join.left, ctx, proto_converter)?; + let right = into_physical_plan(&group_join.right, ctx, proto_converter)?; + let left_schema = left.schema(); + let right_schema = right.schema(); + let on = group_join + .on + .iter() + .map(|col| { + let left = proto_converter.proto_to_physical_expr( + col.left + .as_ref() + .ok_or_else(|| proto_error("Missing GroupJoin left key"))?, + left_schema.as_ref(), + ctx, + )?; + let right = proto_converter.proto_to_physical_expr( + col.right + .as_ref() + .ok_or_else(|| proto_error("Missing GroupJoin right key"))?, + right_schema.as_ref(), + ctx, + )?; + Ok((left, right)) + }) + .collect::>>()?; + let join_type = + protobuf::JoinType::try_from(group_join.join_type).map_err(|_| { + proto_error(format!( + "Received a GroupJoinNode message with unknown JoinType {}", + group_join.join_type + )) + })?; + + let group_by_exprs = group_join + .group_expr + .iter() + .zip(group_join.group_expr_name.iter()) + .map(|(expr, name)| { + proto_converter + .proto_to_physical_expr(expr, left_schema.as_ref(), ctx) + .map(|expr| (expr, name.clone())) + }) + .collect::>>()?; + + let input_schema = group_join.input_schema.as_ref().ok_or_else(|| { + internal_datafusion_err!("input_schema in GroupJoinNode is missing.") + })?; + let physical_schema = SchemaRef::new(input_schema.try_into()?); + + let aggr_expr = group_join + .aggr_expr + .iter() + .zip(group_join.aggr_expr_name.iter()) + .map(|(expr, name)| { + parse_physical_aggr_expr_node( + expr, + name, + &physical_schema, + ctx, + proto_converter, + ) + }) + .collect::>>()?; + + Ok(Arc::new(GroupJoinExec::try_new_with_aggr_input_schema( + left, + right, + on, + join_type.into(), + group_by_exprs, + aggr_expr, + physical_schema, + )?)) + } + fn try_into_symmetric_hash_join_physical_plan( &self, sym_join: &protobuf::SymmetricHashJoinExecNode, @@ -2481,6 +2515,63 @@ impl protobuf::PhysicalPlanNode { }) } + fn try_from_group_join_exec( + exec: &GroupJoinExec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let left = proto_converter.execution_plan_to_proto(exec.left(), codec)?; + let right = proto_converter.execution_plan_to_proto(exec.right(), codec)?; + let on = exec + .on() + .iter() + .map(|(left, right)| { + Ok(protobuf::JoinOn { + left: Some(proto_converter.physical_expr_to_proto(left, codec)?), + right: Some(proto_converter.physical_expr_to_proto(right, codec)?), + }) + }) + .collect::>>()?; + let group_expr = exec + .group_by_exprs() + .iter() + .map(|(expr, _)| proto_converter.physical_expr_to_proto(expr, codec)) + .collect::>>()?; + let group_expr_name = exec + .group_by_exprs() + .iter() + .map(|(_, name)| name.clone()) + .collect(); + let aggr_expr = exec + .aggr_expr() + .iter() + .map(|expr| { + serialize_physical_aggr_expr(expr.to_owned(), codec, proto_converter) + }) + .collect::>>()?; + let aggr_expr_name = exec + .aggr_expr() + .iter() + .map(|expr| expr.name().to_string()) + .collect(); + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::GroupJoin(Box::new( + protobuf::GroupJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + on, + join_type: protobuf::JoinType::from(*exec.join_type()).into(), + group_expr, + group_expr_name, + aggr_expr, + aggr_expr_name, + input_schema: Some(exec.aggr_input_schema().as_ref().try_into()?), + }, + ))), + }) + } + fn try_from_symmetric_hash_join_exec( exec: &SymmetricHashJoinExec, codec: &dyn PhysicalExtensionCodec, @@ -4116,6 +4207,62 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { } } +fn parse_physical_aggr_expr_node( + expr: &protobuf::PhysicalExprNode, + name: &str, + input_schema: &SchemaRef, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, +) -> Result> { + let expr_type = expr + .expr_type + .as_ref() + .ok_or_else(|| proto_error("Unexpected empty aggregate physical expression"))?; + + match expr_type { + ExprType::AggregateExpr(agg_node) => { + let input_phy_expr: Vec> = agg_node + .expr + .iter() + .map(|e| proto_converter.proto_to_physical_expr(e, input_schema, ctx)) + .collect::>>()?; + let order_bys = agg_node + .ordering_req + .iter() + .map(|e| parse_physical_sort_expr(e, ctx, input_schema, proto_converter)) + .collect::>()?; + agg_node + .aggregate_function + .as_ref() + .map(|func| match func { + AggregateFunction::UserDefinedAggrFunction(udaf_name) => { + let agg_udf = match &agg_node.fun_definition { + Some(buf) => ctx.codec().try_decode_udaf(udaf_name, buf)?, + None => ctx.task_ctx().udaf(udaf_name).or_else(|_| { + ctx.codec().try_decode_udaf(udaf_name, &[]) + })?, + }; + + AggregateExprBuilder::new(agg_udf, input_phy_expr) + .schema(Arc::clone(input_schema)) + .alias(name) + .human_display(agg_node.human_display.clone()) + .with_ignore_nulls(agg_node.ignore_nulls) + .with_distinct(agg_node.distinct) + .order_by(order_bys) + .build() + .map(Arc::new) + } + }) + .transpose()? + .ok_or_else(|| { + proto_error("Invalid AggregateExpr, missing aggregate_function") + }) + } + _ => internal_err!("Invalid aggregate expression for AggregateExec"), + } +} + fn into_physical_plan( node: &Option>, ctx: &PhysicalPlanDecodeContext<'_>, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index fa342ae9079d5..2bdc9a4433ea5 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -65,6 +65,7 @@ use datafusion::physical_plan::expressions::{ BinaryExpr, Column, NotExpr, PhysicalSortExpr, binary, cast, col, in_list, like, lit, }; use datafusion::physical_plan::filter::{FilterExec, FilterExecBuilder}; +use datafusion::physical_plan::joins::group_join::GroupJoinExec; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, @@ -314,6 +315,49 @@ fn roundtrip_hash_join() -> Result<()> { Ok(()) } +#[test] +fn roundtrip_group_join() -> Result<()> { + let left_schema = Arc::new(Schema::new(vec![ + Field::new("l_key", DataType::Int64, false), + Field::new("l_value", DataType::Int64, false), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("r_key", DataType::Int64, false), + Field::new("r_value", DataType::Int64, false), + ])); + let input_schema = Arc::new(Schema::new(vec![ + Field::new("l_key", DataType::Int64, false), + Field::new("l_value", DataType::Int64, false), + Field::new("r_key", DataType::Int64, false), + Field::new("r_value", DataType::Int64, false), + ])); + + let on = vec![( + Arc::new(Column::new("l_key", left_schema.index_of("l_key")?)) as _, + Arc::new(Column::new("r_key", right_schema.index_of("r_key")?)) as _, + )]; + let group_by_exprs = vec![( + Arc::new(Column::new("l_key", left_schema.index_of("l_key")?)) as _, + "l_key".to_string(), + )]; + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("r_value", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("count_values") + .build()?, + )]; + + roundtrip_test(Arc::new(GroupJoinExec::try_new_with_aggr_input_schema( + Arc::new(EmptyExec::new(left_schema)), + Arc::new(EmptyExec::new(right_schema)), + on, + JoinType::Inner, + group_by_exprs, + aggr_expr, + input_schema, + )?)) +} + #[test] fn roundtrip_nested_loop_join() -> Result<()> { let field_a = Field::new("col", DataType::Int64, false); diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 2e8a65385541e..a9392957099bc 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -237,6 +237,7 @@ physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after group_join SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after WindowTopN SAME TEXT AS ABOVE @@ -318,6 +319,7 @@ physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after group_join SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after WindowTopN SAME TEXT AS ABOVE @@ -365,6 +367,7 @@ physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after group_join SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after WindowTopN SAME TEXT AS ABOVE @@ -612,6 +615,7 @@ physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after group_join SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after WindowTopN SAME TEXT AS ABOVE diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part index 94e0848bfcce1..fabe59b1af852 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part @@ -60,10 +60,9 @@ physical_plan 05)--------RepartitionExec: partitioning=Hash([c_count@0], 4), input_partitions=4 06)----------AggregateExec: mode=Partial, gby=[c_count@0 as c_count], aggr=[count(Int64(1))] 07)------------ProjectionExec: expr=[count(orders.o_orderkey)@1 as c_count] -08)--------------AggregateExec: mode=SinglePartitioned, gby=[c_custkey@0 as c_custkey], aggr=[count(orders.o_orderkey)] -09)----------------HashJoinExec: mode=Partitioned, join_type=Left, on=[(c_custkey@0, o_custkey@1)], projection=[c_custkey@0, o_orderkey@1] -10)------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=1 -11)--------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey], file_type=csv, has_header=false -12)------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 -13)--------------------FilterExec: o_comment@2 NOT LIKE %special%requests%, projection=[o_orderkey@0, o_custkey@1] -14)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_comment], file_type=csv, has_header=false +08)--------------GroupJoinExec: join_type=Left, on=[(c_custkey@0, o_custkey@1)], aggr=[count(orders.o_orderkey)] +09)----------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=1 +10)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey], file_type=csv, has_header=false +11)----------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 +12)------------------FilterExec: o_comment@2 NOT LIKE %special%requests%, projection=[o_orderkey@0, o_custkey@1] +13)--------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_comment], file_type=csv, has_header=false