Skip to content

Commit b7abddd

Browse files
authored
feat(isthmus): convert Calcite RelRoot to Substrait Plan.Root (#370)
RelRoots must be converted to Plan.Roots in order to ensure that names are handled correctly. BREAKING CHANGE: converting a Calcite RelRoot no longer produces a Substrait Rel
1 parent 66447cb commit b7abddd

File tree

7 files changed

+171
-72
lines changed

7 files changed

+171
-72
lines changed

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import io.substrait.extension.ExtensionLookup;
99
import io.substrait.extension.SimpleExtension;
1010
import io.substrait.hint.Hint;
11+
import io.substrait.plan.Plan;
1112
import io.substrait.proto.AggregateRel;
1213
import io.substrait.proto.ConsistentPartitionWindowRel;
1314
import io.substrait.proto.CrossRel;
@@ -61,6 +62,10 @@ public ProtoRelConverter(ExtensionLookup lookup, SimpleExtension.ExtensionCollec
6162
this.protoTypeConverter = new ProtoTypeConverter(lookup, extensions);
6263
}
6364

65+
public Plan.Root from(io.substrait.proto.RelRoot rel) {
66+
return Plan.Root.builder().input(from(rel.getInput())).addAllNames(rel.getNamesList()).build();
67+
}
68+
6469
public Rel from(io.substrait.proto.Rel rel) {
6570
var relType = rel.getRelTypeCase();
6671
switch (relType) {

core/src/main/java/io/substrait/relation/RelProtoConverter.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import io.substrait.expression.proto.ExpressionProtoConverter;
77
import io.substrait.expression.proto.ExpressionProtoConverter.BoundConverter;
88
import io.substrait.extension.ExtensionCollector;
9+
import io.substrait.plan.Plan;
910
import io.substrait.proto.AggregateFunction;
1011
import io.substrait.proto.AggregateRel;
1112
import io.substrait.proto.ConsistentPartitionWindowRel;
@@ -24,6 +25,7 @@
2425
import io.substrait.proto.ReadRel;
2526
import io.substrait.proto.Rel;
2627
import io.substrait.proto.RelCommon;
28+
import io.substrait.proto.RelRoot;
2729
import io.substrait.proto.SetRel;
2830
import io.substrait.proto.SortField;
2931
import io.substrait.proto.SortRel;
@@ -59,6 +61,13 @@ public TypeProtoConverter getTypeProtoConverter() {
5961
return this.typeProtoConverter;
6062
}
6163

64+
public io.substrait.proto.RelRoot toProto(Plan.Root relRoot) {
65+
return RelRoot.newBuilder()
66+
.setInput(toProto(relRoot.getInput()))
67+
.addAllNames(relRoot.getNames())
68+
.build();
69+
}
70+
6271
public io.substrait.proto.Rel toProto(io.substrait.relation.Rel rel) {
6372
return rel.accept(this);
6473
}

isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,9 @@ private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogRea
6666
plan.addRelations(
6767
PlanRel.newBuilder()
6868
.setRoot(
69-
io.substrait.proto.RelRoot.newBuilder()
70-
.setInput(
71-
SubstraitRelVisitor.convert(
72-
root, EXTENSION_COLLECTION, featureBoard)
73-
.accept(relProtoConverter))
74-
.addAllNames(
75-
TypeConverter.DEFAULT
76-
.toNamedStruct(root.validatedRowType)
77-
.names())));
69+
relProtoConverter.toProto(
70+
SubstraitRelVisitor.convert(
71+
root, EXTENSION_COLLECTION, featureBoard))));
7872
});
7973
functionCollector.addExtensionsToPlan(plan);
8074
return plan.build();

isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import io.substrait.isthmus.expression.RexExpressionConverter;
1111
import io.substrait.isthmus.expression.ScalarFunctionConverter;
1212
import io.substrait.isthmus.expression.WindowFunctionConverter;
13+
import io.substrait.plan.Plan;
1314
import io.substrait.relation.Aggregate;
1415
import io.substrait.relation.Cross;
1516
import io.substrait.relation.EmptyScan;
@@ -379,20 +380,32 @@ public List<Rel> apply(List<RelNode> inputs) {
379380
.collect(java.util.stream.Collectors.toList());
380381
}
381382

