diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 24ca33c0c2c90..47cf6e7b0e2a8 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -328,11 +328,16 @@ impl FunctionalDependencies { /// This function joins this set of functional dependencies with the `other` /// according to the given `join_type`. + /// + /// `uniquely_determined_columns` contains a set of columns on each side of the relation ( + /// self first) which are used in the ON clause of the join in an equality comparison, + /// which guarantees they are uniquely determined by the source side of the join. pub fn join( &self, other: &FunctionalDependencies, join_type: &JoinType, left_cols_len: usize, + uniquely_determined_columns: Option<(HashSet, HashSet)>, ) -> FunctionalDependencies { // Get mutable copies of left and right side dependencies: let mut right_func_dependencies = other.clone(); @@ -343,6 +348,61 @@ impl FunctionalDependencies { // Add offset to right schema: right_func_dependencies.add_offset(left_cols_len); + let (fixed_left, fixed_right) = + uniquely_determined_columns.unwrap_or_default(); + + // Computes the list of columns on both side of the relation that can be directly + // derived from the join clause, because all the source indices of their dependencies + // appear in the join clause. + let left_dependent_columns = left_func_dependencies + .deps + .iter() + .filter(|dep| { + dep.source_indices + .iter() + .all(|index| fixed_left.contains(index)) + }) + .flat_map(|dep| dep.target_indices.clone()) + .collect::>(); + let right_dependent_columns = right_func_dependencies + .deps + .iter() + .filter(|dep| { + dep.source_indices + .iter() + .all(|index| fixed_right.contains(&(*index - left_cols_len))) + }) + .flat_map(|dep| dep.target_indices.clone()) + .collect::>(); + + // Update dependencies on each side of the join to add columns from the other side, + // if their source appears fully in the condition of the join + left_func_dependencies + .deps + .iter_mut() + .filter(|dep| { + dep.source_indices + .iter() + .all(|index| fixed_left.contains(index)) + }) + .for_each(|dep| { + dep.target_indices + .extend_from_slice(right_dependent_columns.as_slice()) + }); + + right_func_dependencies + .deps + .iter_mut() + .filter(|dep| { + dep.source_indices + .iter() + .all(|index| fixed_right.contains(&(*index - left_cols_len))) + }) + .for_each(|dep| { + dep.target_indices + .extend_from_slice(left_dependent_columns.as_slice()) + }); + // Result may have multiple values, update the dependency mode: left_func_dependencies = left_func_dependencies.with_dependency(Dependency::Multi); diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2ecb12c30afad..f71cb6f921d04 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1154,8 +1154,18 @@ impl LogicalPlanBuilder { .zip(right_keys) .map(|(l, r)| (Expr::Column(l), Expr::Column(r))) .collect(); - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &join_type)?; + + let join_schema = build_join_schema( + self.plan.schema(), + right.schema(), + &join_type, + Some(Join::get_uniquely_determined_columns( + self.plan.as_ref(), + &right, + &on, + &filter, + )), + )?; // Inner type without join condition is cross join if join_type != JoinType::Inner && on.is_empty() && filter.is_none() { @@ -1672,10 +1682,18 @@ fn mark_field(schema: &DFSchema) -> (Option, Arc) { /// Creates a schema for a join operation. /// The fields from the left side are first +/// +/// `uniquely_determined_columns` contains a set of columns on each side of the relation ( +/// self first) which are used in the ON clause of the join in an equality comparison, +/// which guarantees they are uniquely determined by the source side of the join. pub fn build_join_schema( left: &DFSchema, right: &DFSchema, join_type: &JoinType, + uniquely_determined_columns: Option<( + datafusion_common::HashSet, + datafusion_common::HashSet, + )>, ) -> Result { fn nullify_fields<'a>( fields: impl Iterator, &'a Arc)>, @@ -1755,6 +1773,7 @@ pub fn build_join_schema( right.functional_dependencies(), join_type, left.fields().len(), + uniquely_determined_columns, ); let (schema1, schema2) = match join_type { @@ -2912,13 +2931,13 @@ mod tests { )?; let join_schema = - build_join_schema(&left_schema, &right_schema, &JoinType::Left)?; + build_join_schema(&left_schema, &right_schema, &JoinType::Left, None)?; assert_eq!( join_schema.metadata(), &HashMap::from([("key".to_string(), "left".to_string())]) ); let join_schema = - build_join_schema(&left_schema, &right_schema, &JoinType::Right)?; + build_join_schema(&left_schema, &right_schema, &JoinType::Right, None)?; assert_eq!( join_schema.metadata(), &HashMap::from([("key".to_string(), "right".to_string())]) @@ -2992,4 +3011,56 @@ mod tests { ] ); } + + #[test] + fn test_join_dependencies() -> Result<()> { + let left_schema = DFSchema::new_with_metadata( + vec![ + (None, Arc::new(Field::new("a", DataType::Int32, false))), + (None, Arc::new(Field::new("b", DataType::Int32, false))), + ], + HashMap::new(), + )?; + let left_table = table_source_with_constraints( + left_schema.as_arrow(), + Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]), + ); + + let right_schema = DFSchema::new_with_metadata( + vec![ + (None, Arc::new(Field::new("c", DataType::Int32, false))), + (None, Arc::new(Field::new("d", DataType::Int32, false))), + ], + HashMap::new(), + )?; + let right_table = table_source_with_constraints( + right_schema.as_arrow(), + Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]), + ); + + let plan = LogicalPlanBuilder::scan("A", left_table, Some(vec![0, 1]))? + .join_on( + LogicalPlanBuilder::scan("B", right_table, Some(vec![0, 1]))?.build()?, + JoinType::Left, + vec![binary_expr(col("a"), Operator::Eq, col("c"))], + )? + .project(vec![ + SelectExpr::Expression(col("a")), + SelectExpr::Expression(col("b")), + SelectExpr::Expression(col("c")), + SelectExpr::Expression(col("d")), + ])? + .build()?; + + let deps = plan.schema().functional_dependencies(); + + assert_eq!(deps.len(), 2); + assert_eq!(deps[0].source_indices, vec![0]); + assert_eq!(deps[0].target_indices, vec![0, 1, 2, 3]); + + assert_eq!(deps[1].source_indices, vec![2]); + assert_eq!(deps[1].target_indices, vec![2, 3, 0, 1]); + + Ok(()) + } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9ca6941a61ce6..6054958419035 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -667,9 +667,6 @@ impl LogicalPlan { null_equality, null_aware, }) => { - let schema = - build_join_schema(left.schema(), right.schema(), &join_type)?; - let new_on: Vec<_> = on .into_iter() .map(|equi_expr| { @@ -678,6 +675,18 @@ impl LogicalPlan { }) .collect(); + let schema = build_join_schema( + left.schema(), + right.schema(), + &join_type, + Some(Join::get_uniquely_determined_columns( + left.as_ref(), + right.as_ref(), + &new_on, + &filter, + )), + )?; + Ok(LogicalPlan::Join(Join { left, right, @@ -944,7 +953,6 @@ impl LogicalPlan { .. }) => { let (left, right) = self.only_two_inputs(inputs)?; - let schema = build_join_schema(left.schema(), right.schema(), join_type)?; let equi_expr_count = on.len() * 2; assert!(expr.len() >= equi_expr_count); @@ -973,6 +981,18 @@ impl LogicalPlan { new_on.push((left.unalias(), right.unalias())); } + let schema = build_join_schema( + left.schema(), + right.schema(), + join_type, + Some(Join::get_uniquely_determined_columns( + &left, + &right, + &new_on, + &filter_expr, + )), + )?; + Ok(LogicalPlan::Join(Join { left: Arc::new(left), right: Arc::new(right), @@ -4239,6 +4259,76 @@ pub struct Join { } impl Join { + /// Returns colums which are unique on each side of the join, or that are identical to a + /// column on the other side of the join + pub fn get_uniquely_determined_columns( + left_plan: &LogicalPlan, + right_plan: &LogicalPlan, + on: &Vec<(Expr, Expr)>, + filter: &Option, + ) -> ( + datafusion_common::HashSet, + datafusion_common::HashSet, + ) { + let mut left_set = datafusion_common::HashSet::new(); + let mut right_set = datafusion_common::HashSet::new(); + + let eq_exprs = on.into_iter().cloned().chain( + filter + .iter() + .flat_map(|expr| split_conjunction(expr)) + .filter_map(|expr| match expr { + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) => Some((*left.clone(), *right.clone())), + _ => None, + }), + ); + + for (left, right) in eq_exprs { + // This is a no-op filter expression + if left == right { + continue; + } + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(c1), Expr::Column(c2)) => { + if let Some(left_index) = left_plan.schema().maybe_index_of_column(c1) + { + if let Some(right_index) = + right_plan.schema().maybe_index_of_column(c2) + { + left_set.insert(left_index); + right_set.insert(right_index); + } + } else { + if let Some(left_index) = + left_plan.schema().maybe_index_of_column(c2) + { + let right_index = + right_plan.schema().index_of_column(c1).unwrap(); + left_set.insert(left_index); + right_set.insert(right_index); + } + } + } + (Expr::Column(c), _) | (_, Expr::Column(c)) => { + if let Some(left_index) = left_plan.schema().maybe_index_of_column(c) + { + left_set.insert(left_index); + } else { + right_set.insert(right_plan.schema().index_of_column(c).unwrap()); + } + } + _ => continue, + } + } + + (left_set, right_set) + } + /// Creates a new Join operator with automatically computed schema. /// /// This constructor computes the schema based on the join type and inputs, @@ -4269,7 +4359,17 @@ impl Join { null_equality: NullEquality, null_aware: bool, ) -> Result { - let join_schema = build_join_schema(left.schema(), right.schema(), &join_type)?; + let join_schema = build_join_schema( + left.schema(), + right.schema(), + &join_type, + Some(Self::get_uniquely_determined_columns( + left.as_ref(), + right.as_ref(), + &on, + &filter, + )), + )?; Ok(Join { left, @@ -4324,6 +4424,12 @@ impl Join { left_sch.schema(), right_sch.schema(), &original_join.join_type, + Some(Self::get_uniquely_determined_columns( + left.as_ref(), + right.as_ref(), + &on, + &original_join.filter, + )), )?; Ok(( diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 95b70da443d88..2a5ae8d999fd0 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -374,6 +374,12 @@ fn find_inner_join( left_input.schema(), right_input.schema(), &JoinType::Inner, + Some(Join::get_uniquely_determined_columns( + &left_input, + &right_input, + &join_keys, + &None, + )), )?); return Ok(LogicalPlan::Join(Join { @@ -397,6 +403,7 @@ fn find_inner_join( left_input.schema(), right.schema(), &JoinType::Inner, + None, )?); Ok(LogicalPlan::Join(Join { @@ -1402,6 +1409,7 @@ mod tests { t1.schema(), t2.schema(), &JoinType::Inner, + None, )?); let inner_join = LogicalPlan::Join(Join { diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index acdbf71d05d5c..bde970d950be6 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -1024,7 +1024,8 @@ mod tests { let left_schema = left_child.schema(); let right_schema = right_child.schema(); let schema = Arc::new( - build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(), + build_join_schema(left_schema, right_schema, &JoinType::Inner, None) + .unwrap(), ); Self { exprs: vec![], diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index b0099b8a1dcc3..ad8d8eafc6249 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -15,10 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; -use std::ops::ControlFlow; -use std::sync::Arc; - use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::query::to_order_by_exprs_with_select; use crate::utils::{ @@ -28,11 +24,17 @@ use crate::utils::{ rewrite_recursive_unnests_bottom_up, substitute_top_level_alias, substitute_top_level_aliases_in_sorts, }; +use arrow::datatypes::FieldRef; +use std::collections::HashSet; +use std::ops::ControlFlow; +use std::sync::Arc; use arrow::datatypes::DataType; use datafusion_common::error::DataFusionErrorBuilder; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, not_impl_err, plan_err}; +use datafusion_common::{ + Column, Constraint, DFSchema, DFSchemaRef, HashMap, Result, not_impl_err, plan_err, +}; use datafusion_common::{RecursionUnnestOption, UnnestOptions}; use datafusion_expr::ExprSchemable; use datafusion_expr::builder::get_struct_unnested_columns;