Skip to content

Commit 3dfe665

Browse files
authored
feat(core): handle all Hint fields (#469)
1 parent 3707fe2 commit 3dfe665

File tree

4 files changed

+261
-22
lines changed

4 files changed

+261
-22
lines changed

core/src/main/java/io/substrait/hint/Hint.java

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,91 @@
11
package io.substrait.hint;
22

33
import io.substrait.proto.RelCommon;
4+
import io.substrait.relation.HasExtension;
45
import java.util.List;
56
import java.util.Optional;
67
import org.immutables.value.Value;
78

89
@Value.Immutable
9-
public abstract class Hint {
10+
public abstract class Hint implements HasExtension {
1011
public abstract Optional<String> getAlias();
1112

1213
public abstract List<String> getOutputNames();
1314

14-
public RelCommon.Hint toProto() {
15-
RelCommon.Hint.Builder builder =
16-
RelCommon.Hint.newBuilder().addAllOutputNames(getOutputNames());
17-
getAlias().ifPresent(builder::setAlias);
18-
return builder.build();
15+
public abstract Optional<Stats> getStats();
16+
17+
public abstract Optional<RuntimeConstraint> getRuntimeConstraint();
18+
19+
public abstract List<LoadedComputation> getLoadedComputations();
20+
21+
public abstract List<SavedComputation> getSavedComputations();
22+
23+
public enum ComputationType {
24+
COMPUTATION_TYPE_UNSPECIFIED(RelCommon.Hint.ComputationType.COMPUTATION_TYPE_UNSPECIFIED),
25+
COMPUTATION_TYPE_HASHTABLE(RelCommon.Hint.ComputationType.COMPUTATION_TYPE_HASHTABLE),
26+
COMPUTATION_TYPE_BLOOM_FILTER(RelCommon.Hint.ComputationType.COMPUTATION_TYPE_BLOOM_FILTER),
27+
COMPUTATION_TYPE_UNKNOWN(RelCommon.Hint.ComputationType.COMPUTATION_TYPE_UNKNOWN);
28+
29+
private final RelCommon.Hint.ComputationType proto;
30+
31+
ComputationType(RelCommon.Hint.ComputationType compType) {
32+
this.proto = compType;
33+
}
34+
35+
public RelCommon.Hint.ComputationType toProto() {
36+
return this.proto;
37+
}
38+
39+
public static ComputationType fromProto(RelCommon.Hint.ComputationType proto) {
40+
for (final ComputationType compTypePojo : values()) {
41+
if (compTypePojo.proto == proto) {
42+
return compTypePojo;
43+
}
44+
}
45+
throw new IllegalArgumentException("Unknown computation type: " + proto);
46+
}
47+
}
48+
49+
@Value.Immutable
50+
public abstract static class Stats implements HasExtension {
51+
public abstract double rowCount();
52+
53+
public abstract double recordSize();
54+
55+
public static ImmutableStats.Builder builder() {
56+
return ImmutableStats.builder();
57+
}
58+
}
59+
60+
@Value.Immutable
61+
public abstract static class SavedComputation {
62+
public abstract int computationId();
63+
64+
public abstract ComputationType computationType();
65+
66+
public static ImmutableSavedComputation.Builder builder() {
67+
return ImmutableSavedComputation.builder();
68+
}
69+
}
70+
71+
@Value.Immutable
72+
public abstract static class LoadedComputation {
73+
public abstract int computationId();
74+
75+
public abstract ComputationType computationType();
76+
77+
public static ImmutableLoadedComputation.Builder builder() {
78+
return ImmutableLoadedComputation.builder();
79+
}
80+
}
81+
82+
@Value.Immutable
83+
public abstract static class RuntimeConstraint implements HasExtension {
84+
// NOTE: marked as todo in substrait spec 0.74.0
85+
86+
public static ImmutableRuntimeConstraint.Builder builder() {
87+
return ImmutableRuntimeConstraint.builder();
88+
}
1989
}
2090

2191
public static ImmutableHint.Builder builder() {

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
import io.substrait.extension.ExtensionLookup;
77
import io.substrait.extension.SimpleExtension;
88
import io.substrait.hint.Hint;
9+
import io.substrait.hint.Hint.ComputationType;
10+
import io.substrait.hint.Hint.LoadedComputation;
11+
import io.substrait.hint.Hint.RuntimeConstraint;
12+
import io.substrait.hint.Hint.SavedComputation;
13+
import io.substrait.hint.Hint.Stats;
914
import io.substrait.plan.Plan;
1015
import io.substrait.proto.AggregateRel;
1116
import io.substrait.proto.ConsistentPartitionWindowRel;
@@ -882,14 +887,53 @@ protected static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon
882887
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);
883888
}
884889

885-
protected static Optional<Hint> optionalHint(io.substrait.proto.RelCommon relCommon) {
890+
protected Optional<Hint> optionalHint(io.substrait.proto.RelCommon relCommon) {
886891
if (!relCommon.hasHint()) return Optional.empty();
887892
io.substrait.proto.RelCommon.Hint hint = relCommon.getHint();
888893
io.substrait.hint.ImmutableHint.Builder builder =
889894
Hint.builder().addAllOutputNames(hint.getOutputNamesList());
890895
if (!hint.getAlias().isEmpty()) {
891896
builder.alias(hint.getAlias());
892897
}
898+
if (hint.hasAdvancedExtension()) {
899+
builder.extension(advancedExtension(hint.getAdvancedExtension()));
900+
}
901+
if (hint.hasStats()) {
902+
io.substrait.proto.RelCommon.Hint.Stats stats = hint.getStats();
903+
io.substrait.hint.ImmutableStats.Builder statsBuilder = Stats.builder();
904+
statsBuilder.recordSize(stats.getRecordSize()).rowCount(stats.getRowCount());
905+
if (stats.hasAdvancedExtension()) {
906+
statsBuilder.extension(advancedExtension(stats.getAdvancedExtension()));
907+
}
908+
builder.stats(statsBuilder.build());
909+
}
910+
if (hint.hasConstraint()) {
911+
io.substrait.proto.RelCommon.Hint.RuntimeConstraint constraint = hint.getConstraint();
912+
io.substrait.hint.ImmutableRuntimeConstraint.Builder constraintBuilder =
913+
RuntimeConstraint.builder();
914+
if (constraint.hasAdvancedExtension()) {
915+
constraintBuilder.extension(advancedExtension(constraint.getAdvancedExtension()));
916+
}
917+
builder.runtimeConstraint(constraintBuilder.build());
918+
}
919+
920+
hint.getLoadedComputationsList()
921+
.forEach(
922+
loadedComp ->
923+
builder.addLoadedComputations(
924+
LoadedComputation.builder()
925+
.computationId(loadedComp.getComputationIdReference())
926+
.computationType(ComputationType.fromProto(loadedComp.getType()))
927+
.build()));
928+
hint.getSavedComputationsList()
929+
.forEach(
930+
savedComp ->
931+
builder.addSavedComputations(
932+
SavedComputation.builder()
933+
.computationId(savedComp.getComputationId())
934+
.computationType(ComputationType.fromProto(savedComp.getType()))
935+
.build()));
936+
893937
return Optional.of(builder.build());
894938
}
895939

core/src/main/java/io/substrait/relation/RelProtoConverter.java

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@
3030
import io.substrait.proto.ReadRel;
3131
import io.substrait.proto.Rel;
3232
import io.substrait.proto.RelCommon;
33+
import io.substrait.proto.RelCommon.Hint;
34+
import io.substrait.proto.RelCommon.Hint.LoadedComputation;
35+
import io.substrait.proto.RelCommon.Hint.RuntimeConstraint;
36+
import io.substrait.proto.RelCommon.Hint.SavedComputation;
37+
import io.substrait.proto.RelCommon.Hint.Stats;
3338
import io.substrait.proto.RelRoot;
3439
import io.substrait.proto.SetRel;
3540
import io.substrait.proto.SortField;
@@ -635,7 +640,48 @@ private RelCommon common(io.substrait.relation.Rel rel) {
635640
builder.setDirect(RelCommon.Direct.getDefaultInstance());
636641
}
637642

638-
rel.getHint().ifPresent(md -> builder.setHint(md.toProto()));
643+
if (rel.getHint().isPresent()) {
644+
io.substrait.hint.Hint hint = rel.getHint().get();
645+
Hint.Builder hintBuilder = Hint.newBuilder();
646+
647+
hint.getAlias().ifPresent(hintBuilder::setAlias);
648+
hintBuilder.addAllOutputNames(hint.getOutputNames());
649+
650+
if (hint.getStats().isPresent()) {
651+
io.substrait.hint.Hint.Stats stats = hint.getStats().get();
652+
Stats.Builder statsBuilder = Stats.newBuilder();
653+
654+
stats.getExtension().ifPresent(ae -> statsBuilder.setAdvancedExtension(ae.toProto(this)));
655+
hintBuilder.setStats(
656+
statsBuilder.setRowCount(stats.rowCount()).setRecordSize(stats.recordSize()));
657+
}
658+
659+
if (hint.getRuntimeConstraint().isPresent()) {
660+
io.substrait.hint.Hint.RuntimeConstraint rc = hint.getRuntimeConstraint().get();
661+
RuntimeConstraint.Builder rcBuilder = RuntimeConstraint.newBuilder();
662+
663+
rc.getExtension().ifPresent(ae -> rcBuilder.setAdvancedExtension(ae.toProto(this)));
664+
hintBuilder.setConstraint(rcBuilder);
665+
}
666+
667+
hint.getLoadedComputations()
668+
.forEach(
669+
loadedComp ->
670+
hintBuilder.addLoadedComputations(
671+
LoadedComputation.newBuilder()
672+
.setComputationIdReference(loadedComp.computationId())
673+
.setType(loadedComp.computationType().toProto())));
674+
675+
hint.getSavedComputations()
676+
.forEach(
677+
savedComp ->
678+
hintBuilder.addSavedComputations(
679+
SavedComputation.newBuilder()
680+
.setComputationId(savedComp.computationId())
681+
.setType(savedComp.computationType().toProto())));
682+
683+
builder.setHint(hintBuilder.build());
684+
}
639685

640686
return builder.build();
641687
}

core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
import io.substrait.TestBase;
88
import io.substrait.extension.AdvancedExtension;
99
import io.substrait.hint.Hint;
10+
import io.substrait.hint.Hint.ComputationType;
11+
import io.substrait.hint.Hint.LoadedComputation;
12+
import io.substrait.hint.Hint.RuntimeConstraint;
13+
import io.substrait.hint.Hint.SavedComputation;
14+
import io.substrait.hint.Hint.Stats;
15+
import io.substrait.hint.ImmutableRuntimeConstraint;
16+
import io.substrait.hint.ImmutableStats;
1017
import io.substrait.relation.utils.StringHolder;
1118
import java.util.Arrays;
1219
import java.util.Collections;
@@ -130,28 +137,100 @@ void extensionTable() {
130137
@Nested
131138
class HintsTest {
132139

140+
Stats createStats(boolean includeEmptyOptimization) {
141+
ImmutableStats.Builder builder = Stats.builder();
142+
builder.rowCount(42).recordSize(42);
143+
if (includeEmptyOptimization) {
144+
builder.extension(AdvancedExtension.builder().addOptimizations().build());
145+
}
146+
return builder.build();
147+
}
148+
149+
LoadedComputation createLoadedComputation() {
150+
return LoadedComputation.builder()
151+
.computationId(1)
152+
.computationType(ComputationType.COMPUTATION_TYPE_UNKNOWN)
153+
.build();
154+
}
155+
156+
SavedComputation createSavedComputation() {
157+
return SavedComputation.builder()
158+
.computationId(1)
159+
.computationType(ComputationType.COMPUTATION_TYPE_UNKNOWN)
160+
.build();
161+
}
162+
163+
RuntimeConstraint createRuntimeConstraint(boolean includeEmptyOptimization) {
164+
ImmutableRuntimeConstraint.Builder builder = RuntimeConstraint.builder();
165+
if (includeEmptyOptimization) {
166+
builder.extension(AdvancedExtension.builder().addOptimizations().build());
167+
}
168+
return builder.build();
169+
}
170+
171+
@Test
172+
void relWithCompleteHint() {
173+
Hint test =
174+
Hint.builder()
175+
.alias("TestHint")
176+
.addAllOutputNames(Arrays.asList("Hint 1", "Hint 2"))
177+
.stats(createStats(true))
178+
.addAllLoadedComputations(
179+
Arrays.asList(createLoadedComputation(), createLoadedComputation()))
180+
.addAllSavedComputations(
181+
Arrays.asList(createSavedComputation(), createSavedComputation()))
182+
.runtimeConstraint(createRuntimeConstraint(true))
183+
.build();
184+
185+
Rel relWithCompleteHint = NamedScan.builder().from(commonTable).hint(test).build();
186+
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithCompleteHint);
187+
Rel relReturned = protoRelConverter.from(protoRel);
188+
assertEquals(relWithCompleteHint, relReturned);
189+
}
190+
133191
@Test
134-
void relWithHint() {
135-
Rel relWithHints =
136-
NamedScan.builder()
137-
.from(commonTable)
138-
.hint(Hint.builder().addOutputNames("Test hint").build())
192+
void relWithLoadedComputationHint() {
193+
Hint test =
194+
Hint.builder()
195+
.alias("TestHint")
196+
.addAllOutputNames(Arrays.asList("Hint 1", "Hint 2"))
197+
.stats(createStats(false))
198+
.addAllLoadedComputations(
199+
Arrays.asList(createLoadedComputation(), createLoadedComputation()))
200+
.runtimeConstraint(createRuntimeConstraint(false))
139201
.build();
140-
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithHints);
202+
203+
Rel relWithLoadedComputationHint = NamedScan.builder().from(commonTable).hint(test).build();
204+
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithLoadedComputationHint);
141205
Rel relReturned = protoRelConverter.from(protoRel);
142-
assertEquals(relWithHints, relReturned);
206+
assertEquals(relWithLoadedComputationHint, relReturned);
143207
}
144208

145209
@Test
146-
void relWithHints() {
147-
Rel relWithHints =
148-
NamedScan.builder()
149-
.from(commonTable)
150-
.hint(Hint.builder().addAllOutputNames(Arrays.asList("Hint 1", "Hint 2")).build())
210+
void relWithSavedComputationHint() {
211+
Hint test =
212+
Hint.builder()
213+
.alias("TestHint")
214+
.addAllOutputNames(Arrays.asList("Hint 1", "Hint 2"))
215+
.stats(createStats(false))
216+
.addAllSavedComputations(
217+
Arrays.asList(createSavedComputation(), createSavedComputation()))
218+
.runtimeConstraint(createRuntimeConstraint(false))
151219
.build();
152-
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithHints);
220+
221+
Rel relWithSavedComputationHint = NamedScan.builder().from(commonTable).hint(test).build();
222+
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithSavedComputationHint);
223+
Rel relReturned = protoRelConverter.from(protoRel);
224+
assertEquals(relWithSavedComputationHint, relReturned);
225+
}
226+
227+
@Test
228+
void relWithMinimalHint() {
229+
Hint test = Hint.builder().build();
230+
Rel relWithMinimalHint = NamedScan.builder().from(commonTable).hint(test).build();
231+
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithMinimalHint);
153232
Rel relReturned = protoRelConverter.from(protoRel);
154-
assertEquals(relWithHints, relReturned);
233+
assertEquals(relWithMinimalHint, relReturned);
155234
}
156235
}
157236
}

0 commit comments

Comments
 (0)