382-
public static Rel convert(RelRoot root, SimpleExtension.ExtensionCollection extensions) {
383-
return convert(root.rel, extensions, FEATURES_DEFAULT);
383+
public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollection extensions) {
384+
return convert(relRoot, extensions, FEATURES_DEFAULT);
384385
}
385386

386-
public static Rel convert(
387-
RelRoot root, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
388-
return convert(root.rel, extensions, features);
387+
public static Plan.Root convert(
388+
RelRoot relRoot, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
389+
SubstraitRelVisitor visitor =
390+
new SubstraitRelVisitor(relRoot.rel.getCluster().getTypeFactory(), extensions, features);
391+
visitor.popFieldAccessDepthMap(relRoot.rel);
392+
Rel rel = visitor.apply(relRoot.project());
393+
394+
// Avoid using the names from relRoot.validatedRowType because if there are
395+
// nested types (i.e ROW, MAP, etc) the typeConverter will pad names correctly
396+
List<String> names = visitor.typeConverter.toNamedStruct(relRoot.validatedRowType).names();
397+
return Plan.Root.builder().input(rel).names(names).build();
389398
}
390399

391-
private static Rel convert(
392-
RelNode rel, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
400+
public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection extensions) {
401+
return convert(relNode, extensions, FEATURES_DEFAULT);
402+
}
403+
404+
public static Rel convert(
405+
RelNode relNode, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
393406
SubstraitRelVisitor visitor =
394-
new SubstraitRelVisitor(rel.getCluster().getTypeFactory(), extensions, features);
395-
visitor.popFieldAccessDepthMap(rel);
396-
return visitor.apply(rel);
407+
new SubstraitRelVisitor(relNode.getCluster().getTypeFactory(), extensions, features);
408+
visitor.popFieldAccessDepthMap(relNode);
409+
return visitor.apply(relNode);
397410
}
398411
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package io.substrait.isthmus;
2+
3+
import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
4+
import static org.junit.jupiter.api.Assertions.assertEquals;
5+
6+
import io.substrait.plan.Plan;
7+
import io.substrait.relation.NamedScan;
8+
import java.util.List;
9+
import org.junit.jupiter.api.Test;
10+
11+
public class NameRoundtripTest extends PlanTestBase {
12+
13+
@Test
14+
void preserveNamesFromSql() throws Exception {
15+
List<String> creates = List.of("CREATE TABLE foo(a BIGINT, b BIGINT)");
16+
17+
SqlToSubstrait s = new SqlToSubstrait();
18+
var substraitToCalcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory);
19+
20+
String query = """
21+
SELECT "a", "B" FROM foo GROUP BY a, b
22+
""";
23+
List<String> expectedNames = List.of("a", "B");
24+
25+
List<org.apache.calcite.rel.RelRoot> calciteRelRoots = s.sqlToRelNode(query, creates);
26+
assertEquals(1, calciteRelRoots.size());
27+
28+
org.apache.calcite.rel.RelRoot calciteRelRoot1 = calciteRelRoots.get(0);
29+
assertEquals(expectedNames, calciteRelRoot1.validatedRowType.getFieldNames());
30+
31+
io.substrait.plan.Plan.Root substraitRelRoot =
32+
SubstraitRelVisitor.convert(calciteRelRoot1, EXTENSION_COLLECTION);
33+
assertEquals(expectedNames, substraitRelRoot.getNames());
34+
35+
org.apache.calcite.rel.RelRoot calciteRelRoot2 = substraitToCalcite.convert(substraitRelRoot);
36+
assertEquals(expectedNames, calciteRelRoot2.validatedRowType.getFieldNames());
37+
}
38+
39+
@Test
40+
void preserveNamesFromSubstrait() {
41+
NamedScan rel =
42+
substraitBuilder.namedScan(
43+
List.of("foo"),
44+
List.of("i64", "struct", "struct0", "struct1"),
45+
List.of(R.I64, R.struct(R.FP64, R.STRING)));
46+
47+
Plan.Root planRoot =
48+
Plan.Root.builder().input(rel).names(List.of("i", "s", "s0", "s1")).build();
49+
assertFullRoundTrip(planRoot);
50+
}
51+
}

isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java

Lines changed: 71 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
import io.substrait.type.Type;
2121
import io.substrait.type.TypeCreator;
2222
import java.io.IOException;
23-
import java.util.ArrayList;
2423
import java.util.Arrays;
2524
import java.util.List;
2625
import org.apache.calcite.rel.RelNode;
2726
import org.apache.calcite.rel.RelRoot;
2827
import org.apache.calcite.rel.type.RelDataType;
2928
import org.apache.calcite.rel.type.RelDataTypeFactory;
3029
import org.apache.calcite.rex.RexBuilder;
31-
import org.apache.calcite.sql.SqlKind;
3230
import org.apache.calcite.sql.parser.SqlParseException;
3331
import org.apache.calcite.tools.RelBuilder;
3432
import org.junit.jupiter.api.Assertions;
@@ -72,8 +70,9 @@ protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s, List<Stri
7270
var rootRels = s.sqlToRelNode(query, creates);
7371
assertEquals(rootRels.size(), plan.getRoots().size());
7472
for (int i = 0; i < rootRels.size(); i++) {
75-
var rootRel = SubstraitRelVisitor.convert(rootRels.get(i), EXTENSION_COLLECTION);
76-
assertEquals(rootRel.getRecordType(), plan.getRoots().get(i).getInput().getRecordType());
73+
Plan.Root rootRel = SubstraitRelVisitor.convert(rootRels.get(i), EXTENSION_COLLECTION);
74+
assertEquals(
75+
rootRel.getInput().getRecordType(), plan.getRoots().get(i).getInput().getRecordType());
7776
}
7877
return plan;
7978
}
@@ -85,38 +84,36 @@ protected void assertPlanRoundtrip(Plan plan) {
8584
assertEquals(protoPlan1, protoPlan2);
8685
}
8786

88-
protected List<RelNode> assertSqlSubstraitRelRoundTrip(String query) throws Exception {
87+
protected RelRoot assertSqlSubstraitRelRoundTrip(String query) throws Exception {
8988
return assertSqlSubstraitRelRoundTrip(query, tpchSchemaCreateStatements());
9089
}
9190

92-
protected List<RelNode> assertSqlSubstraitRelRoundTrip(String query, List<String> creates)
91+
protected RelRoot assertSqlSubstraitRelRoundTrip(String query, List<String> creates)
9392
throws Exception {
9493
// sql <--> substrait round trip test.
9594
// Assert (sql -> calcite -> substrait) and (sql -> substrait -> calcite -> substrait) are same.
9695
// Return list of sql -> Substrait rel -> Calcite rel.
97-
List<RelNode> relNodeList = new ArrayList<>();
9896

9997
var substraitToCalcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory);
10098

10199
SqlToSubstrait s = new SqlToSubstrait();
102100

103101
// 1. SQL -> Calcite RelRoot
104-
for (RelRoot relRoot : s.sqlToRelNode(query, creates)) {
105-
// 2. Calcite RelRoot -> Substrait Rel
106-
Rel pojo1 = SubstraitRelVisitor.convert(relRoot, EXTENSION_COLLECTION);
102+
List<RelRoot> relRoots = s.sqlToRelNode(query, creates);
103+
assertEquals(1, relRoots.size());
104+
RelRoot relRoot1 = relRoots.get(0);
107105

108-
// 3. Substrait Rel -> Calcite RelNode
109-
RelNode relNode = substraitToCalcite.convert(pojo1);
106+
// 2. Calcite RelRoot -> Substrait Rel
107+
Plan.Root pojo1 = SubstraitRelVisitor.convert(relRoot1, EXTENSION_COLLECTION);
110108

111-
relNodeList.add(relNode);
109+
// 3. Substrait Rel -> Calcite RelNode
110+
RelRoot relRoot2 = substraitToCalcite.convert(pojo1);
112111

113-
// 4. Calcite RelNode -> Substrait Rel
114-
Rel pojo2 =
115-
SubstraitRelVisitor.convert(RelRoot.of(relNode, SqlKind.SELECT), EXTENSION_COLLECTION);
112+
// 4. Calcite RelNode -> Substrait Rel
113+
Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, EXTENSION_COLLECTION);
116114

117-
Assertions.assertEquals(pojo1, pojo2);
118-
}
119-
return relNodeList;
115+
Assertions.assertEquals(pojo1, pojo2);
116+
return relRoot2;
120117
}
121118

