diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 5aef82b64ed32..f508dd2905935 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -258,12 +258,18 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def patchAggregateFunctionChildren( af: AggregateFunction)( attrs: Expression => Option[Expression]): AggregateFunction = { - val newChildren = af.children.map(c => attrs(c).getOrElse(c)) + val newChildren = af.children.map { + case so: SortOrder => + so.copy(child = attrs(so.child).getOrElse(so.child)) + case c => + attrs(c).getOrElse(c) + } af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] } // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq + .filter(_.isInstanceOf[AttributeReference]).distinct val distinctAggChildAttrMap = distinctAggChildren.map { e => e.canonicalized -> AttributeReference(e.sql, e.dataType, nullable = true)() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 6ce0a657d5b9d..4de5566f863a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -661,6 +661,17 @@ class DataFrameAggregateSuite extends QueryTest ) } + test("ListAgg should be able to handle multiple distinct aggregations") { + withTable("tbl") { + sql("CREATE TABLE tbl (col1 STRING, col2 STRING) USING parquet") + sql("INSERT INTO tbl VALUES ('A', 'x'), ('A', 'y'), ('B', 'y')") + val lag1 = "LISTAGG(DISTINCT col1) WITHIN GROUP (ORDER BY col1)" + val lag2 = "LISTAGG(DISTINCT col2) WITHIN GROUP (ORDER BY col2)" + val query = sql(s"SELECT $lag1, $lag2 FROM tbl") + checkAnswer(query, Seq(Row("AB", "xy"))) + } + } + test("SPARK-31500: collect_set() of BinaryType returns duplicate elements") { val bytesTest1 = "test1".getBytes val bytesTest2 = "test2".getBytes