diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9ca6941a61ce6..0858c1ad84081 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -4144,6 +4144,8 @@ fn calc_func_dependencies_for_project( exprs: &[Expr], input: &LogicalPlan, ) -> Result { + const COMPUTED_EXPR_INDEX: usize = usize::MAX; + let input_fields = input.schema().field_names(); // Calculate expression indices (if present) in the input schema. let proj_indices = exprs @@ -4161,30 +4163,33 @@ fn calc_func_dependencies_for_project( Ok::<_, DataFusionError>( wildcard_fields .into_iter() - .filter_map(|(qualifier, f)| { + .map(|(qualifier, f)| { let flat_name = qualifier .map(|t| format!("{}.{}", t, f.name())) .unwrap_or_else(|| f.name().clone()); - input_fields.iter().position(|item| *item == flat_name) + input_fields + .iter() + .position(|item| *item == flat_name) + .unwrap_or(COMPUTED_EXPR_INDEX) }) .collect::>(), ) } Expr::Alias(alias) => { let name = format!("{}", alias.expr); - Ok(input_fields + let input_index = input_fields .iter() .position(|item| *item == name) - .map(|i| vec![i]) - .unwrap_or(vec![])) + .unwrap_or(COMPUTED_EXPR_INDEX); + Ok(vec![input_index]) } _ => { let name = format!("{expr}"); - Ok(input_fields + let input_index = input_fields .iter() .position(|item| *item == name) - .map(|i| vec![i]) - .unwrap_or(vec![])) + .unwrap_or(COMPUTED_EXPR_INDEX); + Ok(vec![input_index]) } }) .collect::>>()? @@ -4947,6 +4952,30 @@ mod tests { ]) } + #[test] + fn projection_with_leading_computed_column_preserves_pk() -> Result<()> { + let constraints = + Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]); + let source = Arc::new( + LogicalTableSource::new(Arc::new(employee_schema())) + .with_constraints(constraints), + ); + let plan = LogicalPlanBuilder::scan("employee_csv", source, None)? + .project(vec![ + lit(1i32).alias("__common_expr_1"), + col("id"), + col("first_name"), + col("salary"), + ])? + .build()?; + + let deps = plan.schema().functional_dependencies(); + assert_eq!(deps.len(), 1); + assert_eq!(deps[0].source_indices, vec![1]); + + Ok(()) + } + fn i32_split_point(value: i32) -> SplitPoint { SplitPoint::new(vec![ScalarValue::Int32(Some(value))]) }