122119
@Beta
@@ -140,37 +137,36 @@ protected void assertFullRoundTrip(String query) throws IOException, SqlParseExc
140137
protected void assertFullRoundTrip(String sqlQuery, List<String> createStatements)
141138
throws SqlParseException {
142139
SqlToSubstrait sqlConverter = new SqlToSubstrait();
143-
List<RelRoot> relRoots = sqlConverter.sqlToRelNode(sqlQuery, createStatements);
140+
ExtensionCollector extensionCollector = new ExtensionCollector();
144141

145-
for (RelRoot calcite1 : relRoots) {
146-
var extensionCollector = new ExtensionCollector();
142+
// SQL -> Calcite 1
143+
List<RelRoot> relRoots = sqlConverter.sqlToRelNode(sqlQuery, createStatements);
144+
assertEquals(1, relRoots.size());
145+
RelRoot calcite1 = relRoots.get(0);
147146

148-
// Calcite 1 -> Substrait POJO 1
149-
io.substrait.relation.Rel pojo1 = SubstraitRelVisitor.convert(calcite1, EXTENSION_COLLECTION);
147+
// Calcite 1 -> Substrait POJO 1
148+
Plan.Root pojo1 = SubstraitRelVisitor.convert(calcite1, EXTENSION_COLLECTION);
150149

151-
// Substrait POJO 1 -> Substrait Proto
152-
io.substrait.proto.Rel proto = new RelProtoConverter(extensionCollector).toProto(pojo1);
150+
// Substrait POJO 1 -> Substrait Proto
151+
io.substrait.proto.RelRoot proto = new RelProtoConverter(extensionCollector).toProto(pojo1);
153152

154-
// Substrait Proto -> Substrait Pojo 2
155-
io.substrait.relation.Rel pojo2 =
156-
new ProtoRelConverter(extensionCollector, EXTENSION_COLLECTION).from(proto);
153+
// Substrait Proto -> Substrait Pojo 2
154+
Plan.Root pojo2 = new ProtoRelConverter(extensionCollector, EXTENSION_COLLECTION).from(proto);
157155

158-
// Verify that POJOs are the same
159-
assertEquals(pojo1, pojo2);
156+
// Verify that POJOs are the same
157+
assertEquals(pojo1, pojo2);
160158

161-
// Substrait POJO 2 -> Calcite 2
162-
RelNode calcite2 = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2);
163-
// It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to
164-
// do so
165-
assertNotNull(calcite2);
159+
// Substrait POJO 2 -> Calcite 2
160+
RelRoot calcite2 = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2);
161+
// It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to
162+
// do so
163+
assertNotNull(calcite2);
166164

167-
// Calcite 2 -> Substrait POJO 3
168-
io.substrait.relation.Rel pojo3 =
169-
SubstraitRelVisitor.convert(RelRoot.of(calcite2, calcite1.kind), EXTENSION_COLLECTION);
165+
// Calcite 2 -> Substrait POJO 3
166+
Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite2, EXTENSION_COLLECTION);
170167

171-
// Verify that POJOs are the same
172-
assertEquals(pojo1, pojo3);
173-
}
168+
// Verify that POJOs are the same
169+
assertEquals(pojo1, pojo3);
174170
}
175171

