Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 88 additions & 13 deletions datafusion/physical-plan/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ use datafusion_common::{
use datafusion_execution::TaskContext;
use datafusion_expr::Operator;
use datafusion_physical_expr::equivalence::ProjectionMapping;
use datafusion_physical_expr::expressions::{BinaryExpr, Column, lit};
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, lit};
use datafusion_physical_expr::intervals::utils::check_support;
use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns};
use datafusion_physical_expr::{
AcrossPartitions, AnalysisContext, ConstExpr, ExprBoundaries, PhysicalExpr, analyze,
conjunction, split_conjunction,
AcrossPartitions, AnalysisContext, ConstExpr, EquivalenceProperties, ExprBoundaries,
PhysicalExpr, analyze, conjunction, split_conjunction,
};

use datafusion_physical_expr_common::physical_expr::fmt_sql;
Expand Down Expand Up @@ -347,6 +347,20 @@ impl FilterExec {
})
}

/// Returns the `AcrossPartitions` value for `expr` if it is constant:
/// either already known constant in `input_eqs`, or a `Literal`
/// (which is inherently constant across all partitions).
fn expr_constant_or_literal(
expr: &Arc<dyn PhysicalExpr>,
input_eqs: &EquivalenceProperties,
) -> Option<AcrossPartitions> {
input_eqs.is_expr_constant(expr).or_else(|| {
expr.as_any()
.downcast_ref::<Literal>()
.map(|l| AcrossPartitions::Uniform(Some(l.value().clone())))
})
}

fn extend_constants(
input: &Arc<dyn ExecutionPlan>,
predicate: &Arc<dyn PhysicalExpr>,
Expand All @@ -359,18 +373,24 @@ impl FilterExec {
if let Some(binary) = conjunction.as_any().downcast_ref::<BinaryExpr>()
&& binary.op() == &Operator::Eq
{
// Filter evaluates to single value for all partitions
if input_eqs.is_expr_constant(binary.left()).is_some() {
let across = input_eqs
.is_expr_constant(binary.right())
.unwrap_or_default();
// Check if either side is constant — either already known
// constant from the input equivalence properties, or a literal
// value (which is inherently constant across all partitions).
let left_const = Self::expr_constant_or_literal(binary.left(), input_eqs);
let right_const =
Self::expr_constant_or_literal(binary.right(), input_eqs);

if let Some(left_across) = left_const {
// LEFT is constant, so RIGHT must also be constant.
// Use RIGHT's known across value if available, otherwise
// propagate LEFT's (e.g. Uniform from a literal).
let across = right_const.unwrap_or(left_across);
res_constants
.push(ConstExpr::new(Arc::clone(binary.right()), across));
} else if input_eqs.is_expr_constant(binary.right()).is_some() {
let across = input_eqs
.is_expr_constant(binary.left())
.unwrap_or_default();
res_constants.push(ConstExpr::new(Arc::clone(binary.left()), across));
} else if let Some(right_across) = right_const {
// RIGHT is constant, so LEFT must also be constant.
res_constants
.push(ConstExpr::new(Arc::clone(binary.left()), right_across));
}
}
}
Expand Down Expand Up @@ -979,6 +999,19 @@ fn collect_columns_from_predicate_inner(
let predicates = split_conjunction(predicate);
predicates.into_iter().for_each(|p| {
if let Some(binary) = p.as_any().downcast_ref::<BinaryExpr>() {
// Only extract pairs where at least one side is a Column reference.
// Pairs like `complex_expr = literal` should not create equivalence
// classes — the literal could appear in many unrelated expressions
// (e.g. sort keys), and normalize_expr's deep traversal would
// replace those occurrences with the complex expression, corrupting
// sort orderings. Constant propagation for such pairs is handled
// separately by `extend_constants`.
let has_direct_column_operand =
binary.left().as_any().downcast_ref::<Column>().is_some()
|| binary.right().as_any().downcast_ref::<Column>().is_some();
if !has_direct_column_operand {
return;
}
match binary.op() {
Operator::Eq => {
eq_predicate_columns.push((binary.left(), binary.right()))
Expand Down Expand Up @@ -2066,4 +2099,46 @@ mod tests {

Ok(())
}

/// Regression test for https://github.com/apache/datafusion/issues/20194
///
/// `collect_columns_from_predicate_inner` should only extract equality
/// pairs where at least one side is a Column. Pairs like
/// `complex_expr = literal` must not create equivalence classes because
/// `normalize_expr`'s deep traversal would replace the literal inside
/// unrelated expressions (e.g. sort keys) with the complex expression.
#[test]
fn test_collect_columns_skips_non_column_pairs() -> Result<()> {
let schema = test::aggr_test_schema();

// Simulate: nvl(c2, 0) = 0 → (c2 IS DISTINCT FROM 0) = 0
// Neither side is a Column, so this should NOT be extracted.
let complex_expr: Arc<dyn PhysicalExpr> = binary(
col("c2", &schema)?,
Operator::IsDistinctFrom,
lit(0u32),
&schema,
)?;
let predicate: Arc<dyn PhysicalExpr> =
binary(complex_expr, Operator::Eq, lit(0u32), &schema)?;

let (equal_pairs, _) = collect_columns_from_predicate_inner(&predicate);
assert_eq!(
0,
equal_pairs.len(),
"Should not extract equality pairs where neither side is a Column"
);

// But col = literal should still be extracted
let predicate: Arc<dyn PhysicalExpr> =
binary(col("c2", &schema)?, Operator::Eq, lit(0u32), &schema)?;
let (equal_pairs, _) = collect_columns_from_predicate_inner(&predicate);
assert_eq!(
1,
equal_pairs.len(),
"Should extract equality pairs where one side is a Column"
);

Ok(())
}
}
46 changes: 46 additions & 0 deletions datafusion/sqllogictest/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -6081,3 +6081,49 @@ WHERE acctbal > (
);
----
1

# Regression test for https://github.com/apache/datafusion/issues/20194
# Window function with CASE WHEN in ORDER BY combined with NVL filter
# should not trigger SanityCheckPlan error from equivalence normalization
# replacing literals in sort expressions with complex filter expressions.
statement ok
CREATE TABLE issue_20194_t1 (
value_1_1 decimal(25) NULL,
value_1_2 int NULL,
value_1_3 bigint NULL
);

statement ok
CREATE TABLE issue_20194_t2 (
value_2_1 bigint NULL,
value_2_2 varchar(140) NULL,
value_2_3 varchar(140) NULL
);

statement ok
INSERT INTO issue_20194_t1 (value_1_1, value_1_2, value_1_3) VALUES (6774502793, 10040029, 1120);

statement ok
INSERT INTO issue_20194_t2 (value_2_1, value_2_2, value_2_3) VALUES (1120, '0', '0');

query RII
SELECT
t1.value_1_1, t1.value_1_2,
ROW_NUMBER() OVER (
PARTITION BY t1.value_1_1, t1.value_1_2
ORDER BY
CASE WHEN t2.value_2_2 = '0' THEN 1 ELSE 0 END ASC,
CASE WHEN t2.value_2_3 = '0' THEN 1 ELSE 0 END ASC
) AS ord
FROM issue_20194_t1 t1
INNER JOIN issue_20194_t2 t2
ON t1.value_1_3 = t2.value_2_1
AND nvl(t2.value_2_3, '0') = '0';
----
6774502793 10040029 1

statement ok
DROP TABLE issue_20194_t1;

statement ok
DROP TABLE issue_20194_t2;
Loading