Skip to content

Commit a1fbe53

Browse files
feat(isthmus): enable CHAR upcasting in Calcite function calls (#338)
Configures the SubstraitTypeSystem, which extends RelDataTypeSystemImpl, to set shouldConvertRaggedUnionTypesToVarying to true, which has the effect of making the least restrictive type of CHAR types of different lengths VARCHAR. This means that when processing SQL queries with functions like concat, Calcite will upcast all of the input arguments to the correct least restrictive type, which better matches the expectations of Substrait.
1 parent 492c4f6 commit a1fbe53

File tree

8 files changed

+85
-122
lines changed

8 files changed

+85
-122
lines changed

isthmus/src/main/java/io/substrait/isthmus/SubstraitTypeSystem.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ public int getMaxNumericPrecision() {
3939
return 38;
4040
}
4141

42+
@Override
43+
public boolean shouldConvertRaggedUnionTypesToVarying() {
44+
return true;
45+
}
46+
4247
public static RelDataTypeFactory createTypeFactory() {
4348
return new JavaTypeFactoryImpl(TYPE_SYSTEM);
4449
}

isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public AggregateFunctionConverter(
5252
protected AggregateFunctionInvocation generateBinding(
5353
WrappedAggregateCall call,
5454
SimpleExtension.AggregateFunctionVariant function,
55-
List<FunctionArg> arguments,
55+
List<? extends FunctionArg> arguments,
5656
Type outputType) {
5757
AggregateCall agg = call.getUnderlying();
5858

isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java

Lines changed: 48 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.List;
2626
import java.util.Locale;
2727
import java.util.Map;
28+
import java.util.Objects;
2829
import java.util.Optional;
2930
import java.util.Set;
3031
import java.util.function.Function;
@@ -149,7 +150,6 @@ protected class FunctionFinder {
149150
private final SqlOperator operator;
150151
private final List<F> functions;
151152
private final Map<String, F> directMap;
152-
private final SignatureMatcher<F> matcher;
153153
private final Optional<SingularArgumentMatcher<F>> singularInputType;
154154
private final Util.IntRange argRange;
155155

@@ -161,7 +161,6 @@ public FunctionFinder(String name, SqlOperator operator, List<F> functions) {
161161
Util.IntRange.of(
162162
functions.stream().mapToInt(t -> t.getRange().getStartInclusive()).min().getAsInt(),
163163
functions.stream().mapToInt(t -> t.getRange().getEndExclusive()).max().getAsInt());
164-
this.matcher = getSignatureMatcher(operator, functions);
165164
this.singularInputType = getSingularInputType(functions);
166165
var directMap = ImmutableMap.<String, F>builder();
167166
for (var func : functions) {
@@ -178,21 +177,18 @@ public boolean allowedArgCount(int count) {
178177
return argRange.within(count);
179178
}
180179

181-
private static <F extends SimpleExtension.Function> SignatureMatcher<F> getSignatureMatcher(
182-
SqlOperator operator, List<F> functions) {
183-
return (inputTypes, outputType) -> {
184-
for (F function : functions) {
185-
List<SimpleExtension.Argument> args = function.requiredArguments();
186-
// Make sure that arguments & return are within bounds and match the types
187-
if (function.returnType() instanceof ParameterizedType
188-
&& isMatch(outputType, (ParameterizedType) function.returnType())
189-
&& inputTypesSatisfyDefinedArguments(inputTypes, args)) {
190-
return Optional.of(function);
191-
}
180+
private Optional<F> signatureMatch(List<Type> inputTypes, Type outputType) {
181+
for (F function : functions) {
182+
List<SimpleExtension.Argument> args = function.requiredArguments();
183+
// Make sure that arguments & return are within bounds and match the types
184+
if (function.returnType() instanceof ParameterizedType
185+
&& isMatch(outputType, (ParameterizedType) function.returnType())
186+
&& inputTypesMatchDefinedArguments(inputTypes, args)) {
187+
return Optional.of(function);
192188
}
189+
}
193190

194-
return Optional.empty();
195-
};
191+
return Optional.empty();
196192
}
197193

198194
/**
@@ -208,7 +204,7 @@ && inputTypesSatisfyDefinedArguments(inputTypes, args)) {
208204
* @param args expected arguments as defined in a {@link SimpleExtension.Function}
209205
* @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise
210206
*/
211-
private static boolean inputTypesSatisfyDefinedArguments(
207+
private static boolean inputTypesMatchDefinedArguments(
212208
List<Type> inputTypes, List<SimpleExtension.Argument> args) {
213209

214210
Map<String, Set<Type>> wildcardToType = new HashMap<>();
@@ -318,7 +314,7 @@ private Stream<String> matchKeys(List<RexNode> rexOperands, List<String> opTypes
318314

319315
assert (rexOperands.size() == opTypes.size());
320316

321-
if (rexOperands.size() == 0) {
317+
if (rexOperands.isEmpty()) {
322318
return Stream.of("");
323319
} else {
324320
List<List<String>> argTypeLists =
@@ -357,13 +353,12 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
357353
// try to do a direct match
358354
List<String> typeStrings =
359355
opTypes.stream().map(t -> t.accept(ToTypeString.INSTANCE)).collect(Collectors.toList());
360-
Stream<String> possibleKeys =
361-
matchKeys(call.getOperands().collect(Collectors.toList()), typeStrings);
356+
Stream<String> possibleKeys = matchKeys(operandsList, typeStrings);
362357

363358
Optional<String> directMatchKey =
364359
possibleKeys
365360
.map(argList -> name + ":" + argList)
366-
.filter(k -> directMap.containsKey(k))
361+
.filter(directMap::containsKey)
367362
.findFirst();
368363

369364
if (directMatchKey.isPresent()) {
@@ -376,14 +371,13 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
376371
RexNode r = operandsList.get(i);
377372
Expression o = operands.get(i);
378373
if (EnumConverter.isEnumValue(r)) {
379-
return EnumConverter.fromRex(variant, (RexLiteral) r, i)
380-
.orElseGet(() -> null);
374+
return EnumConverter.fromRex(variant, (RexLiteral) r, i).orElse(null);
381375
} else {
382376
return o;
383377
}
384378
})
385379
.collect(Collectors.toList());
386-
boolean allArgsMapped = funcArgs.stream().filter(e -> e == null).findFirst().isEmpty();
380+
boolean allArgsMapped = funcArgs.stream().filter(Objects::isNull).findFirst().isEmpty();
387381
if (allArgsMapped) {
388382
return Optional.of(generateBinding(call, variant, funcArgs, outputType));
389383
} else {
@@ -413,53 +407,35 @@ private Optional<T> matchByLeastRestrictive(
413407
return Optional.empty();
414408
}
415409
Type type = typeConverter.toSubstrait(leastRestrictive);
416-
var out = singularInputType.get().tryMatch(type, outputType);
417-
418-
if (out.isPresent()) {
419-
var declaration = out.get();
420-
var coercedArgs = coerceArguments(operands, type);
421-
declaration.validateOutputType(coercedArgs, outputType);
422-
return Optional.of(
423-
generateBinding(
424-
call,
425-
out.get(),
426-
coercedArgs.stream().map(FunctionArg.class::cast).collect(Collectors.toList()),
427-
outputType));
428-
}
429-
return Optional.empty();
410+
var out = singularInputType.orElseThrow().tryMatch(type, outputType);
411+
412+
return out.map(
413+
declaration -> {
414+
var coercedArgs = coerceArguments(operands, type);
415+
declaration.validateOutputType(coercedArgs, outputType);
416+
return generateBinding(call, out.get(), coercedArgs, outputType);
417+
});
430418
}
431419

432-
private Optional<T> matchCoerced(C call, Type outputType, List<Expression> operands) {
433-
420+
private Optional<T> matchCoerced(C call, Type outputType, List<Expression> expressions) {
434421
// Convert the operands to the proper Substrait type
435-
List<Type> allTypes =
422+
List<Type> operandTypes =
436423
call.getOperands()
437424
.map(RexNode::getType)
438425
.map(typeConverter::toSubstrait)
439426
.collect(Collectors.toList());
440427

441-
// See if all the input types match the function
442-
Optional<F> matchFunction = this.matcher.tryMatch(allTypes, outputType);
443-
if (matchFunction.isPresent()) {
444-
List<Expression> coerced =
445-
Streams.zip(
446-
operands.stream(),
447-
call.getOperands(),
448-
(a, b) -> {
449-
Type type = typeConverter.toSubstrait(b.getType());
450-
return coerceArgument(a, type);
451-
})
452-
.collect(Collectors.toList());
453-
454-
return Optional.of(
455-
generateBinding(
456-
call,
457-
matchFunction.get(),
458-
coerced.stream().map(FunctionArg.class::cast).collect(Collectors.toList()),
459-
outputType));
428+
// See if all the input types can be made to match the function
429+
Optional<F> matchFunction = signatureMatch(operandTypes, outputType);
430+
if (matchFunction.isEmpty()) {
431+
return Optional.empty();
460432
}
461433

462-
return Optional.empty();
434+
var coercedArgs =
435+
Streams.zip(
436+
expressions.stream(), operandTypes.stream(), FunctionConverter::coerceArgument)
437+
.collect(Collectors.toList());
438+
return Optional.of(generateBinding(call, matchFunction.get(), coercedArgs, outputType));
463439
}
464440

465441
protected String getName() {
@@ -481,56 +457,30 @@ public interface GenericCall {
481457
* Coerced types according to an expected output type. Coercion is only done for type mismatches,
482458
* not for nullability or parameter mismatches.
483459
*/
484-
private static List<Expression> coerceArguments(List<Expression> arguments, Type type) {
485-
return arguments.stream().map(a -> coerceArgument(a, type)).collect(Collectors.toList());
460+
private static List<Expression> coerceArguments(List<Expression> arguments, Type targetType) {
461+
return arguments.stream().map(a -> coerceArgument(a, targetType)).collect(Collectors.toList());
486462
}
487463

488464
private static Expression coerceArgument(Expression argument, Type type) {
489-
var typeMatches = isMatch(type, argument.getType());
490-
if (!typeMatches) {
491-
return ExpressionCreator.cast(type, argument, Expression.FailureBehavior.THROW_EXCEPTION);
465+
if (isMatch(type, argument.getType())) {
466+
return argument;
492467
}
493-
return argument;
468+
469+
return ExpressionCreator.cast(type, argument, Expression.FailureBehavior.THROW_EXCEPTION);
494470
}
495471

496472
protected abstract T generateBinding(
497-
C call, F function, List<FunctionArg> arguments, Type outputType);
473+
C call, F function, List<? extends FunctionArg> arguments, Type outputType);
498474

499-
public interface SingularArgumentMatcher<F> {
475+
@FunctionalInterface
476+
private interface SingularArgumentMatcher<F> {
500477
Optional<F> tryMatch(Type type, Type outputType);
501478
}
502479

503-
public interface SignatureMatcher<F> {
504-
Optional<F> tryMatch(List<Type> types, Type outputType);
505-
}
506-
507-
private static SignatureMatcher chainedSignature(SignatureMatcher... matchers) {
508-
return switch (matchers.length) {
509-
case 0 -> (types, outputType) -> Optional.empty();
510-
case 1 -> matchers[0];
511-
default -> (types, outputType) -> {
512-
for (SignatureMatcher m : matchers) {
513-
var t = m.tryMatch(types, outputType);
514-
if (t.isPresent()) {
515-
return t;
516-
}
517-
}
518-
return Optional.empty();
519-
};
520-
};
521-
}
522-
523-
private static boolean isMatch(Type inputType, ParameterizedType type) {
524-
if (type.isWildcard()) {
525-
return true;
526-
}
527-
return inputType.accept(new IgnoreNullableAndParameters(type));
528-
}
529-
530-
private static boolean isMatch(ParameterizedType inputType, ParameterizedType type) {
531-
if (type.isWildcard()) {
480+
private static boolean isMatch(ParameterizedType actualType, ParameterizedType targetType) {
481+
if (targetType.isWildcard()) {
532482
return true;
533483
}
534-
return inputType.accept(new IgnoreNullableAndParameters(type));
484+
return actualType.accept(new IgnoreNullableAndParameters(targetType));
535485
}
536486
}

isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public Optional<Expression> convert(
6060
protected Expression generateBinding(
6161
WrappedScalarCall call,
6262
SimpleExtension.ScalarFunctionVariant function,
63-
List<FunctionArg> arguments,
63+
List<? extends FunctionArg> arguments,
6464
Type outputType) {
6565
return Expression.ScalarFunctionInvocation.builder()
6666
.outputType(outputType)

isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public WindowFunctionConverter(
5252
protected Expression.WindowFunctionInvocation generateBinding(
5353
WrappedWindowCall call,
5454
SimpleExtension.WindowFunctionVariant function,
55-
List<FunctionArg> arguments,
55+
List<? extends FunctionArg> arguments,
5656
Type outputType) {
5757
RexOver over = call.over;
5858
RexWindow window = over.getWindow();

isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public WindowRelFunctionConverter(
5151
protected ConsistentPartitionWindow.WindowRelFunctionInvocation generateBinding(
5252
WrappedWindowRelCall call,
5353
SimpleExtension.WindowFunctionVariant function,
54-
List<FunctionArg> arguments,
54+
List<? extends FunctionArg> arguments,
5555
Type outputType) {
5656
Window.RexWinAggCall over = call.getWinAggCall();
5757

isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,19 @@ public void unsupportedExtractDateWithIndexing() {
258258
assertThrows(
259259
UnsupportedOperationException.class, () -> reqReqDateFn.accept(expressionRexConverter));
260260
}
261+
262+
@Test
263+
public void concatStringLiteralAndVarchar() throws Exception {
264+
assertProtoPlanRoundrip("select 'part_'||P_NAME from PART");
265+
}
266+
267+
@Test
268+
public void concatCharAndVarchar() throws Exception {
269+
assertProtoPlanRoundrip("select P_BRAND||P_NAME from PART");
270+
}
271+
272+
@Test
273+
public void concatStringLiteralAndChar() throws Exception {
274+
assertProtoPlanRoundrip("select 'brand_'||P_BRAND from PART");
275+
}
261276
}
Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package io.substrait.isthmus;
22

33
import com.google.protobuf.util.JsonFormat;
4+
import java.util.Set;
5+
import java.util.stream.IntStream;
46
import org.apache.calcite.adapter.tpcds.TpcdsSchema;
57
import org.junit.jupiter.params.ParameterizedTest;
6-
import org.junit.jupiter.params.provider.ValueSource;
8+
import org.junit.jupiter.params.provider.MethodSource;
79

810
/**
911
*
@@ -27,33 +29,24 @@
2729
*/
2830
public class TpcdsQueryNoValidation extends PlanTestBase {
2931

32+
static final Set<Integer> EXCLUDED =
33+
Set.of(2, 9, 12, 20, 27, 36, 47, 51, 53, 57, 63, 70, 86, 89, 98);
34+
35+
static IntStream testCases() {
36+
return IntStream.rangeClosed(1, 99).filter(n -> !EXCLUDED.contains(n));
37+
}
38+
3039
/**
3140
* This test only validates that generating substrait plans for TPC-DS queries does not fail. As
3241
* of now this test does not validate correctness of the generated plan
3342
*/
34-
private void testQuery(int i) throws Exception {
43+
@ParameterizedTest
44+
@MethodSource("testCases")
45+
void testQuery(int i) throws Exception {
3546
SqlToSubstrait s = new SqlToSubstrait();
3647
TpcdsSchema schema = new TpcdsSchema(1.0);
3748
String sql = asString(String.format("tpcds/queries/%02d.sql", i));
3849
var plan = s.execute(sql, "tpcds", schema);
3950
System.out.println(JsonFormat.printer().print(plan));
4051
}
41-
42-
@ParameterizedTest
43-
@ValueSource(
44-
ints = {
45-
1, 3, 4, 6, 7, 8, 10, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 28, 29, 30,
46-
31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 52, 54, 55, 56, 58,
47-
59, 60, 61, 62, 64, 65, 67, 68, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 85, 87,
48-
88, 90, 92, 93, 94, 95, 96, 97, 99
49-
})
50-
public void tpcdsSuccess(int query) throws Exception {
51-
testQuery(query);
52-
}
53-
54-
@ParameterizedTest
55-
@ValueSource(ints = {2, 5, 9, 12, 20, 27, 36, 47, 51, 53, 57, 63, 66, 70, 80, 84, 86, 89, 91, 98})
56-
public void tpcdsFailure(int query) throws Exception {
57-
// testQuery(query);
58-
}
5952
}

0 commit comments

Comments
 (0)