Skip to content

Commit 6d94059

Browse files
authored
feat: support output order remapping (#132)
* refactor: expand access to SubstraitRelNodeConverter fields * feat: add deriveRecordType to Cross * feat: substrait builder dsl * feat: self-contained Substrait to Calcite converter * chore: bump junit-jupiter * test: check for application of remappings * feat: apply remaps * refactor: move RelOutputTest to SubstraitRelNodeConverterTest
1 parent 2ab6272 commit 6d94059

File tree

14 files changed

+740
-47
lines changed

14 files changed

+740
-47
lines changed

build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ repositories { mavenCentral() }
1616
java { toolchain { languageVersion.set(JavaLanguageVersion.of(17)) } }
1717

1818
dependencies {
19-
testImplementation("org.junit.jupiter:junit-jupiter-api:5.6.0")
19+
testImplementation("org.junit.jupiter:junit-jupiter-api:5.9.2")
2020
testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine")
2121
implementation("org.slf4j:slf4j-jdk14:1.7.30")
2222
annotationProcessor("org.immutables:value:2.8.8")

core/build.gradle.kts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ signing {
6868
}
6969

7070
dependencies {
71-
testImplementation("org.junit.jupiter:junit-jupiter-api:5.6.0")
72-
testImplementation("org.junit.jupiter:junit-jupiter-params:5.6.0")
71+
testImplementation("org.junit.jupiter:junit-jupiter-api:5.9.2")
72+
testImplementation("org.junit.jupiter:junit-jupiter-params:5.9.2")
7373
testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine")
7474
implementation("com.google.protobuf:protobuf-java:3.17.3")
7575
implementation("com.fasterxml.jackson.core:jackson-databind:2.13.4")
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
package io.substrait.dsl;
2+
3+
import com.github.bsideup.jabel.Desugar;
4+
import io.substrait.expression.AggregateFunctionInvocation;
5+
import io.substrait.expression.Expression;
6+
import io.substrait.expression.FieldReference;
7+
import io.substrait.expression.ImmutableFieldReference;
8+
import io.substrait.function.SimpleExtension;
9+
import io.substrait.plan.ImmutableRoot;
10+
import io.substrait.plan.Plan;
11+
import io.substrait.proto.AggregateFunction;
12+
import io.substrait.relation.Aggregate;
13+
import io.substrait.relation.Cross;
14+
import io.substrait.relation.Fetch;
15+
import io.substrait.relation.Filter;
16+
import io.substrait.relation.Join;
17+
import io.substrait.relation.NamedScan;
18+
import io.substrait.relation.Project;
19+
import io.substrait.relation.Rel;
20+
import io.substrait.relation.Set;
21+
import io.substrait.relation.Sort;
22+
import io.substrait.type.NamedStruct;
23+
import io.substrait.type.Type;
24+
import io.substrait.type.TypeCreator;
25+
import java.util.Arrays;
26+
import java.util.List;
27+
import java.util.Optional;
28+
import java.util.function.Function;
29+
import java.util.stream.Collectors;
30+
import java.util.stream.Stream;
31+
32+
public class SubstraitBuilder {
33+
static final TypeCreator R = TypeCreator.of(false);
34+
static final TypeCreator N = TypeCreator.of(true);
35+
private final SimpleExtension.ExtensionCollection extensions;
36+
37+
public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) {
38+
this.extensions = extensions;
39+
}
40+
41+
// Relations
42+
public Aggregate aggregate(
43+
Function<Rel, Aggregate.Grouping> groupingFn,
44+
Function<Rel, List<AggregateFunctionInvocation>> measuresFn,
45+
Rel input) {
46+
Function<Rel, List<Aggregate.Grouping>> groupingsFn =
47+
groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList()));
48+
return aggregate(groupingsFn, measuresFn, Optional.empty(), input);
49+
}
50+
51+
public Aggregate aggregate(
52+
Function<Rel, Aggregate.Grouping> groupingFn,
53+
Function<Rel, List<AggregateFunctionInvocation>> measuresFn,
54+
Rel.Remap remap,
55+
Rel input) {
56+
Function<Rel, List<Aggregate.Grouping>> groupingsFn =
57+
groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList()));
58+
return aggregate(groupingsFn, measuresFn, Optional.of(remap), input);
59+
}
60+
61+
private Aggregate aggregate(
62+
Function<Rel, List<Aggregate.Grouping>> groupingsFn,
63+
Function<Rel, List<AggregateFunctionInvocation>> measuresFn,
64+
Optional<Rel.Remap> remap,
65+
Rel input) {
66+
var groupings = groupingsFn.apply(input);
67+
var measures =
68+
measuresFn.apply(input).stream()
69+
.map(m -> Aggregate.Measure.builder().function(m).build())
70+
.collect(java.util.stream.Collectors.toList());
71+
return Aggregate.builder()
72+
.groupings(groupings)
73+
.measures(measures)
74+
.remap(remap)
75+
.input(input)
76+
.build();
77+
}
78+
79+
public Cross cross(Rel left, Rel right) {
80+
return cross(left, right, Optional.empty());
81+
}
82+
83+
public Cross cross(Rel left, Rel right, Rel.Remap remap) {
84+
return cross(left, right, Optional.of(remap));
85+
}
86+
87+
private Cross cross(Rel left, Rel right, Optional<Rel.Remap> remap) {
88+
return Cross.builder().left(left).right(right).remap(remap).build();
89+
}
90+
91+
public Fetch fetch(long offset, long count, Rel input) {
92+
return fetch(offset, count, Optional.empty(), input);
93+
}
94+
95+
public Fetch fetch(long offset, long count, Rel.Remap remap, Rel input) {
96+
return fetch(offset, count, Optional.of(remap), input);
97+
}
98+
99+
private Fetch fetch(long offset, long count, Optional<Rel.Remap> remap, Rel input) {
100+
return Fetch.builder().offset(offset).count(count).input(input).remap(remap).build();
101+
}
102+
103+
public Filter filter(Function<Rel, Expression> conditionFn, Rel input) {
104+
return filter(conditionFn, Optional.empty(), input);
105+
}
106+
107+
public Filter filter(Function<Rel, Expression> conditionFn, Rel.Remap remap, Rel input) {
108+
return filter(conditionFn, Optional.of(remap), input);
109+
}
110+
111+
private Filter filter(
112+
Function<Rel, Expression> conditionFn, Optional<Rel.Remap> remap, Rel input) {
113+
var condition = conditionFn.apply(input);
114+
return Filter.builder().input(input).condition(condition).remap(remap).build();
115+
}
116+
117+
@Desugar
118+
public record JoinInput(Rel left, Rel right) {}
119+
120+
public Join innerJoin(Function<JoinInput, Expression> conditionFn, Rel left, Rel right) {
121+
return join(conditionFn, Join.JoinType.INNER, left, right);
122+
}
123+
124+
public Join innerJoin(
125+
Function<JoinInput, Expression> conditionFn, Rel.Remap remap, Rel left, Rel right) {
126+
return join(conditionFn, Join.JoinType.INNER, remap, left, right);
127+
}
128+
129+
public Join join(
130+
Function<JoinInput, Expression> conditionFn, Join.JoinType joinType, Rel left, Rel right) {
131+
return join(conditionFn, joinType, Optional.empty(), left, right);
132+
}
133+
134+
public Join join(
135+
Function<JoinInput, Expression> conditionFn,
136+
Join.JoinType joinType,
137+
Rel.Remap remap,
138+
Rel left,
139+
Rel right) {
140+
return join(conditionFn, joinType, Optional.of(remap), left, right);
141+
}
142+
143+
private Join join(
144+
Function<JoinInput, Expression> conditionFn,
145+
Join.JoinType joinType,
146+
Optional<Rel.Remap> remap,
147+
Rel left,
148+
Rel right) {
149+
var condition = conditionFn.apply(new JoinInput(left, right));
150+
return Join.builder()
151+
.left(left)
152+
.right(right)
153+
.condition(condition)
154+
.joinType(joinType)
155+
.remap(remap)
156+
.build();
157+
}
158+
159+
public NamedScan namedScan(
160+
Iterable<String> tableName, Iterable<String> columnNames, Iterable<Type> types) {
161+
return namedScan(tableName, columnNames, types, Optional.empty());
162+
}
163+
164+
public NamedScan namedScan(
165+
Iterable<String> tableName,
166+
Iterable<String> columnNames,
167+
Iterable<Type> types,
168+
Rel.Remap remap) {
169+
return namedScan(tableName, columnNames, types, Optional.of(remap));
170+
}
171+
172+
private NamedScan namedScan(
173+
Iterable<String> tableName,
174+
Iterable<String> columnNames,
175+
Iterable<Type> types,
176+
Optional<Rel.Remap> remap) {
177+
var struct = Type.Struct.builder().addAllFields(types).nullable(false).build();
178+
var namedStruct = NamedStruct.of(columnNames, struct);
179+
return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
180+
}
181+
182+
public Project project(Function<Rel, Iterable<? extends Expression>> expressionsFn, Rel input) {
183+
return project(expressionsFn, Optional.empty(), input);
184+
}
185+
186+
public Project project(
187+
Function<Rel, Iterable<? extends Expression>> expressionsFn, Rel.Remap remap, Rel input) {
188+
return project(expressionsFn, Optional.of(remap), input);
189+
}
190+
191+
private Project project(
192+
Function<Rel, Iterable<? extends Expression>> expressionsFn,
193+
Optional<Rel.Remap> remap,
194+
Rel input) {
195+
var expressions = expressionsFn.apply(input);
196+
return Project.builder().input(input).expressions(expressions).remap(remap).build();
197+
}
198+
199+
public Set set(Set.SetOp op, Rel... inputs) {
200+
return set(op, Optional.empty(), inputs);
201+
}
202+
203+
public Set set(Set.SetOp op, Rel.Remap remap, Rel... inputs) {
204+
return set(op, Optional.of(remap), inputs);
205+
}
206+
207+
private Set set(Set.SetOp op, Optional<Rel.Remap> remap, Rel... inputs) {
208+
return Set.builder().setOp(op).remap(remap).addAllInputs(Arrays.asList(inputs)).build();
209+
}
210+
211+
public Sort sort(Function<Rel, Iterable<? extends Expression.SortField>> sortFieldFn, Rel input) {
212+
return sort(sortFieldFn, Optional.empty(), input);
213+
}
214+
215+
public Sort sort(
216+
Function<Rel, Iterable<? extends Expression.SortField>> sortFieldFn,
217+
Rel.Remap remap,
218+
Rel input) {
219+
return sort(sortFieldFn, Optional.of(remap), input);
220+
}
221+
222+
private Sort sort(
223+
Function<Rel, Iterable<? extends Expression.SortField>> sortFieldFn,
224+
Optional<Rel.Remap> remap,
225+
Rel input) {
226+
var condition = sortFieldFn.apply(input);
227+
return Sort.builder().input(input).sortFields(condition).remap(remap).build();
228+
}
229+
230+
// Expressions
231+
232+
public Expression.BoolLiteral bool(boolean v) {
233+
return Expression.BoolLiteral.builder().value(v).build();
234+
}
235+
236+
public FieldReference fieldReference(Rel input, int index) {
237+
return ImmutableFieldReference.newInputRelReference(index, input);
238+
}
239+
240+
public List<FieldReference> fieldReferences(Rel input, int... indexes) {
241+
return Arrays.stream(indexes)
242+
.mapToObj(index -> fieldReference(input, index))
243+
.collect(java.util.stream.Collectors.toList());
244+
}
245+
246+
public List<Expression.SortField> sortFields(Rel input, int... indexes) {
247+
return Arrays.stream(indexes)
248+
.mapToObj(
249+
index ->
250+
Expression.SortField.builder()
251+
.expr(ImmutableFieldReference.newInputRelReference(index, input))
252+
.direction(Expression.SortDirection.ASC_NULLS_LAST)
253+
.build())
254+
.collect(java.util.stream.Collectors.toList());
255+
}
256+
257+
// Aggregate Functions
258+
259+
public Aggregate.Grouping grouping(Rel input, int... indexes) {
260+
var columns = fieldReferences(input, indexes);
261+
return Aggregate.Grouping.builder().addAllExpressions(columns).build();
262+
}
263+
264+
public AggregateFunctionInvocation count(Rel input, int field) {
265+
var declaration =
266+
extensions.getAggregateFunction(
267+
SimpleExtension.FunctionAnchor.of("/functions_aggregate_generic.yaml", "count:any"));
268+
return AggregateFunctionInvocation.builder()
269+
.arguments(fieldReferences(input, field))
270+
.outputType(R.I64)
271+
.declaration(declaration)
272+
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
273+
.invocation(AggregateFunction.AggregationInvocation.AGGREGATION_INVOCATION_ALL)
274+
.build();
275+
}
276+
277+
// Scalar Functions
278+
279+
// Misc
280+
281+
public Plan.Root root(Rel rel) {
282+
return ImmutableRoot.builder().input(rel).build();
283+
}
284+
285+
public Rel.Remap remap(Integer... fields) {
286+
return Rel.Remap.of(Arrays.asList(fields));
287+
}
288+
}

