Skip to content

Commit 4c81833

Browse files
authored
feat: support parsing of SQL queries with APPLY (#106)
* feat: support parsing of SQL queries with APPLY This change adds support for parsing of SQL queries with APPLY (join with correlated subquery), and to build OuterReferences map of correlated variables present in the query's join predicates. The OuterRefs will be used while constructing Substrait plans to bind correlated variables. The change also adds few example queries which depend on APPLY / LATERAL operators. This change still does not map calcite-correlated-join to Substrait, as the spec for APPLY is still not approved. As such, while the parsing of calcite query plans will succeed after this change, the unit tests and run time conversion will continue to fail in the final step of building the Substrait plan. Additional changes are needed to support APPLY. Refs #substrait-io/substrait/issues/357 * fix: unit test cases to validate correlated vars This change addresses review comments, the unit tests validate the outer reference map built from calcite plans of APPLY queries. * fix: add test for nested APPLY This change addresses review comments. A new test case to validate nested APPLY join parsing is added. Also added validation of depth information in existing tests.
1 parent 178695f commit 4c81833

File tree

6 files changed

+247
-21
lines changed

6 files changed

+247
-21
lines changed

isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import java.util.Map;
88
import org.apache.calcite.rel.RelNode;
99
import org.apache.calcite.rel.core.CorrelationId;
10+
import org.apache.calcite.rel.logical.LogicalCorrelate;
1011
import org.apache.calcite.rel.logical.LogicalFilter;
1112
import org.apache.calcite.rel.logical.LogicalProject;
1213
import org.apache.calcite.rex.*;
@@ -48,6 +49,32 @@ public RelNode visit(LogicalFilter filter) throws RuntimeException {
4849
return super.visit(filter);
4950
}
5051

52+
@Override
53+
public RelNode visit(LogicalCorrelate correlate) throws RuntimeException {
54+
for (CorrelationId id : correlate.getVariablesSet()) {
55+
if (!nestedDepth.containsKey(id)) {
56+
nestedDepth.put(id, 0);
57+
}
58+
}
59+
60+
apply(correlate.getLeft());
61+
62+
// Correlated join is a special case. The right-rel is a correlated sub-query but not a REX. So,
63+
// the RexVisitor cannot be applied to it to correctly compute the depth map. Hence, we need to
64+
// manually compute the depth map for the right-rel.
65+
for (Map.Entry<CorrelationId, Integer> entry : nestedDepth.entrySet()) {
66+
nestedDepth.put(entry.getKey(), entry.getValue() + 1);
67+
}
68+
69+
apply(correlate.getRight()); // look inside sub-queries
70+
71+
for (Map.Entry<CorrelationId, Integer> entry : nestedDepth.entrySet()) {
72+
nestedDepth.put(entry.getKey(), entry.getValue() - 1);
73+
}
74+
75+
return correlate;
76+
}
77+
5178
@Override
5279
public RelNode visitOther(RelNode other) throws RuntimeException {
5380
for (RelNode child : other.getInputs()) {

isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ WHERE p_size <
1414
AND PS.ps_suppkey = l.l_suppkey))
1515
1616
17-
Filter --- $coor0
17+
Filter --- $corr0
1818
/ \ condition
1919
/ p_size < RexSubquery
2020
Scan(P) |
@@ -23,7 +23,7 @@ WHERE p_size <
2323
|
2424
Project
2525
|
26-
Filter --- $coor2
26+
Filter --- $corr2
2727
/ \
2828
/ \
2929
Scan (L) \

isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package io.substrait.isthmus;
22

3-
import io.substrait.function.ImmutableSimpleExtension;
43
import io.substrait.function.SimpleExtension;
54
import io.substrait.type.NamedStruct;
65
import java.io.IOException;
@@ -67,15 +66,17 @@ protected SqlConverterBase(FeatureBoard features) {
6766
new ProxyingMetadataHandlerProvider(DefaultRelMetadataProvider.INSTANCE);
6867
return new RelMetadataQuery(handler);
6968
});
70-
parserConfig = SqlParser.Config.DEFAULT.withParserFactory(SqlDdlParserImpl.FACTORY);
7169
featureBoard = features == null ? FEATURES_DEFAULT : features;
70+
parserConfig =
71+
SqlParser.Config.DEFAULT
72+
.withParserFactory(SqlDdlParserImpl.FACTORY)
73+
.withConformance(featureBoard.sqlConformanceMode());
7274
}
7375

