20
20
import io .substrait .type .Type ;
21
21
import io .substrait .type .TypeCreator ;
22
22
import java .io .IOException ;
23
- import java .util .ArrayList ;
24
23
import java .util .Arrays ;
25
24
import java .util .List ;
26
25
import org .apache .calcite .rel .RelNode ;
27
26
import org .apache .calcite .rel .RelRoot ;
28
27
import org .apache .calcite .rel .type .RelDataType ;
29
28
import org .apache .calcite .rel .type .RelDataTypeFactory ;
30
29
import org .apache .calcite .rex .RexBuilder ;
31
- import org .apache .calcite .sql .SqlKind ;
32
30
import org .apache .calcite .sql .parser .SqlParseException ;
33
31
import org .apache .calcite .tools .RelBuilder ;
34
32
import org .junit .jupiter .api .Assertions ;
@@ -72,8 +70,9 @@ protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s, List<Stri
72
70
var rootRels = s .sqlToRelNode (query , creates );
73
71
assertEquals (rootRels .size (), plan .getRoots ().size ());
74
72
for (int i = 0 ; i < rootRels .size (); i ++) {
75
- var rootRel = SubstraitRelVisitor .convert (rootRels .get (i ), EXTENSION_COLLECTION );
76
- assertEquals (rootRel .getRecordType (), plan .getRoots ().get (i ).getInput ().getRecordType ());
73
+ Plan .Root rootRel = SubstraitRelVisitor .convert (rootRels .get (i ), EXTENSION_COLLECTION );
74
+ assertEquals (
75
+ rootRel .getInput ().getRecordType (), plan .getRoots ().get (i ).getInput ().getRecordType ());
77
76
}
78
77
return plan ;
79
78
}
@@ -85,38 +84,36 @@ protected void assertPlanRoundtrip(Plan plan) {
85
84
assertEquals (protoPlan1 , protoPlan2 );
86
85
}
87
86
88
- protected List < RelNode > assertSqlSubstraitRelRoundTrip (String query ) throws Exception {
87
+ protected RelRoot assertSqlSubstraitRelRoundTrip (String query ) throws Exception {
89
88
return assertSqlSubstraitRelRoundTrip (query , tpchSchemaCreateStatements ());
90
89
}
91
90
92
- protected List < RelNode > assertSqlSubstraitRelRoundTrip (String query , List <String > creates )
91
+ protected RelRoot assertSqlSubstraitRelRoundTrip (String query , List <String > creates )
93
92
throws Exception {
94
93
// sql <--> substrait round trip test.
95
94
// Assert (sql -> calcite -> substrait) and (sql -> substrait -> calcite -> substrait) are same.
96
95
// Return list of sql -> Substrait rel -> Calcite rel.
97
- List <RelNode > relNodeList = new ArrayList <>();
98
96
99
97
var substraitToCalcite = new SubstraitToCalcite (EXTENSION_COLLECTION , typeFactory );
100
98
101
99
SqlToSubstrait s = new SqlToSubstrait ();
102
100
103
101
// 1. SQL -> Calcite RelRoot
104
- for ( RelRoot relRoot : s .sqlToRelNode (query , creates )) {
105
- // 2. Calcite RelRoot -> Substrait Rel
106
- Rel pojo1 = SubstraitRelVisitor . convert ( relRoot , EXTENSION_COLLECTION );
102
+ List < RelRoot > relRoots = s .sqlToRelNode (query , creates );
103
+ assertEquals ( 1 , relRoots . size ());
104
+ RelRoot relRoot1 = relRoots . get ( 0 );
107
105
108
- // 3. Substrait Rel -> Calcite RelNode
109
- RelNode relNode = substraitToCalcite .convert (pojo1 );
106
+ // 2. Calcite RelRoot -> Substrait Rel
107
+ Plan . Root pojo1 = SubstraitRelVisitor .convert (relRoot1 , EXTENSION_COLLECTION );
110
108
111
- relNodeList .add (relNode );
109
+ // 3. Substrait Rel -> Calcite RelNode
110
+ RelRoot relRoot2 = substraitToCalcite .convert (pojo1 );
112
111
113
- // 4. Calcite RelNode -> Substrait Rel
114
- Rel pojo2 =
115
- SubstraitRelVisitor .convert (RelRoot .of (relNode , SqlKind .SELECT ), EXTENSION_COLLECTION );
112
+ // 4. Calcite RelNode -> Substrait Rel
113
+ Plan .Root pojo2 = SubstraitRelVisitor .convert (relRoot2 , EXTENSION_COLLECTION );
116
114
117
- Assertions .assertEquals (pojo1 , pojo2 );
118
- }
119
- return relNodeList ;
115
+ Assertions .assertEquals (pojo1 , pojo2 );
116
+ return relRoot2 ;
120
117
}
121
118
122
119
@ Beta
@@ -140,37 +137,36 @@ protected void assertFullRoundTrip(String query) throws IOException, SqlParseExc
140
137
protected void assertFullRoundTrip (String sqlQuery , List <String > createStatements )
141
138
throws SqlParseException {
142
139
SqlToSubstrait sqlConverter = new SqlToSubstrait ();
143
- List < RelRoot > relRoots = sqlConverter . sqlToRelNode ( sqlQuery , createStatements );
140
+ ExtensionCollector extensionCollector = new ExtensionCollector ( );
144
141
145
- for (RelRoot calcite1 : relRoots ) {
146
- var extensionCollector = new ExtensionCollector ();
142
+ // SQL -> Calcite 1
143
+ List <RelRoot > relRoots = sqlConverter .sqlToRelNode (sqlQuery , createStatements );
144
+ assertEquals (1 , relRoots .size ());
145
+ RelRoot calcite1 = relRoots .get (0 );
147
146
148
- // Calcite 1 -> Substrait POJO 1
149
- io . substrait . relation . Rel pojo1 = SubstraitRelVisitor .convert (calcite1 , EXTENSION_COLLECTION );
147
+ // Calcite 1 -> Substrait POJO 1
148
+ Plan . Root pojo1 = SubstraitRelVisitor .convert (calcite1 , EXTENSION_COLLECTION );
150
149
151
- // Substrait POJO 1 -> Substrait Proto
152
- io .substrait .proto .Rel proto = new RelProtoConverter (extensionCollector ).toProto (pojo1 );
150
+ // Substrait POJO 1 -> Substrait Proto
151
+ io .substrait .proto .RelRoot proto = new RelProtoConverter (extensionCollector ).toProto (pojo1 );
153
152
154
- // Substrait Proto -> Substrait Pojo 2
155
- io .substrait .relation .Rel pojo2 =
156
- new ProtoRelConverter (extensionCollector , EXTENSION_COLLECTION ).from (proto );
153
+ // Substrait Proto -> Substrait Pojo 2
154
+ Plan .Root pojo2 = new ProtoRelConverter (extensionCollector , EXTENSION_COLLECTION ).from (proto );
157
155
158
- // Verify that POJOs are the same
159
- assertEquals (pojo1 , pojo2 );
156
+ // Verify that POJOs are the same
157
+ assertEquals (pojo1 , pojo2 );
160
158
161
- // Substrait POJO 2 -> Calcite 2
162
- RelNode calcite2 = new SubstraitToCalcite (EXTENSION_COLLECTION , typeFactory ).convert (pojo2 );
163
- // It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to
164
- // do so
165
- assertNotNull (calcite2 );
159
+ // Substrait POJO 2 -> Calcite 2
160
+ RelRoot calcite2 = new SubstraitToCalcite (EXTENSION_COLLECTION , typeFactory ).convert (pojo2 );
161
+ // It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to
162
+ // do so
163
+ assertNotNull (calcite2 );
166
164
167
- // Calcite 2 -> Substrait POJO 3
168
- io .substrait .relation .Rel pojo3 =
169
- SubstraitRelVisitor .convert (RelRoot .of (calcite2 , calcite1 .kind ), EXTENSION_COLLECTION );
165
+ // Calcite 2 -> Substrait POJO 3
166
+ Plan .Root pojo3 = SubstraitRelVisitor .convert (calcite2 , EXTENSION_COLLECTION );
170
167
171
- // Verify that POJOs are the same
172
- assertEquals (pojo1 , pojo3 );
173
- }
168
+ // Verify that POJOs are the same
169
+ assertEquals (pojo1 , pojo3 );
174
170
}
175
171
176
172
/**
@@ -182,6 +178,7 @@ protected void assertFullRoundTrip(String sqlQuery, List<String> createStatement
182
178
* </ul>
183
179
*/
184
180
protected void assertFullRoundTrip (Rel pojo1 ) {
181
+ // TODO: reuse the Plan.Root based assertFullRoundTrip by generating names
185
182
var extensionCollector = new ExtensionCollector ();
186
183
187
184
// Substrait POJO 1 -> Substrait Proto
@@ -198,9 +195,38 @@ protected void assertFullRoundTrip(Rel pojo1) {
198
195
RelNode calcite = new SubstraitToCalcite (EXTENSION_COLLECTION , typeFactory ).convert (pojo2 );
199
196
200
197
// Calcite -> Substrait POJO 3
201
- io .substrait .relation .Rel pojo3 =
202
- // SqlKind.SELECT is used because the majority of our tests are SELECT queries
203
- SubstraitRelVisitor .convert (RelRoot .of (calcite , SqlKind .SELECT ), EXTENSION_COLLECTION );
198
+ io .substrait .relation .Rel pojo3 = SubstraitRelVisitor .convert (calcite , EXTENSION_COLLECTION );
199
+
200
+ // Verify that POJOs are the same
201
+ assertEquals (pojo1 , pojo3 );
202
+ }
203
+
204
+ /**
205
+ * Verifies that the given POJO can be converted:
206
+ *
207
+ * <ul>
208
+ * <li>From POJO to Proto and back
209
+ * <li>From POJO to Calcite and back
210
+ * </ul>
211
+ */
212
+ protected void assertFullRoundTrip (Plan .Root pojo1 ) {
213
+ var extensionCollector = new ExtensionCollector ();
214
+
215
+ // Substrait POJO 1 -> Substrait Proto
216
+ io .substrait .proto .RelRoot proto = new RelProtoConverter (extensionCollector ).toProto (pojo1 );
217
+
218
+ // Substrait Proto -> Substrait Pojo 2
219
+ io .substrait .plan .Plan .Root pojo2 =
220
+ new ProtoRelConverter (extensionCollector , EXTENSION_COLLECTION ).from (proto );
221
+
222
+ // Verify that POJOs are the same
223
+ assertEquals (pojo1 , pojo2 );
224
+
225
+ // Substrait POJO 2 -> Calcite
226
+ RelRoot calcite = new SubstraitToCalcite (EXTENSION_COLLECTION , typeFactory ).convert (pojo2 );
227
+
228
+ // Calcite -> Substrait POJO 3
229
+ io .substrait .plan .Plan .Root pojo3 = SubstraitRelVisitor .convert (calcite , EXTENSION_COLLECTION );
204
230
205
231
// Verify that POJOs are the same
206
232
assertEquals (pojo1 , pojo3 );
0 commit comments