From abf329533aef4bec73d1997d896149746a25977e Mon Sep 17 00:00:00 2001 From: Helios He Date: Mon, 23 Mar 2026 17:18:34 +0000 Subject: [PATCH] initial commit: fix + test --- .../catalyst/expressions/aggregate/collect.scala | 12 ++++++++++++ .../apache/spark/sql/DataFrameAggregateSuite.scala | 13 +++++++++++++ 2 files changed, 25 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 8f6b80ebbe772..ee06147b03944 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -162,6 +162,12 @@ case class CollectList( s"$prettyName($child)$ignoreNullsStr" } + override def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else "" + val nullsStr = if (ignoreNulls) "" else " RESPECT NULLS" + s"$prettyName($distinct${child.sql})$nullsStr" + } + override protected def withNewChildInternal(newChild: Expression): CollectList = copy(child = newChild) } @@ -268,6 +274,12 @@ case class CollectSet( s"$prettyName($child)$ignoreNullsStr" } + override def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else "" + val nullsStr = if (ignoreNulls) "" else " RESPECT NULLS" + s"$prettyName($distinct${child.sql})$nullsStr" + } + override protected def withNewChildInternal(newChild: Expression): CollectSet = copy(child = newChild) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 355333d8268f5..a55638b2431c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -791,6 +791,19 @@ class DataFrameAggregateSuite extends QueryTest Seq(Row(Seq(1.0, 2.0)))) } + test("SPARK-56155: collect functions sql() display RESPECT NULLS") { + val df = Seq((1, Some(2)), (1, None), (1, Some(4))).toDF("a", "b") + val collect_list_result = df.selectExpr("collect_list(b) RESPECT NULLS") + val collect_list_result2 = df.selectExpr("collect_list(b)") + assert(collect_list_result.columns.head == "collect_list(b) RESPECT NULLS") + assert(collect_list_result2.columns.head == "collect_list(b)") + + val collect_set_result = df.selectExpr("collect_set(b) RESPECT NULLS") + val collect_set_result2 = df.selectExpr("collect_set(b)") + assert(collect_set_result.columns.head == "collect_set(b) RESPECT NULLS") + assert(collect_set_result2.columns.head == "collect_set(b)") + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),