Skip to content

[SPARK-52354][SQL] Add type coercion to UnionLoop #51063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ import org.apache.spark.sql.catalyst.plans.logical.{
Project,
ReplaceTable,
Union,
UnionLoop,
Unpivot
}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.errors.DataTypeErrors.cannotMergeIncompatibleDataTypesError
import org.apache.spark.sql.types.DataType

abstract class TypeCoercionBase extends TypeCoercionHelper {
Expand Down Expand Up @@ -247,6 +249,25 @@ abstract class TypeCoercionBase extends TypeCoercionHelper {
val attrMapping = s.children.head.output.zip(newChildren.head.output)
s.copy(children = newChildren) -> attrMapping
}

case s: UnionLoop
if s.childrenResolved && s.anchor.output.length == s.recursion.output.length
&& !s.resolved =>
// If the anchor data type is wider than the recursion data type, we cast the recursion
// type to match the anchor type.
// On the other hand, we cannot cast the anchor type into a wider recursion type, as at
// this point the UnionLoopRefs inside the recursion are already resolved with the
// narrower anchor type.
val projectList = s.recursion.output.zip(s.anchor.output.map(_.dataType)).map {
case (attr, dt) =>
val widerType = findWiderTypeForTwo(attr.dataType, dt)
if (widerType.isDefined && widerType.get == dt) {
Alias(Cast(attr, dt), attr.name)()
} else {
throw cannotMergeIncompatibleDataTypesError(dt, attr.dataType)
}
}
s.copy(recursion = Project(projectList, s.recursion)) -> Nil
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
newValidAttrMapping.filterNot { case (_, a) => existingAttrMappingSet.contains(a) }
}
val resultAttrMapping = if (canGetOutput(plan)) {
// We propagate the attributes mapping to the parent plan node to update attributes, so
// the `newAttr` must be part of this plan's output.
(transferAttrMapping ++ newOtherAttrMapping).filter {
case (_, newAttr) => planAfterRule.outputSet.contains(newAttr)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,23 @@ abstract class UnionBase extends LogicalPlan {
.map(child => rewriteConstraints(children.head.output, child.output, child.constraints))
.reduce(merge(_, _))
}




/**
* Checks whether the child outputs are compatible by using `DataType.equalsStructurally`. Do
* that by comparing the size of the output with the size of the first child's output and by
* comparing output data types with the data types of the first child's output.
*
* This method needs to be evaluated after `childrenResolved`.
*/
def allChildrenCompatible: Boolean = childrenResolved && children.tail.forall { child =>
child.output.length == children.head.output.length &&
child.output.zip(children.head.output).forall {
case (l, r) => DataType.equalsStructurally(l.dataType, r.dataType, true)
}
}
}

/**
Expand Down Expand Up @@ -606,20 +623,6 @@ case class Union(
children.length > 1 && !(byName || allowMissingCol) && childrenResolved && allChildrenCompatible
}

/**
* Checks whether the child outputs are compatible by using `DataType.equalsStructurally`. Do
* that by comparing the size of the output with the size of the first child's output and by
* comparing output data types with the data types of the first child's output.
*
* This method needs to be evaluated after `childrenResolved`.
*/
def allChildrenCompatible: Boolean = childrenResolved && children.tail.forall { child =>
child.output.length == children.head.output.length &&
child.output.zip(children.head.output).forall {
case (l, r) => DataType.equalsStructurally(l.dataType, r.dataType, true)
}
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): Union =
copy(children = newChildren)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ case class UnionLoop(
id.toString + limit.map(", " + _.toString).getOrElse("") +
maxDepth.map(", " + _.toString).getOrElse("")
}

override lazy val resolved: Boolean = {
// allChildrenCompatible needs to be evaluated after childrenResolved
childrenResolved && allChildrenCompatible
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1830,6 +1830,49 @@ SELECT val FROM randoms LIMIT 5
[Analyzer test output redacted due to nondeterminism]


-- !query
WITH RECURSIVE t1(n, m) AS (
SELECT 1, CAST(1 AS BIGINT)
UNION ALL
SELECT n+1, n+1 FROM t1 WHERE n < 5)
SELECT * FROM t1
-- !query analysis
WithCTE
:- CTERelationDef xxxx, false
: +- SubqueryAlias t1
: +- Project [1#x AS n#x, CAST(1 AS BIGINT)#xL AS m#xL]
: +- UnionLoop xxxx
: :- Project [1 AS 1#x, cast(1 as bigint) AS CAST(1 AS BIGINT)#xL]
: : +- OneRowRelation
: +- Project [cast((n + 1)#x as int) AS (n + 1)#x, cast((n + 1)#x as bigint) AS (n + 1)#xL]
: +- Project [(n#x + 1) AS (n + 1)#x, (n#x + 1) AS (n + 1)#x]
: +- Filter (n#x < 5)
: +- SubqueryAlias t1
: +- Project [1#x AS n#x, CAST(1 AS BIGINT)#xL AS m#xL]
: +- UnionLoopRef xxxx, [1#x, CAST(1 AS BIGINT)#xL], false
+- Project [n#x, m#xL]
+- SubqueryAlias t1
+- CTERelationRef xxxx, true, [n#x, m#xL], false, false


-- !query
WITH RECURSIVE t1(n, m) AS (
SELECT 1, 1
UNION ALL
SELECT n+1, CAST(n+1 AS BIGINT) FROM t1 WHERE n < 5)
SELECT * FROM t1
-- !query analysis
org.apache.spark.SparkException
{
"errorClass" : "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE",
"sqlState" : "42825",
"messageParameters" : {
"left" : "\"INT\"",
"right" : "\"BIGINT\""
}
}


-- !query
WITH RECURSIVE t1(n) AS (
SELECT 1
Expand Down
14 changes: 14 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,20 @@ WITH RECURSIVE randoms(val) AS (
)
SELECT val FROM randoms LIMIT 5;

-- Type coercion where the anchor is wider
WITH RECURSIVE t1(n, m) AS (
SELECT 1, CAST(1 AS BIGINT)
UNION ALL
SELECT n+1, n+1 FROM t1 WHERE n < 5)
SELECT * FROM t1;

-- Type coercion where the recursion is wider
WITH RECURSIVE t1(n, m) AS (
SELECT 1, 1
UNION ALL
SELECT n+1, CAST(n+1 AS BIGINT) FROM t1 WHERE n < 5)
SELECT * FROM t1;

-- Recursive CTE with nullable recursion and non-recursive anchor
WITH RECURSIVE t1(n) AS (
SELECT 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,42 @@ struct<val:array<int>>
[4,5,1,2,3]


-- !query
WITH RECURSIVE t1(n, m) AS (
SELECT 1, CAST(1 AS BIGINT)
UNION ALL
SELECT n+1, n+1 FROM t1 WHERE n < 5)
SELECT * FROM t1
-- !query schema
struct<n:int,m:bigint>
-- !query output
1 1
2 2
3 3
4 4
5 5


-- !query
WITH RECURSIVE t1(n, m) AS (
SELECT 1, 1
UNION ALL
SELECT n+1, CAST(n+1 AS BIGINT) FROM t1 WHERE n < 5)
SELECT * FROM t1
-- !query schema
struct<>
-- !query output
org.apache.spark.SparkException
{
"errorClass" : "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE",
"sqlState" : "42825",
"messageParameters" : {
"left" : "\"INT\"",
"right" : "\"BIGINT\""
}
}


-- !query
WITH RECURSIVE t1(n) AS (
SELECT 1
Expand Down