176172
/**
@@ -182,6 +178,7 @@ protected void assertFullRoundTrip(String sqlQuery, List<String> createStatement
182178
* </ul>
183179
*/
184180
protected void assertFullRoundTrip(Rel pojo1) {
181+
// TODO: reuse the Plan.Root based assertFullRoundTrip by generating names
185182
var extensionCollector = new ExtensionCollector();
186183

187184
// Substrait POJO 1 -> Substrait Proto
@@ -198,9 +195,38 @@ protected void assertFullRoundTrip(Rel pojo1) {
198195
RelNode calcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2);
199196

200197
// Calcite -> Substrait POJO 3
201-
io.substrait.relation.Rel pojo3 =
202-
// SqlKind.SELECT is used because the majority of our tests are SELECT queries
203-
SubstraitRelVisitor.convert(RelRoot.of(calcite, SqlKind.SELECT), EXTENSION_COLLECTION);
198+
io.substrait.relation.Rel pojo3 = SubstraitRelVisitor.convert(calcite, EXTENSION_COLLECTION);
199+
200+
// Verify that POJOs are the same
201+
assertEquals(pojo1, pojo3);
202+
}
203+
204+
/**
205+
* Verifies that the given POJO can be converted:
206+
*
207+
* <ul>
208+
* <li>From POJO to Proto and back
209+
* <li>From POJO to Calcite and back
210+
* </ul>
211+
*/
212+
protected void assertFullRoundTrip(Plan.Root pojo1) {
213+
var extensionCollector = new ExtensionCollector();
214+
215+
// Substrait POJO 1 -> Substrait Proto
216+
io.substrait.proto.RelRoot proto = new RelProtoConverter(extensionCollector).toProto(pojo1);
217+
218+
// Substrait Proto -> Substrait Pojo 2
219+
io.substrait.plan.Plan.Root pojo2 =
220+
new ProtoRelConverter(extensionCollector, EXTENSION_COLLECTION).from(proto);
221+
222+
// Verify that POJOs are the same
223+
assertEquals(pojo1, pojo2);
224+
225+
// Substrait POJO 2 -> Calcite
226+
RelRoot calcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2);
227+
228+
// Calcite -> Substrait POJO 3
229+
io.substrait.plan.Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite, EXTENSION_COLLECTION);
204230

205231
// Verify that POJOs are the same
206232
assertEquals(pojo1, pojo3);

isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package io.substrait.isthmus;
22

3-
import static org.junit.jupiter.api.Assertions.assertTrue;
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
45

56
import io.substrait.isthmus.utils.SetUtils;
67
import io.substrait.relation.Set;
7-
import java.util.List;
88
import org.apache.calcite.rel.RelNode;
9+
import org.apache.calcite.rel.RelRoot;
910
import org.apache.calcite.rel.logical.LogicalAggregate;
1011
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
1112
import org.junit.jupiter.api.Test;
@@ -142,15 +143,15 @@ public void tpch_q1_variant() throws Exception {
142143
@Test
143144
public void simpleTestApproxCountDistinct() throws Exception {
144145
String query = "select approx_count_distinct(l_tax) from lineitem";
145-
List<RelNode> relNodeList = assertSqlSubstraitRelRoundTrip(query);
146+
RelRoot relRoot = assertSqlSubstraitRelRoundTrip(query);
147+
RelNode relNode = relRoot.project();
146148

147149
// Assert converted Calcite RelNode has `approx_count_distinct`
148-
RelNode relNode = relNodeList.get(0);
149-
assertTrue(relNode instanceof LogicalAggregate);
150+
assertInstanceOf(LogicalAggregate.class, relNode);
150151
LogicalAggregate aggregate = (LogicalAggregate) relNode;
151-
assertTrue(
152-
aggregate.getAggCallList().get(0).getAggregation()
153-
== SqlStdOperatorTable.APPROX_COUNT_DISTINCT);
152+
assertEquals(
153+
SqlStdOperatorTable.APPROX_COUNT_DISTINCT,
154+
aggregate.getAggCallList().get(0).getAggregation());
154155
}
155156

156157
@Test

0 commit comments

Comments
 (0)