Skip to content

Commit e0a0fda

Browse files
authored
fix(core): convert hints in ProtoRelConverter (#420)
1 parent b6053f5 commit e0a0fda

File tree

2 files changed

+73
-22
lines changed

2 files changed

+73
-22
lines changed

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ protected NamedWrite newNamedWrite(WriteRel rel) {
179179

180180
builder
181181
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
182-
.remap(optionalRelmap(rel.getCommon()));
182+
.remap(optionalRelmap(rel.getCommon()))
183+
.hint(optionalHint(rel.getCommon()));
183184
return builder.build();
184185
}
185186

@@ -197,7 +198,8 @@ protected Rel newExtensionWrite(WriteRel rel) {
197198

198199
builder
199200
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
200-
.remap(optionalRelmap(rel.getCommon()));
201+
.remap(optionalRelmap(rel.getCommon()))
202+
.hint(optionalHint(rel.getCommon()));
201203
return builder.build();
202204
}
203205

@@ -225,6 +227,7 @@ protected NamedDdl newNamedDdl(DdlRel rel) {
225227
.viewDefinition(optionalViewDefinition(rel))
226228
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
227229
.remap(optionalRelmap(rel.getCommon()))
230+
.hint(optionalHint(rel.getCommon()))
228231
.build();
229232
}
230233

@@ -240,6 +243,7 @@ protected ExtensionDdl newExtensionDdl(DdlRel rel) {
240243
.viewDefinition(optionalViewDefinition(rel))
241244
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
242245
.remap(optionalRelmap(rel.getCommon()))
246+
.hint(optionalHint(rel.getCommon()))
243247
.build();
244248
}
245249

@@ -266,10 +270,10 @@ protected Filter newFilter(FilterRel rel) {
266270
.condition(
267271
new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this)
268272
.from(rel.getCondition()));
269-
270273
builder
271274
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
272-
.remap(optionalRelmap(rel.getCommon()));
275+
.remap(optionalRelmap(rel.getCommon()))
276+
.hint(optionalHint(rel.getCommon()));
273277
if (rel.hasAdvancedExtension()) {
274278
builder.extension(advancedExtension(rel.getAdvancedExtension()));
275279
}
@@ -317,7 +321,8 @@ protected EmptyScan newEmptyScan(ReadRel rel) {
317321

318322
builder
319323
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
320-
.remap(optionalRelmap(rel.getCommon()));
324+
.remap(optionalRelmap(rel.getCommon()))
325+
.hint(optionalHint(rel.getCommon()));
321326
if (rel.hasAdvancedExtension()) {
322327
builder.extension(advancedExtension(rel.getAdvancedExtension()));
323328
}
@@ -329,7 +334,8 @@ protected ExtensionLeaf newExtensionLeaf(ExtensionLeafRel rel) {
329334
var builder =
330335
ExtensionLeaf.from(detail)
331336
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
332-
.remap(optionalRelmap(rel.getCommon()));
337+
.remap(optionalRelmap(rel.getCommon()))
338+
.hint(optionalHint(rel.getCommon()));
333339
return builder.build();
334340
}
335341

@@ -339,7 +345,8 @@ protected ExtensionSingle newExtensionSingle(ExtensionSingleRel rel) {
339345
var builder =
340346
ExtensionSingle.from(detail, input)
341347
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
342-
.remap(optionalRelmap(rel.getCommon()));
348+
.remap(optionalRelmap(rel.getCommon()))
349+
.hint(optionalHint(rel.getCommon()));
343350
return builder.build();
344351
}
345352

@@ -349,7 +356,8 @@ protected ExtensionMulti newExtensionMulti(ExtensionMultiRel rel) {
349356
var builder =
350357
ExtensionMulti.from(detail, inputs)
351358
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
352-
.remap(optionalRelmap(rel.getCommon()));
359+
.remap(optionalRelmap(rel.getCommon()))
360+
.hint(optionalHint(rel.getCommon()));
353361
if (rel.hasDetail()) {
354362
builder.detail(detailFromExtensionMultiRel(rel.getDetail()));
355363
}
@@ -379,7 +387,8 @@ protected NamedScan newNamedScan(ReadRel rel) {
379387

380388
builder
381389
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
382-
.remap(optionalRelmap(rel.getCommon()));
390+
.remap(optionalRelmap(rel.getCommon()))
391+
.hint(optionalHint(rel.getCommon()));
383392
if (rel.hasAdvancedExtension()) {
384393
builder.extension(advancedExtension(rel.getAdvancedExtension()));
385394
}
@@ -393,7 +402,8 @@ protected ExtensionTable newExtensionTable(ReadRel rel) {
393402

394403
builder
395404
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
396-
.remap(optionalRelmap(rel.getCommon()));
405+
.remap(optionalRelmap(rel.getCommon()))
406+
.hint(optionalHint(rel.getCommon()));
397407
if (rel.hasAdvancedExtension()) {
398408
builder.extension(advancedExtension(rel.getAdvancedExtension()));
399409
}
@@ -427,7 +437,8 @@ protected LocalFiles newLocalFiles(ReadRel rel) {
427437

428438
builder
429439
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
430-
.remap(optionalRelmap(rel.getCommon()));
440+
.remap(optionalRelmap(rel.getCommon()))
441+
.hint(optionalHint(rel.getCommon()));
431442
if (rel.hasAdvancedExtension()) {
432443
builder.extension(advancedExtension(rel.getAdvancedExtension()));
433444
}
@@ -503,7 +514,8 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) {
503514

504515
builder
505516
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
506-
.remap(optionalRelmap(rel.getCommon()));
517+
.remap(optionalRelmap(rel.getCommon()))
518+
.hint(optionalHint(rel.getCommon()));
507519
if (rel.hasAdvancedExtension()) {
508520
builder.extension(advancedExtension(rel.getAdvancedExtension()));
509521
}
@@ -521,7 +533,8 @@ protected Fetch newFetch(FetchRel rel) {
521533

522534
builder
523535
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
524-
.remap(optionalRelmap(rel.getCommon()));
536+
.remap(optionalRelmap(rel.getCommon()))
537+
.hint(optionalHint(rel.getCommon()));
525538
if (rel.hasAdvancedExtension()) {
526539
builder.extension(advancedExtension(rel.getAdvancedExtension()));
527540
}
@@ -619,7 +632,8 @@ protected Aggregate newAggregate(AggregateRel rel) {
619632

620633
builder
621634
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
622-
.remap(optionalRelmap(rel.getCommon()));
635+
.remap(optionalRelmap(rel.getCommon()))
636+
.hint(optionalHint(rel.getCommon()));
623637
if (rel.hasAdvancedExtension()) {
624638
builder.extension(advancedExtension(rel.getAdvancedExtension()));
625639
}
@@ -644,7 +658,8 @@ protected Sort newSort(SortRel rel) {
644658

645659
builder
646660
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
647-
.remap(optionalRelmap(rel.getCommon()));
661+
.remap(optionalRelmap(rel.getCommon()))
662+
.hint(optionalHint(rel.getCommon()));
648663
if (rel.hasAdvancedExtension()) {
649664
builder.extension(advancedExtension(rel.getAdvancedExtension()));
650665
}
@@ -670,7 +685,8 @@ protected Join newJoin(JoinRel rel) {
670685

671686
builder
672687
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
673-
.remap(optionalRelmap(rel.getCommon()));
688+
.remap(optionalRelmap(rel.getCommon()))
689+
.hint(optionalHint(rel.getCommon()));
674690
if (rel.hasAdvancedExtension()) {
675691
builder.extension(advancedExtension(rel.getAdvancedExtension()));
676692
}
@@ -700,7 +716,8 @@ protected Set newSet(SetRel rel) {
700716

701717
builder
702718
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
703-
.remap(optionalRelmap(rel.getCommon()));
719+
.remap(optionalRelmap(rel.getCommon()))
720+
.hint(optionalHint(rel.getCommon()));
704721
if (rel.hasAdvancedExtension()) {
705722
builder.extension(advancedExtension(rel.getAdvancedExtension()));
706723
}
@@ -729,10 +746,10 @@ protected Rel newHashJoin(HashJoinRel rel) {
729746
.postJoinFilter(
730747
Optional.ofNullable(
731748
rel.hasPostJoinFilter() ? unionConverter.from(rel.getPostJoinFilter()) : null));
732-
733749
builder
734750
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
735-
.remap(optionalRelmap(rel.getCommon()));
751+
.remap(optionalRelmap(rel.getCommon()))
752+
.hint(optionalHint(rel.getCommon()));
736753
if (rel.hasAdvancedExtension()) {
737754
builder.extension(advancedExtension(rel.getAdvancedExtension()));
738755
}
@@ -764,7 +781,8 @@ protected Rel newMergeJoin(MergeJoinRel rel) {
764781

765782
builder
766783
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
767-
.remap(optionalRelmap(rel.getCommon()));
784+
.remap(optionalRelmap(rel.getCommon()))
785+
.hint(optionalHint(rel.getCommon()));
768786
if (rel.hasAdvancedExtension()) {
769787
builder.extension(advancedExtension(rel.getAdvancedExtension()));
770788
}
@@ -791,7 +809,8 @@ protected NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) {
791809

792810
builder
793811
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
794-
.remap(optionalRelmap(rel.getCommon()));
812+
.remap(optionalRelmap(rel.getCommon()))
813+
.hint(optionalHint(rel.getCommon()));
795814
if (rel.hasAdvancedExtension()) {
796815
builder.extension(advancedExtension(rel.getAdvancedExtension()));
797816
}
@@ -827,7 +846,8 @@ protected ConsistentPartitionWindow newConsistentPartitionWindow(
827846

828847
builder
829848
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
830-
.remap(optionalRelmap(rel.getCommon()));
849+
.remap(optionalRelmap(rel.getCommon()))
850+
.hint(optionalHint(rel.getCommon()));
831851
if (rel.hasAdvancedExtension()) {
832852
builder.extension(advancedExtension(rel.getAdvancedExtension()));
833853
}

core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import io.substrait.TestBase;
88
import io.substrait.extension.AdvancedExtension;
9+
import io.substrait.hint.Hint;
910
import io.substrait.relation.utils.StringHolder;
11+
import java.util.Arrays;
1012
import java.util.Collections;
1113
import org.junit.jupiter.api.Nested;
1214
import org.junit.jupiter.api.Test;
@@ -123,4 +125,33 @@ void extensionTable() {
123125
assertNotEquals(rel, relReturned);
124126
}
125127
}
128+
129+
/** Verify that hints are correctly transmitted in proto<->pojo */
130+
@Nested
131+
class HintsTest {
132+
133+
@Test
134+
void relWithHint() {
135+
Rel relWithHints =
136+
NamedScan.builder()
137+
.from(commonTable)
138+
.hint(Hint.builder().addOutputNames("Test hint").build())
139+
.build();
140+
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithHints);
141+
Rel relReturned = protoRelConverter.from(protoRel);
142+
assertEquals(relWithHints, relReturned);
143+
}
144+
145+
@Test
146+
void relWithHints() {
147+
Rel relWithHints =
148+
NamedScan.builder()
149+
.from(commonTable)
150+
.hint(Hint.builder().addAllOutputNames(Arrays.asList("Hint 1", "Hint 2")).build())
151+
.build();
152+
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithHints);
153+
Rel relReturned = protoRelConverter.from(protoRel);
154+
assertEquals(relWithHints, relReturned);
155+
}
156+
}
126157
}

0 commit comments

Comments
 (0)