diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 93df300bb50b4..bcba7408e923f 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -21,17 +21,18 @@ mod required_indices; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use arrow::array::Array; use std::collections::HashSet; use std::sync::Arc; use datafusion_common::{ - Column, DFSchema, HashMap, JoinType, Result, assert_eq_or_internal_err, + Column, DFSchema, HashMap, JoinType, Result, ScalarValue, assert_eq_or_internal_err, get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, }; use datafusion_expr::expr::Alias; use datafusion_expr::{ - Aggregate, Distinct, EmptyRelation, Expr, Projection, TableScan, Unnest, Window, - logical_plan::LogicalPlan, + Aggregate, Distinct, EmptyRelation, Expr, Extension, Projection, TableScan, Unnest, + Window, logical_plan::LogicalPlan, }; use crate::optimize_projections::required_indices::RequiredIndices; @@ -132,6 +133,8 @@ fn optimize_projections( config: &dyn OptimizerConfig, indices: RequiredIndices, ) -> Result> { + let volatile_in_plan = plan.expressions().iter().any(Expr::is_volatile); + // Recursively rewrite any nodes that may be able to avoid computation given // their parents' required indices. match plan { @@ -141,137 +144,23 @@ fn optimize_projections( }); } LogicalPlan::Aggregate(aggregate) => { - // Split parent requirements to GROUP BY and aggregate sections: - let n_group_exprs = aggregate.group_expr_len()?; - // Offset aggregate indices so that they point to valid indices at - // `aggregate.aggr_expr`: - let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs); - - // Get absolutely necessary GROUP BY fields: - let group_by_expr_existing = aggregate - .group_expr - .iter() - .map(|group_by_expr| group_by_expr.schema_name().to_string()) - .collect::>(); - - let new_group_bys = if let Some(simplest_groupby_indices) = - get_required_group_by_exprs_indices( - aggregate.input.schema(), - &group_by_expr_existing, - ) { - // Some of the fields in the GROUP BY may be required by the - // parent even if these fields are unnecessary in terms of - // functional dependency. - group_by_reqs - .append(&simplest_groupby_indices) - .get_at_indices(&aggregate.group_expr) - } else { - aggregate.group_expr - }; - - // Only use the absolutely necessary aggregate expressions required - // by the parent: - let new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr); - - if new_group_bys.is_empty() && new_aggr_expr.is_empty() { - // Global aggregation with no aggregate functions always produces 1 row and no columns. - return Ok(Transformed::yes(LogicalPlan::EmptyRelation( - EmptyRelation { - produce_one_row: true, - schema: Arc::new(DFSchema::empty()), - }, - ))); - } - - let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); - let schema = aggregate.input.schema(); - let necessary_indices = - RequiredIndices::new().with_exprs(schema, all_exprs_iter); - let necessary_exprs = necessary_indices.get_required_exprs(schema); - - return optimize_projections( - Arc::unwrap_or_clone(aggregate.input), + return optimize_aggregate_projections( + aggregate, config, - necessary_indices, - )? - .transform_data(|aggregate_input| { - // Simplify the input of the aggregation by adding a projection so - // that its input only contains absolutely necessary columns for - // the aggregate expressions. Note that necessary_indices refer to - // fields in `aggregate.input.schema()`. - add_projection_on_top_if_helpful(aggregate_input, necessary_exprs) - })? - .map_data(|aggregate_input| { - // Create a new aggregate plan with the updated input and only the - // absolutely necessary fields: - Aggregate::try_new( - Arc::new(aggregate_input), - new_group_bys, - new_aggr_expr, - ) - .map(LogicalPlan::Aggregate) - }); + indices, + volatile_in_plan, + ); } LogicalPlan::Window(window) => { - let input_schema = Arc::clone(window.input.schema()); - // Split parent requirements to child and window expression sections: - let n_input_fields = input_schema.fields().len(); - // Offset window expression indices so that they point to valid - // indices at `window.window_expr`: - let (child_reqs, window_reqs) = indices.split_off(n_input_fields); - - // Only use window expressions that are absolutely necessary according - // to parent requirements: - let new_window_expr = window_reqs.get_at_indices(&window.window_expr); - - // Get all the required column indices at the input, either by the - // parent or window expression requirements. - let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr); - - return optimize_projections( - Arc::unwrap_or_clone(window.input), + return optimize_window_projections( + window, config, - required_indices.clone(), - )? - .transform_data(|window_child| { - if new_window_expr.is_empty() { - // When no window expression is necessary, use the input directly: - Ok(Transformed::no(window_child)) - } else { - // Calculate required expressions at the input of the window. - // Please note that we use `input_schema`, because `required_indices` - // refers to that schema - let required_exprs = - required_indices.get_required_exprs(&input_schema); - let window_child = - add_projection_on_top_if_helpful(window_child, required_exprs)? - .data; - Window::try_new(new_window_expr, Arc::new(window_child)) - .map(LogicalPlan::Window) - .map(Transformed::yes) - } - }); + indices, + volatile_in_plan, + ); } LogicalPlan::TableScan(table_scan) => { - let TableScan { - table_name, - source, - projection, - filters, - fetch, - projected_schema: _, - } = table_scan; - - // Get indices referred to in the original (schema with all fields) - // given projected indices. - let projection = match &projection { - Some(projection) => indices.into_mapped_indices(|idx| projection[idx]), - None => indices.into_inner(), - }; - let new_scan = - TableScan::try_new(table_name, source, Some(projection), filters, fetch)?; - - return Ok(Transformed::yes(LogicalPlan::TableScan(new_scan))); + return optimize_table_scan_projections(table_scan, indices); } // Other node types are handled below _ => {} @@ -279,7 +168,7 @@ fn optimize_projections( // For other plan node types, calculate indices for columns they use and // try to rewrite their children - let mut child_required_indices: Vec = match &plan { + let child_required_indices: Vec = match &plan { LogicalPlan::Sort(_) | LogicalPlan::Filter(_) | LogicalPlan::Repartition(_) @@ -290,25 +179,14 @@ fn optimize_projections( // that appear in this plan's expressions to its child. All these // operators benefit from "small" inputs, so the projection_beneficial // flag is `true`. - plan.inputs() - .into_iter() - .map(|input| { - indices - .clone() - .with_projection_beneficial() - .with_plan_exprs(&plan, input.schema()) - }) - .collect::>()? + build_plan_input_requirements(&plan, &indices, volatile_in_plan, true)? } LogicalPlan::Limit(_) => { // Pass index requirements from the parent as well as column indices // that appear in this plan's expressions to its child. These operators // do not benefit from "small" inputs, so the projection_beneficial // flag is `false`. - plan.inputs() - .into_iter() - .map(|input| indices.clone().with_plan_exprs(&plan, input.schema())) - .collect::>()? + build_plan_input_requirements(&plan, &indices, volatile_in_plan, false)? } LogicalPlan::Copy(_) | LogicalPlan::Ddl(_) @@ -316,42 +194,26 @@ fn optimize_projections( | LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) | LogicalPlan::Subquery(_) - | LogicalPlan::Statement(_) - | LogicalPlan::Distinct(Distinct::All(_)) => { + | LogicalPlan::Statement(_) => { // These plans require all their fields, and their children should // be treated as final plans -- otherwise, we may have schema a // mismatch. // TODO: For some subquery variants (e.g. a subquery arising from an // EXISTS expression), we may not need to require all indices. - plan.inputs() - .into_iter() - .map(RequiredIndices::new_for_all_exprs) - .collect() + build_all_expr_input_requirements(&plan, volatile_in_plan, true)? } LogicalPlan::Extension(extension) => { - let Some(necessary_children_indices) = - extension.node.necessary_children_exprs(indices.indices()) + let Some(child_requirements) = build_extension_input_requirements( + &plan, + extension, + &indices, + volatile_in_plan, + )? else { // Requirements from parent cannot be routed down to user defined logical plan safely return Ok(Transformed::no(plan)); }; - let children = extension.node.inputs(); - assert_eq_or_internal_err!( - children.len(), - necessary_children_indices.len(), - "Inconsistent length between children and necessary children indices. \ - Make sure `.necessary_children_exprs` implementation of the \ - `UserDefinedLogicalNode` is consistent with actual children length \ - for the node." - ); - children - .into_iter() - .zip(necessary_children_indices) - .map(|(child, necessary_indices)| { - RequiredIndices::new_from_indices(necessary_indices) - .with_plan_exprs(&plan, child.schema()) - }) - .collect::>>()? + child_requirements } LogicalPlan::EmptyRelation(_) | LogicalPlan::Values(_) @@ -373,30 +235,18 @@ fn optimize_projections( return Ok(Transformed::no(plan)); } - plan.inputs() - .into_iter() - .map(|input| { - indices - .clone() - .with_projection_beneficial() - .with_plan_exprs(&plan, input.schema()) - }) - .collect::>>()? + build_plan_input_requirements(&plan, &indices, volatile_in_plan, true)? } - LogicalPlan::Join(join) => { - let left_len = join.left.schema().fields().len(); - let (left_req_indices, right_req_indices) = - split_join_requirements(left_len, indices, &join.join_type); - let left_indices = - left_req_indices.with_plan_exprs(&plan, join.left.schema())?; - let right_indices = - right_req_indices.with_plan_exprs(&plan, join.right.schema())?; - // Joins benefit from "small" input tables (lower memory usage). - // Therefore, each child benefits from projection: - vec![ - left_indices.with_projection_beneficial(), - right_indices.with_projection_beneficial(), - ] + LogicalPlan::Join(join) => build_join_input_requirements( + &plan, + join.left.schema(), + join.right.schema(), + &join.join_type, + indices, + volatile_in_plan, + )?, + LogicalPlan::Distinct(Distinct::All(_)) => { + build_all_expr_input_requirements(&plan, volatile_in_plan, false)? } // these nodes are explicitly rewritten in the match statement above LogicalPlan::Projection(_) @@ -407,28 +257,294 @@ fn optimize_projections( "OptimizeProjection: should have handled in the match statement above" ); } - LogicalPlan::Unnest(Unnest { - input, - dependency_indices, - .. - }) => { - // at least provide the indices for the exec-columns as a starting point - let required_indices = - RequiredIndices::new().with_plan_exprs(&plan, input.schema())?; - - // Add additional required indices from the parent - let mut additional_necessary_child_indices = Vec::new(); - indices.indices().iter().for_each(|idx| { - if let Some(index) = dependency_indices.get(*idx) { - additional_necessary_child_indices.push(*index); - } - }); - vec![required_indices.append(&additional_necessary_child_indices)] + LogicalPlan::Unnest(unnest) => { + if can_eliminate_unnest(unnest, &indices) { + let child_required_indices = + build_unnest_child_requirements(unnest, &indices); + let transformed_input = optimize_projections( + Arc::unwrap_or_clone(Arc::clone(&unnest.input)), + config, + child_required_indices, + )?; + return Ok(Transformed::yes(transformed_input.data)); + } + build_unnest_fallback_requirements(&plan, unnest, &indices, volatile_in_plan)? } }; - // Required indices are currently ordered (child0, child1, ...) - // but the loop pops off the last element, so we need to reverse the order + let transformed_plan = rewrite_plan_children(plan, config, child_required_indices)?; + + // If any of the children are transformed, we need to potentially update the plan's schema + if transformed_plan.transformed { + transformed_plan.map_data(|plan| plan.recompute_schema()) + } else { + Ok(transformed_plan) + } +} + +fn optimize_aggregate_projections( + aggregate: Aggregate, + config: &dyn OptimizerConfig, + indices: RequiredIndices, + volatile_in_plan: bool, +) -> Result> { + let has_volatile_ancestor = indices.has_volatile_ancestor(); + let n_group_exprs = aggregate.group_expr_len()?; + let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs); + + let group_by_expr_existing = aggregate + .group_expr + .iter() + .map(|group_by_expr| group_by_expr.schema_name().to_string()) + .collect::>(); + + let new_group_bys = if let Some(simplest_groupby_indices) = + get_required_group_by_exprs_indices( + aggregate.input.schema(), + &group_by_expr_existing, + ) { + group_by_reqs + .append(&simplest_groupby_indices) + .get_at_indices(&aggregate.group_expr) + } else { + aggregate.group_expr + }; + + let new_aggr_exprs = aggregate_reqs.get_at_indices(&aggregate.aggr_expr); + + if new_group_bys.is_empty() && new_aggr_exprs.is_empty() { + return Ok(Transformed::yes(LogicalPlan::EmptyRelation( + EmptyRelation { + produce_one_row: true, + schema: Arc::new(DFSchema::empty()), + }, + ))); + } + + let all_exprs_iter = new_group_bys.iter().chain(new_aggr_exprs.iter()); + let schema = aggregate.input.schema(); + let necessary_indices = RequiredIndices::new().with_exprs(schema, all_exprs_iter); + let necessary_exprs = necessary_indices.get_required_exprs(schema); + let necessary_indices = + with_child_multiplicity(necessary_indices, !new_aggr_exprs.is_empty()) + .with_volatile_ancestor_if(has_volatile_ancestor) + .with_plan_volatile(volatile_in_plan); + + optimize_projections( + Arc::unwrap_or_clone(aggregate.input), + config, + necessary_indices, + )? + .transform_data(|aggregate_input| { + add_projection_on_top_if_helpful(aggregate_input, necessary_exprs) + })? + .map_data(|aggregate_input| { + Aggregate::try_new(Arc::new(aggregate_input), new_group_bys, new_aggr_exprs) + .map(LogicalPlan::Aggregate) + }) +} + +fn optimize_window_projections( + window: Window, + config: &dyn OptimizerConfig, + indices: RequiredIndices, + volatile_in_plan: bool, +) -> Result> { + let has_volatile_ancestor = indices.has_volatile_ancestor(); + let input_schema = Arc::clone(window.input.schema()); + let n_input_fields = input_schema.fields().len(); + let (child_reqs, window_reqs) = indices.split_off(n_input_fields); + + let new_window_exprs = window_reqs.get_at_indices(&window.window_expr); + + let required_indices = child_reqs.with_exprs(&input_schema, &new_window_exprs); + let required_indices = + with_child_multiplicity(required_indices, !new_window_exprs.is_empty()) + .with_volatile_ancestor_if(has_volatile_ancestor) + .with_plan_volatile(volatile_in_plan); + + optimize_projections( + Arc::unwrap_or_clone(window.input), + config, + required_indices.clone(), + )? + .transform_data(|window_child| { + if new_window_exprs.is_empty() { + Ok(Transformed::no(window_child)) + } else { + let required_exprs = required_indices.get_required_exprs(&input_schema); + let window_child = + add_projection_on_top_if_helpful(window_child, required_exprs)?.data; + Window::try_new(new_window_exprs, Arc::new(window_child)) + .map(LogicalPlan::Window) + .map(Transformed::yes) + } + }) +} + +fn optimize_table_scan_projections( + table_scan: TableScan, + indices: RequiredIndices, +) -> Result> { + let TableScan { + table_name, + source, + projection, + filters, + fetch, + projected_schema: _, + } = table_scan; + + let projection = match &projection { + Some(projection) => indices.into_mapped_indices(|idx| projection[idx]), + None => indices.into_inner(), + }; + let new_scan = + TableScan::try_new(table_name, source, Some(projection), filters, fetch)?; + + Ok(Transformed::yes(LogicalPlan::TableScan(new_scan))) +} + +fn with_child_multiplicity( + required_indices: RequiredIndices, + multiplicity_sensitive: bool, +) -> RequiredIndices { + // This switch encodes semantic safety, not a performance preference. + // If ancestors can observe row-count changes, keep the child multiplicity-sensitive; + // otherwise mark it multiplicity-insensitive to allow more aggressive rewrites. + if multiplicity_sensitive { + required_indices.for_multiplicity_sensitive_child() + } else { + required_indices.for_multiplicity_insensitive_child() + } +} + +fn build_plan_input_requirements( + plan: &LogicalPlan, + indices: &RequiredIndices, + volatile_in_plan: bool, + projection_beneficial: bool, +) -> Result> { + plan.inputs() + .into_iter() + .map(|input| { + let required = if projection_beneficial { + indices.clone().with_projection_beneficial() + } else { + indices.clone() + }; + let required = required.with_plan_exprs(plan, input.schema())?; + Ok(required.with_plan_volatile(volatile_in_plan)) + }) + .collect::>() +} + +fn build_all_expr_input_requirements( + plan: &LogicalPlan, + volatile_in_plan: bool, + multiplicity_sensitive: bool, +) -> Result> { + plan.inputs() + .into_iter() + .map(|input| { + let mut required = RequiredIndices::new_for_all_exprs(input); + if !multiplicity_sensitive { + required = required.for_multiplicity_insensitive_child(); + } + Ok(required.with_plan_volatile(volatile_in_plan)) + }) + .collect::>() +} + +fn build_extension_input_requirements( + plan: &LogicalPlan, + extension: &Extension, + indices: &RequiredIndices, + volatile_in_plan: bool, +) -> Result>> { + let Some(necessary_children_indices) = + extension.node.necessary_children_exprs(indices.indices()) + else { + return Ok(None); + }; + + let children = extension.node.inputs(); + assert_eq_or_internal_err!( + children.len(), + necessary_children_indices.len(), + "Inconsistent length between children and necessary children indices. \ + Make sure `.necessary_children_exprs` implementation of the \ + `UserDefinedLogicalNode` is consistent with actual children length \ + for the node." + ); + + children + .into_iter() + .zip(necessary_children_indices) + .map(|(child, necessary_indices)| { + let required = RequiredIndices::new_from_indices(necessary_indices) + .with_plan_exprs(plan, child.schema())?; + Ok(required.with_plan_volatile(volatile_in_plan)) + }) + .collect::>>() + .map(Some) +} + +fn build_unnest_fallback_requirements( + plan: &LogicalPlan, + unnest: &Unnest, + indices: &RequiredIndices, + volatile_in_plan: bool, +) -> Result> { + let mut required_indices = + RequiredIndices::new().with_plan_exprs(plan, unnest.input.schema())?; + required_indices = required_indices + .for_multiplicity_sensitive_child() + .with_volatile_ancestor_if(indices.has_volatile_ancestor()) + .with_plan_volatile(volatile_in_plan); + + let additional_necessary_child_indices = indices + .indices() + .iter() + .filter_map(|idx| unnest.dependency_indices.get(*idx).copied()) + .collect::>(); + + Ok(vec![ + required_indices.append(&additional_necessary_child_indices), + ]) +} + +fn build_join_input_requirements( + plan: &LogicalPlan, + left_schema: &datafusion_common::DFSchemaRef, + right_schema: &datafusion_common::DFSchemaRef, + join_type: &JoinType, + indices: RequiredIndices, + volatile_in_plan: bool, +) -> Result> { + let left_len = left_schema.fields().len(); + let (left_req_indices, right_req_indices) = + split_join_requirements(left_len, indices, join_type); + let left_indices = left_req_indices.with_plan_exprs(plan, left_schema)?; + let right_indices = right_req_indices.with_plan_exprs(plan, right_schema)?; + let left_indices = left_indices + .for_multiplicity_sensitive_child() + .with_plan_volatile(volatile_in_plan); + let right_indices = right_indices + .for_multiplicity_sensitive_child() + .with_plan_volatile(volatile_in_plan); + + Ok(vec![ + left_indices.with_projection_beneficial(), + right_indices.with_projection_beneficial(), + ]) +} + +fn rewrite_plan_children( + plan: LogicalPlan, + config: &dyn OptimizerConfig, + mut child_required_indices: Vec, +) -> Result> { child_required_indices.reverse(); assert_eq_or_internal_err!( child_required_indices.len(), @@ -436,8 +552,7 @@ fn optimize_projections( "OptimizeProjection: child_required_indices length mismatch with plan inputs" ); - // Rewrite children of the plan - let transformed_plan = plan.map_children(|child| { + plan.map_children(|child| { let required_indices = child_required_indices.pop().ok_or_else(|| { internal_datafusion_err!( "Unexpected number of required_indices in OptimizeProjections rule" @@ -456,14 +571,7 @@ fn optimize_projections( } }, ) - })?; - - // If any of the children are transformed, we need to potentially update the plan's schema - if transformed_plan.transformed { - transformed_plan.map_data(|plan| plan.recompute_schema()) - } else { - Ok(transformed_plan) - } + }) } /// Merges consecutive projections. @@ -837,8 +945,14 @@ fn rewrite_projection_given_requirements( let exprs_used = indices.get_at_indices(&expr); - let required_indices = + let mut required_indices = RequiredIndices::new().with_exprs(input.schema(), exprs_used.iter()); + if !indices.multiplicity_sensitive() { + required_indices = required_indices.for_multiplicity_insensitive_child(); + } + if indices.has_volatile_ancestor() { + required_indices = required_indices.with_volatile_ancestor(); + } // rewrite the children projection, and if they are changed rewrite the // projection down @@ -909,6 +1023,116 @@ fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool { .any(|child| plan_contains_other_subqueries(child, cte_name)) } +fn can_eliminate_unnest(unnest: &Unnest, indices: &RequiredIndices) -> bool { + if indices.multiplicity_sensitive() || indices.has_volatile_ancestor() { + return false; + } + + // List unnest can drop rows for empty lists even when preserve_nulls=true. + // Allow elimination only when list inputs are provably non-empty. + if !list_unnest_rows_are_preserved(unnest) { + return false; + } + + // preserve_nulls only affects list unnest semantics. For struct-only unnest, + // row cardinality is unchanged and this option is not semantically relevant. + + indices + .indices() + .iter() + .all(|&output_idx| unnest_output_is_passthrough(unnest, output_idx)) +} + +fn list_unnest_rows_are_preserved(unnest: &Unnest) -> bool { + if unnest.list_type_columns.is_empty() { + return true; + } + + // To preserve row cardinality we need strict evidence that every unnested + // list input yields at least one element per row. + let LogicalPlan::Projection(input_projection) = unnest.input.as_ref() else { + return false; + }; + + unnest + .list_type_columns + .iter() + .all(|(input_idx, list_column)| { + list_column.depth == 1 + && input_projection + .expr + .get(*input_idx) + .is_some_and(expr_is_provably_non_empty_list) + }) +} + +fn expr_is_provably_non_empty_list(expr: &Expr) -> bool { + let expr = strip_alias(expr); + if expr.is_volatile() { + return false; + } + + match expr { + Expr::ScalarFunction(func) => { + func.name() == "make_array" && !func.args.is_empty() + } + Expr::Literal(scalar, _) => scalar_value_is_non_empty_list(scalar), + _ => false, + } +} + +fn strip_alias(expr: &Expr) -> &Expr { + match expr { + Expr::Alias(alias) => strip_alias(alias.expr.as_ref()), + _ => expr, + } +} + +fn scalar_value_is_non_empty_list(scalar: &ScalarValue) -> bool { + match scalar { + ScalarValue::List(arr) => !arr.is_null(0) && arr.value_length(0) > 0, + ScalarValue::LargeList(arr) => !arr.is_null(0) && arr.value_length(0) > 0, + ScalarValue::FixedSizeList(arr) => !arr.is_null(0) && arr.value_length() > 0, + _ => false, + } +} + +fn unnest_output_is_passthrough(unnest: &Unnest, output_idx: usize) -> bool { + let Some(&dependency_idx) = unnest.dependency_indices.get(output_idx) else { + return false; + }; + + if dependency_idx >= unnest.input.schema().fields().len() { + return false; + } + + unnest.schema.qualified_field(output_idx) + == unnest.input.schema().qualified_field(dependency_idx) +} + +fn build_unnest_child_requirements( + unnest: &Unnest, + indices: &RequiredIndices, +) -> RequiredIndices { + let child_indices = indices + .indices() + .iter() + .filter_map(|&output_idx| unnest.dependency_indices.get(output_idx).copied()) + .collect::>(); + let mut child_required_indices = RequiredIndices::new_from_indices(child_indices); + if indices.projection_beneficial() { + child_required_indices = child_required_indices.with_projection_beneficial(); + } + if indices.has_volatile_ancestor() { + child_required_indices = child_required_indices.with_volatile_ancestor(); + } + if !indices.multiplicity_sensitive() { + child_required_indices = + child_required_indices.for_multiplicity_insensitive_child(); + } + child_required_indices +} + fn expr_contains_subquery(expr: &Expr) -> bool { expr.exists(|e| match e { Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true), @@ -953,7 +1177,7 @@ mod tests { use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{ - Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, + Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, UnnestOptions, }; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ @@ -2274,6 +2498,103 @@ mod tests { ) } + #[test] + fn eliminate_struct_unnest_when_only_group_keys_are_required() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new( + "user", + DataType::Struct( + vec![ + Field::new("name", DataType::Utf8, true), + Field::new("score", DataType::Int32, true), + ] + .into(), + ), + true, + ), + ]); + let plan = scan_empty(Some("test"), &schema, None)? + .unnest_column("user")? + .aggregate(vec![col("id")], Vec::::new())? + .project(vec![col("id")])? + .build()?; + + let optimized = optimize(plan)?; + let formatted = format!("{}", optimized.display_indent()); + assert!(!formatted.contains("Unnest:")); + Ok(()) + } + + #[test] + fn keep_list_unnest_when_group_keys_are_only_required_outputs() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new( + "vals", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + ]); + let plan = scan_empty(Some("test"), &schema, None)? + .unnest_column("vals")? + .aggregate(vec![col("id")], Vec::::new())? + .project(vec![col("id")])? + .build()?; + + let optimized = optimize(plan)?; + let formatted = format!("{}", optimized.display_indent()); + assert!(formatted.contains("Unnest:")); + Ok(()) + } + + #[test] + fn keep_unnest_when_count_depends_on_row_multiplicity() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new( + "vals", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + ]); + let plan = scan_empty(Some("test"), &schema, None)? + .unnest_column("vals")? + .aggregate(vec![col("id")], vec![count(lit(1)).alias("cnt")])? + .project(vec![col("id"), col("cnt")])? + .build()?; + + let optimized = optimize(plan)?; + let formatted = format!("{}", optimized.display_indent()); + assert!(formatted.contains("Unnest:")); + Ok(()) + } + + #[test] + fn keep_unnest_when_preserve_nulls_is_disabled() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new( + "vals", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + ]); + let plan = scan_empty(Some("test"), &schema, None)? + .unnest_column_with_options( + "vals", + UnnestOptions::new().with_preserve_nulls(false), + )? + .aggregate(vec![col("id")], Vec::::new())? + .project(vec![col("id")])? + .build()?; + + let optimized = optimize(plan)?; + let formatted = format!("{}", optimized.display_indent()); + assert!(formatted.contains("Unnest:")); + Ok(()) + } + #[test] fn test_window() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/optimizer/src/optimize_projections/required_indices.rs b/datafusion/optimizer/src/optimize_projections/required_indices.rs index c1e0885c9b5f2..1abbe36c45d32 100644 --- a/datafusion/optimizer/src/optimize_projections/required_indices.rs +++ b/datafusion/optimizer/src/optimize_projections/required_indices.rs @@ -34,13 +34,33 @@ use datafusion_expr::{Expr, LogicalPlan}; /// Indices are always in order and without duplicates. For example, if these /// indices were added `[3, 2, 4, 3, 6, 1]`, the instance would be represented /// by `[1, 2, 3, 4, 6]`. -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub(super) struct RequiredIndices { /// The indices of the required columns in the indices: Vec, /// If putting a projection above children is beneficial for the parent. /// Defaults to false. projection_beneficial: bool, + /// Whether ancestors can observe row multiplicity changes. + /// + /// "Multiplicity" means how many rows a child produces, including duplicate + /// rows. If this is `true`, rewrites must preserve row counts exactly because + /// upstream expressions (for example, `COUNT` or window functions) may depend + /// on them. + multiplicity_sensitive: bool, + /// Whether any ancestor expression is volatile. + has_volatile_ancestor: bool, +} + +impl Default for RequiredIndices { + fn default() -> Self { + Self { + indices: Vec::new(), + projection_beneficial: false, + multiplicity_sensitive: true, + has_volatile_ancestor: false, + } + } } impl RequiredIndices { @@ -54,6 +74,8 @@ impl RequiredIndices { Self { indices: (0..plan.schema().fields().len()).collect(), projection_beneficial: false, + multiplicity_sensitive: true, + has_volatile_ancestor: false, } } @@ -62,6 +84,8 @@ impl RequiredIndices { Self { indices, projection_beneficial: false, + multiplicity_sensitive: true, + has_volatile_ancestor: false, } .compact() } @@ -77,6 +101,62 @@ impl RequiredIndices { self } + /// Mark this requirement as multiplicity-insensitive. + pub fn with_multiplicity_insensitive(mut self) -> Self { + self.multiplicity_sensitive = false; + self + } + + /// Mark this requirement as multiplicity-sensitive. + pub fn with_multiplicity_sensitive(mut self) -> Self { + self.multiplicity_sensitive = true; + self + } + + /// Return whether ancestors can observe multiplicity changes. + pub fn multiplicity_sensitive(&self) -> bool { + self.multiplicity_sensitive + } + + /// Mark this requirement as having volatile ancestors. + pub fn with_volatile_ancestor(mut self) -> Self { + self.has_volatile_ancestor = true; + self + } + + /// Conditionally mark this requirement as having volatile ancestors. + pub fn with_volatile_ancestor_if(mut self, value: bool) -> Self { + if value { + self.has_volatile_ancestor = true; + } + self + } + + /// Propagate volatile-plan context into this requirement. + /// + /// This keeps call sites declarative and centralizes state-transition logic. + pub fn with_plan_volatile(mut self, volatile_in_plan: bool) -> Self { + if volatile_in_plan { + self.has_volatile_ancestor = true; + } + self + } + + /// Transition this requirement for a multiplicity-sensitive child. + pub fn for_multiplicity_sensitive_child(self) -> Self { + self.with_multiplicity_sensitive() + } + + /// Transition this requirement for a multiplicity-insensitive child. + pub fn for_multiplicity_insensitive_child(self) -> Self { + self.with_multiplicity_insensitive() + } + + /// Return whether a volatile expression exists in the ancestor chain. + pub fn has_volatile_ancestor(&self) -> bool { + self.has_volatile_ancestor + } + /// Return the value of projection beneficial flag pub fn projection_beneficial(&self) -> bool { self.projection_beneficial @@ -173,10 +253,14 @@ impl RequiredIndices { Self { indices: l, projection_beneficial, + multiplicity_sensitive: self.multiplicity_sensitive, + has_volatile_ancestor: self.has_volatile_ancestor, }, Self { indices: r, projection_beneficial, + multiplicity_sensitive: self.multiplicity_sensitive, + has_volatile_ancestor: self.has_volatile_ancestor, }, ) } diff --git a/datafusion/sqllogictest/test_files/optimizer_unnest_prune.slt b/datafusion/sqllogictest/test_files/optimizer_unnest_prune.slt new file mode 100644 index 0000000000000..7b99bccfccaa6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/optimizer_unnest_prune.slt @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############################### +# Unnest Pruning Safety Tests # +############################### + +statement ok +CREATE TABLE unnest_prune_t +AS VALUES + (1, [1, 2]), + (2, []), + (3, [3]), + (4, null) +; + +statement ok +set datafusion.explain.logical_plan_only = true; + +# Safe case: struct unnest is cardinality-preserving and unnested outputs are dead. +# Unnest should be removed. +statement ok +CREATE TABLE unnest_prune_struct_t +AS VALUES + (1, struct('a', 10)), + (2, struct('b', 20)) +; + +query TT +EXPLAIN SELECT id +FROM ( + SELECT column1 AS id, unnest(column2) + FROM unnest_prune_struct_t +) q +GROUP BY id; +---- +logical_plan +01)Aggregate: groupBy=[[q.id]], aggr=[[]] +02)--SubqueryAlias: q +03)----Projection: unnest_prune_struct_t.column1 AS id +04)------TableScan: unnest_prune_struct_t projection=[column1] + +# Safe case: deterministic non-empty make_array unnest is cardinality-preserving. +# Unnest should be removed. +query TT +EXPLAIN SELECT id +FROM ( + SELECT column1 AS id, unnest(make_array(1, 2, 3)) AS elem + FROM unnest_prune_t +) q +GROUP BY id; +---- +logical_plan +01)Aggregate: groupBy=[[q.id]], aggr=[[]] +02)--SubqueryAlias: q +03)----Projection: unnest_prune_t.column1 AS id +04)------TableScan: unnest_prune_t projection=[column1] + +# Empty-list/null semantics are cardinality-sensitive even if unnested column is dead. +# Unnest must remain. +query TT +EXPLAIN SELECT id +FROM ( + SELECT column1 AS id, unnest(column2) AS elem + FROM unnest_prune_t +) q +GROUP BY id; +---- +logical_plan +01)Aggregate: groupBy=[[q.id]], aggr=[[]] +02)--SubqueryAlias: q +03)----Projection: id +04)------Unnest: lists[__unnest_placeholder(unnest_prune_t.column2)|depth=1] structs[] +05)--------Projection: unnest_prune_t.column1 AS id, unnest_prune_t.column2 AS __unnest_placeholder(unnest_prune_t.column2) +06)----------TableScan: unnest_prune_t projection=[column1, column2] + +# Count(*) is explicitly multiplicity-sensitive. Unnest must remain. +query TT +EXPLAIN SELECT id, count(*) AS cnt +FROM ( + SELECT column1 AS id, unnest(column2) AS elem + FROM unnest_prune_t +) q +GROUP BY id; +---- +logical_plan +01)Projection: q.id, count(Int64(1)) AS count(*) AS cnt +02)--Aggregate: groupBy=[[q.id]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: q +04)------Projection: id +05)--------Unnest: lists[__unnest_placeholder(unnest_prune_t.column2)|depth=1] structs[] +06)----------Projection: unnest_prune_t.column1 AS id, unnest_prune_t.column2 AS __unnest_placeholder(unnest_prune_t.column2) +07)------------TableScan: unnest_prune_t projection=[column1, column2] + +statement ok +set datafusion.explain.logical_plan_only = false; + +# Correctness check for empty-list/null behavior +query I +SELECT id +FROM ( + SELECT column1 AS id, unnest(column2) AS elem + FROM unnest_prune_t +) q +GROUP BY id +ORDER BY id; +---- +1 +3 + +# Correctness check for multiplicity-sensitive count path +query II +SELECT id, count(*) AS cnt +FROM ( + SELECT column1 AS id, unnest(column2) AS elem + FROM unnest_prune_t +) q +GROUP BY id +ORDER BY id; +---- +1 2 +3 1