7476
protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION;
7577

7678
static {
77-
SimpleExtension.ExtensionCollection defaults =
78-
ImmutableSimpleExtension.ExtensionCollection.builder().build();
79+
SimpleExtension.ExtensionCollection defaults;
7980
try {
8081
defaults = SimpleExtension.loadDefaults();
8182
} catch (IOException e) {

isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.substrait.isthmus;
22

3+
import com.google.common.annotations.VisibleForTesting;
34
import io.substrait.expression.proto.FunctionCollector;
45
import io.substrait.proto.Plan;
56
import io.substrait.proto.PlanRel;
@@ -13,6 +14,7 @@
1314
import org.apache.calcite.rel.RelRoot;
1415
import org.apache.calcite.rel.type.RelDataTypeFactory;
1516
import org.apache.calcite.schema.Schema;
17+
import org.apache.calcite.sql.SqlNode;
1618
import org.apache.calcite.sql.parser.SqlParseException;
1719
import org.apache.calcite.sql.parser.SqlParser;
1820
import org.apache.calcite.sql.validate.SqlValidator;
@@ -95,6 +97,17 @@ private List<RelRoot> sqlToRelNode(
9597
if (!featureBoard.allowsSqlBatch() && parsedList.size() > 1) {
9698
throw new UnsupportedOperationException("SQL must contain only a single statement: " + sql);
9799
}
100+
SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader);
101+
List<RelRoot> roots =
102+
parsedList.stream()
103+
.map(parsed -> getBestExpRelRoot(converter, parsed))
104+
.collect(java.util.stream.Collectors.toList());
105+
return roots;
106+
}
107+
108+
@VisibleForTesting
109+
SqlToRelConverter createSqlToRelConverter(
110+
SqlValidator validator, CalciteCatalogReader catalogReader) {
98111
SqlToRelConverter converter =
99112
new SqlToRelConverter(
100113
null,
@@ -103,20 +116,18 @@ private List<RelRoot> sqlToRelNode(
103116
relOptCluster,
104117
StandardConvertletTable.INSTANCE,
105118
converterConfig);
106-
List<RelRoot> roots =
107-
parsedList.stream()
108-
.map(
109-
parsed -> {
110-
RelRoot root = converter.convertQuery(parsed, true, true);
111-
{
112-
var program = HepProgram.builder().build();
113-
HepPlanner hepPlanner = new HepPlanner(program);
114-
hepPlanner.setRoot(root.rel);
115-
root = root.withRel(hepPlanner.findBestExp());
116-
}
117-
return root;
118-
})
119-
.collect(java.util.stream.Collectors.toList());
120-
return roots;
119+
return converter;
120+
}
121+
122+
@VisibleForTesting
123+
static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) {
124+
RelRoot root = converter.convertQuery(parsed, true, true);
125+
{
126+
var program = HepProgram.builder().build();
127+
HepPlanner hepPlanner = new HepPlanner(program);
128+
hepPlanner.setRoot(root.rel);
129+
root = root.withRel(hepPlanner.findBestExp());
130+
}
131+
return root;
121132
}
122133
}

isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
@SuppressWarnings("UnstableApiUsage")
6262
@Value.Enclosing
6363
public class SubstraitRelVisitor extends RelNodeVisitor<Rel, RuntimeException> {
64+
6465
static final org.slf4j.Logger logger =
6566
org.slf4j.LoggerFactory.getLogger(SubstraitRelVisitor.class);
6667
private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build();
@@ -196,6 +197,19 @@ public Rel visit(LogicalJoin join) {
196197

197198
@Override
198199
public Rel visit(LogicalCorrelate correlate) {
200+
// left input of correlated-join is similar to the left input of a logical join
201+
apply(correlate.getLeft());
202+
203+
// right input of correlated-join is similar to a correlated sub-query
204+
apply(correlate.getRight());
205+
206+
var joinType =
207+
switch (correlate.getJoinType()) {
208+
case INNER -> Join.JoinType.INNER; // corresponds to CROSS APPLY join
209+
case LEFT -> Join.JoinType.LEFT; // corresponds to OUTER APPLY join
210+
default -> throw new IllegalArgumentException(
211+
"Invalid correlated join type: " + correlate.getJoinType());
212+
};
199213
return super.visit(correlate);
200214
}
201215

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
package io.substrait.isthmus;
2+
3+
import java.util.Map;
4+
import org.apache.calcite.adapter.tpcds.TpcdsSchema;
5+
import org.apache.calcite.rel.RelRoot;
6+
import org.apache.calcite.rex.RexFieldAccess;
7+
import org.apache.calcite.sql.parser.SqlParseException;
8+
import org.apache.calcite.sql.parser.SqlParser;
9+
import org.apache.calcite.sql.validate.SqlConformanceEnum;
10+
import org.junit.jupiter.api.Assertions;
11+
import org.junit.jupiter.api.Test;
12+
13+
public class ApplyJoinPlanTest {
14+
15+
private static RelRoot getCalcitePlan(SqlToSubstrait s, TpcdsSchema schema, String sql)
16+
throws SqlParseException {
17+
var pair = s.registerSchema("tpcds", schema);
18+
var converter = s.createSqlToRelConverter(pair.left, pair.right);
19+
SqlParser parser = SqlParser.create(sql, s.parserConfig);
20+
var root = s.getBestExpRelRoot(converter, parser.parseQuery());
21+
return root;
22+
}
23+
24+
private static void validateOuterRef(
25+
Map<RexFieldAccess, Integer> fieldAccessDepthMap, String refName, String colName, int depth) {
26+
var entry =
27+
fieldAccessDepthMap.entrySet().stream()
28+
.filter(f -> f.getKey().getReferenceExpr().toString().equals(refName))
29+
.filter(f -> f.getKey().getField().getName().equals(colName))
30+
.filter(f -> f.getValue() == depth)
31+
.findFirst();
32+
Assertions.assertTrue(entry.isPresent());
33+
}
34+
35+
private static Map<RexFieldAccess, Integer> buildOuterFieldRefMap(RelRoot root) {
36+
final OuterReferenceResolver resolver = new OuterReferenceResolver();
37+
var fieldAccessDepthMap = resolver.getFieldAccessDepthMap();
38+
Assertions.assertEquals(0, fieldAccessDepthMap.size());
39+
resolver.apply(root.rel);
40+
return fieldAccessDepthMap;
41+
}
42+
43+
@Test
44+
public void lateralJoinQuery() throws SqlParseException {
45+
TpcdsSchema schema = new TpcdsSchema(1.0);
46+
String sql;
47+
sql =
48+
"""
49+
SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk
50+
FROM store_sales CROSS JOIN LATERAL
51+
(select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)""";
52+
53+
/* the calcite plan for the above query is:
54+
LogicalProject(SS_SOLD_DATE_SK=[$0], SS_ITEM_SK=[$2], SS_CUSTOMER_SK=[$3])
55+
LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{2}])
56+
LogicalTableScan(table=[[tpcds, STORE_SALES]])
57+
LogicalProject(I_ITEM_SK=[$0])
58+
LogicalFilter(condition=[=($0, $cor0.SS_ITEM_SK)])
59+
LogicalTableScan(table=[[tpcds, ITEM]])
60+
*/
61+
62+
// validate outer reference map
63+
RelRoot root = getCalcitePlan(new SqlToSubstrait(), schema, sql);
64+
Map<RexFieldAccess, Integer> fieldAccessDepthMap = buildOuterFieldRefMap(root);
65+
Assertions.assertEquals(1, fieldAccessDepthMap.size());
66+
validateOuterRef(fieldAccessDepthMap, "$cor0", "SS_ITEM_SK", 1);
67+
68+
// TODO validate end to end conversion
69+
var sE2E = new SqlToSubstrait();
70+
Assertions.assertThrows(
71+
UnsupportedOperationException.class,
72+
() -> sE2E.execute(sql, "tpcds", schema),
73+
"Lateral join is not supported");
74+
}
75+
76+
@Test
77+
public void outerApplyQuery() throws SqlParseException {
78+
TpcdsSchema schema = new TpcdsSchema(1.0);
79+
String sql;
80+
sql =
81+
"""
82+
SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk
83+
FROM store_sales OUTER APPLY
84+
(select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)""";
85+
86+
FeatureBoard featureBoard =
87+
ImmutableFeatureBoard.builder()
88+
.sqlConformanceMode(SqlConformanceEnum.SQL_SERVER_2008)
89+
.build();
90+
SqlToSubstrait s = new SqlToSubstrait(featureBoard);
91+
RelRoot root = getCalcitePlan(s, schema, sql);
92+
93+
Map<RexFieldAccess, Integer> fieldAccessDepthMap = buildOuterFieldRefMap(root);
94+
Assertions.assertEquals(1, fieldAccessDepthMap.size());
95+
validateOuterRef(fieldAccessDepthMap, "$cor0", "SS_ITEM_SK", 1);
96+
97+
// TODO validate end to end conversion
98+
Assertions.assertThrows(
99+
UnsupportedOperationException.class,
100+
() -> s.execute(sql, "tpcds", schema),
101+
"APPLY is not supported");
102+
}
103+
104+
@Test
105+
public void nestedApplyJoinQuery() throws SqlParseException {
106+
TpcdsSchema schema = new TpcdsSchema(1.0);
107+
String sql;
108+
sql =
109+
"""
110+
SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk
111+
FROM store_sales CROSS APPLY
112+
( SELECT i_item_sk
113+
FROM item CROSS APPLY
114+
( SELECT p_promo_sk
115+
FROM promotion
116+
WHERE p_item_sk = i_item_sk AND p_item_sk = ss_item_sk )
117+
WHERE item.i_item_sk = store_sales.ss_item_sk )""";
118+
119+
/* the calcite plan for the above query is:
120+
LogicalProject(SS_SOLD_DATE_SK=[$0], SS_ITEM_SK=[$2], SS_CUSTOMER_SK=[$3])
121+
LogicalCorrelate(correlation=[$cor2], joinType=[inner], requiredColumns=[{2}])
122+
LogicalTableScan(table=[[tpcds, STORE_SALES]])
123+
LogicalProject(I_ITEM_SK=[$0])
124+
LogicalFilter(condition=[=($0, $cor2.SS_ITEM_SK)])
125+
LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
126+
LogicalTableScan(table=[[tpcds, ITEM]])
127+
LogicalProject(P_PROMO_SK=[$0])
128+
LogicalFilter(condition=[AND(=($4, $cor0.I_ITEM_SK), =($4, $cor2.SS_ITEM_SK))])
129+
LogicalTableScan(table=[[tpcds, PROMOTION]])
130+
*/
131+
FeatureBoard featureBoard =
132+
ImmutableFeatureBoard.builder()
133+
.sqlConformanceMode(SqlConformanceEnum.SQL_SERVER_2008)
134+
.build();
135+
SqlToSubstrait s = new SqlToSubstrait(featureBoard);
136+
RelRoot root = getCalcitePlan(s, schema, sql);
137+
138+
Map<RexFieldAccess, Integer> fieldAccessDepthMap = buildOuterFieldRefMap(root);
139+
Assertions.assertEquals(3, fieldAccessDepthMap.size());
140+
validateOuterRef(fieldAccessDepthMap, "$cor2", "SS_ITEM_SK", 1);
141+
validateOuterRef(fieldAccessDepthMap, "$cor2", "SS_ITEM_SK", 2);
142+
validateOuterRef(fieldAccessDepthMap, "$cor0", "I_ITEM_SK", 1);
143+
144+
// TODO validate end to end conversion
145+
Assertions.assertThrows(
146+
UnsupportedOperationException.class,
147+
() -> s.execute(sql, "tpcds", schema),
148+
"APPLY is not supported");
149+
}
150+
151+
@Test
152+
public void crossApplyQuery() throws SqlParseException {
153+
TpcdsSchema schema = new TpcdsSchema(1.0);
154+
String sql;
155+
sql =
156+
"""
157+
SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk
158+
FROM store_sales CROSS APPLY
159+
(select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)""";
160+
161+
FeatureBoard featureBoard =
162+
ImmutableFeatureBoard.builder()
163+
.sqlConformanceMode(SqlConformanceEnum.SQL_SERVER_2008)
164+
.build();
165+
SqlToSubstrait s = new SqlToSubstrait(featureBoard);
166+
167+
// TODO validate end to end conversion
168+
Assertions.assertThrows(
169+
UnsupportedOperationException.class,
170+
() -> s.execute(sql, "tpcds", schema),
171+
"APPLY is not supported");
172+
}
173+
}

0 commit comments

Comments
 (0)