core/src/main/java/io/substrait/relation/Cross.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
package io.substrait.relation;
22

3+
import io.substrait.type.Type;
4+
import io.substrait.type.TypeCreator;
5+
import java.util.stream.Stream;
36
import org.immutables.value.Value;
47

58
@Value.Immutable
69
public abstract class Cross extends BiRel {
710

11+
@Override
12+
protected Type.Struct deriveRecordType() {
13+
return TypeCreator.REQUIRED.struct(
14+
Stream.concat(
15+
getLeft().getRecordType().fields().stream(),
16+
getRight().getRecordType().fields().stream()));
17+
}
18+
819
@Override
920
public <O, E extends Exception> O accept(RelVisitor<O, E> visitor) throws E {
1021
return visitor.visit(this);

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -336,15 +336,7 @@ private Join newJoin(JoinRel rel) {
336336
private Rel newCross(CrossRel rel) {
337337
Rel left = from(rel.getLeft());
338338
Rel right = from(rel.getRight());
339-
Type.Struct leftStruct = left.getRecordType();
340-
Type.Struct rightStruct = right.getRecordType();
341-
Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build();
342-
return Cross.builder()
343-
.left(left)
344-
.right(right)
345-
.deriveRecordType(unionedStruct)
346-
.remap(optionalRelmap(rel.getCommon()))
347-
.build();
339+
return Cross.builder().left(left).right(right).remap(optionalRelmap(rel.getCommon())).build();
348340
}
349341

350342
private Set newSet(SetRel rel) {

core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ public Optional<Rel> visit(Cross cross) throws RuntimeException {
163163
.from(cross)
164164
.left(left.orElse(cross.getLeft()))
165165
.right(right.orElse(cross.getRight()))
166-
.deriveRecordType(unionedStruct)
167166
.build());
168167
}
169168

core/src/main/java/io/substrait/type/NamedStruct.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public interface NamedStruct {
1010

1111
List<String> names();
1212

13-
public static NamedStruct of(List<String> names, Type.Struct type) {
13+
static NamedStruct of(Iterable<String> names, Type.Struct type) {
1414
return ImmutableNamedStruct.builder().addAllNames(names).struct(type).build();
1515
}
1616

isthmus/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ dependencies {
7777
implementation(project(":core"))
7878
implementation("org.apache.calcite:calcite-core:${CALCITE_VERSION}")
7979
implementation("org.apache.calcite:calcite-server:${CALCITE_VERSION}")
80-
implementation("org.junit.jupiter:junit-jupiter:5.7.0")
80+
implementation("org.junit.jupiter:junit-jupiter:5.9.2")
8181
implementation("org.reflections:reflections:0.9.12")
8282
implementation("com.google.guava:guava:29.0-jre")
8383
implementation("org.graalvm.sdk:graal-sdk:22.0.0.2")

0 commit comments

Comments
 (0)