Skip to content

Commit 1f0e4de

Browse files
feat(isthmus): support for SQL TRIM function (#401)
1 parent f9c8691 commit 1f0e4de

14 files changed

+550
-209
lines changed

build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import com.diffplug.gradle.spotless.SpotlessExtension
22
import com.diffplug.gradle.spotless.SpotlessPlugin
33
import com.github.vlsi.gradle.dsl.configureEach
4+
import org.gradle.api.tasks.testing.logging.TestExceptionFormat
45

56
plugins {
67
`maven-publish`
@@ -43,6 +44,7 @@ allprojects {
4344
val javaToolchains = project.extensions.getByType<JavaToolchainService>()
4445
useJUnitPlatform()
4546
javaLauncher.set(javaToolchains.launcherFor { languageVersion.set(JavaLanguageVersion.of(11)) })
47+
testLogging { exceptionFormat = TestExceptionFormat.FULL }
4648
}
4749
tasks.withType<JavaCompile> {
4850
sourceCompatibility = "17"

isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.apache.calcite.rex.RexBuilder;
1515
import org.apache.calcite.rex.RexLiteral;
1616
import org.apache.calcite.rex.RexNode;
17+
import org.apache.calcite.sql.fun.SqlTrimFunction;
1718
import org.apache.calcite.sql.type.SqlTypeName;
1819

1920
/**
@@ -52,15 +53,38 @@ public class EnumConverter {
5253
calciteEnumMap.put(
5354
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_time", 0),
5455
TimeUnitRange.class);
56+
57+
calciteEnumMap.put(
58+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_STRING, "trim:vchar_vchar", 0),
59+
SqlTrimFunction.Flag.class);
60+
calciteEnumMap.put(
61+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_STRING, "trim:str_str", 0),
62+
SqlTrimFunction.Flag.class);
63+
calciteEnumMap.put(
64+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_STRING, "ltrim:vchar_vchar", 0),
65+
SqlTrimFunction.Flag.class);
66+
calciteEnumMap.put(
67+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_STRING, "ltrim:str_str", 0),
68+
SqlTrimFunction.Flag.class);
69+
calciteEnumMap.put(
70+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_STRING, "rtrim:vchar_vchar", 0),
71+
SqlTrimFunction.Flag.class);
72+
calciteEnumMap.put(
73+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_STRING, "rtrim:str_str", 0),
74+
SqlTrimFunction.Flag.class);
5575
}
5676

5777
private static Optional<Enum<?>> constructValue(
5878
Class<? extends Enum<?>> cls, Supplier<Optional<String>> option) {
5979
if (cls.isAssignableFrom(TimeUnitRange.class)) {
6080
return option.get().map(TimeUnitRange::valueOf);
61-
} else {
62-
return Optional.empty();
6381
}
82+
83+
if (cls.isAssignableFrom(SqlTrimFunction.Flag.class)) {
84+
return option.get().map(SqlTrimFunction.Flag::valueOf);
85+
}
86+
87+
return Optional.empty();
6488
}
6589

6690
static Optional<RexLiteral> toRex(
@@ -123,9 +147,7 @@ static boolean canConvert(Enum<?> value) {
123147
}
124148

125149
static boolean isEnumValue(RexNode value) {
126-
return value != null
127-
&& (value instanceof RexLiteral)
128-
&& value.getType().getSqlTypeName() == SqlTypeName.SYMBOL;
150+
return value instanceof RexLiteral && value.getType().getSqlTypeName() == SqlTypeName.SYMBOL;
129151
}
130152

131153
private static class ArgAnchor {

isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,11 @@ public RexNode visit(Expression.ScalarFunctionInvocation expr) throws RuntimeExc
363363
callConversionFailureMessage(
364364
"scalar", expr.declaration().name(), expr.arguments())));
365365

366-
var eArgs = expr.arguments();
366+
var eArgs = scalarFunctionConverter.getExpressionArguments(expr);
367367
var args =
368-
IntStream.range(0, expr.arguments().size())
368+
IntStream.range(0, eArgs.size())
369369
.mapToObj(i -> eArgs.get(i).accept(expr.declaration(), i, this))
370-
.collect(java.util.stream.Collectors.toList());
370+
.collect(Collectors.toList());
371371

372372
RelDataType returnType = typeConverter.toCalcite(typeFactory, expr.outputType());
373373
return rexBuilder.makeCall(returnType, operator, args);
@@ -395,9 +395,9 @@ public RexNode visit(Expression.WindowFunctionInvocation expr) throws RuntimeExc
395395

396396
List<FunctionArg> eArgs = expr.arguments();
397397
List<RexNode> args =
398-
IntStream.range(0, expr.arguments().size())
398+
IntStream.range(0, eArgs.size())
399399
.mapToObj(i -> eArgs.get(i).accept(expr.declaration(), i, this))
400-
.collect(java.util.stream.Collectors.toList());
400+
.collect(Collectors.toList());
401401

402402
List<RexNode> partitionKeys =
403403
expr.partitionBy().stream().map(e -> e.accept(this)).collect(Collectors.toList());

isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public FunctionConverter(
6767
signatures.addAll(additionalSignatures);
6868
signatures.addAll(getSigs());
6969
this.typeFactory = typeFactory;
70-
this.substraitFuncKeyToSqlOperatorMap = ArrayListMultimap.<String, SqlOperator>create();
70+
this.substraitFuncKeyToSqlOperatorMap = ArrayListMultimap.create();
7171

7272
var alm = ArrayListMultimap.<String, F>create();
7373
for (var f : functions) {
@@ -78,22 +78,19 @@ public FunctionConverter(
7878
signatures.stream()
7979
.collect(
8080
Multimaps.toMultimap(
81-
FunctionMappings.Sig::name, f -> f, () -> ArrayListMultimap.create()));
81+
FunctionMappings.Sig::name, Function.identity(), ArrayListMultimap::create));
8282
var matcherMap = new IdentityHashMap<SqlOperator, FunctionFinder>();
8383
for (String key : alm.keySet()) {
8484
var sigs = calciteOperators.get(key);
85-
if (sigs == null) {
86-
logger.atInfo().log("Dropping function due to no binding: {}", key);
87-
continue;
85+
if (sigs.isEmpty()) {
86+
logger.atInfo().log("No binding for function: {}", key);
8887
}
8988

9089
for (var sig : sigs) {
9190
var implList = alm.get(key);
92-
if (implList == null || implList.isEmpty()) {
93-
continue;
91+
if (!implList.isEmpty()) {
92+
matcherMap.put(sig.operator(), new FunctionFinder(key, sig.operator(), implList));
9493
}
95-
96-
matcherMap.put(sig.operator(), new FunctionFinder(key, sig.operator(), implList));
9794
}
9895
}
9996

@@ -110,14 +107,16 @@ public FunctionConverter(
110107

111108
public Optional<SqlOperator> getSqlOperatorFromSubstraitFunc(String key, Type outputType) {
112109
var resolver = getTypeBasedResolver();
113-
if (!substraitFuncKeyToSqlOperatorMap.containsKey(key)) {
110+
var operators = substraitFuncKeyToSqlOperatorMap.get(key);
111+
if (operators.isEmpty()) {
114112
return Optional.empty();
115113
}
116-
var operators = substraitFuncKeyToSqlOperatorMap.get(key);
114+
117115
// only one SqlOperator is possible
118116
if (operators.size() == 1) {
119117
return Optional.of(operators.iterator().next());
120118
}
119+
121120
// at least 2 operators. Use output type to resolve SqlOperator.
122121
String outputTypeStr = outputType.accept(ToTypeString.INSTANCE);
123122
var resolvedOperators =
@@ -146,15 +145,15 @@ private Map<SqlOperator, FunctionMappings.TypeBasedResolver> getTypeBasedResolve
146145
protected abstract ImmutableList<FunctionMappings.Sig> getSigs();
147146

148147
protected class FunctionFinder {
149-
private final String name;
148+
private final String substraitName;
150149
private final SqlOperator operator;
151150
private final List<F> functions;
152151
private final Map<String, F> directMap;
153152
private final Optional<SingularArgumentMatcher<F>> singularInputType;
154153
private final Util.IntRange argRange;
155154

156-
public FunctionFinder(String name, SqlOperator operator, List<F> functions) {
157-
this.name = name;
155+
public FunctionFinder(String substraitName, SqlOperator operator, List<F> functions) {
156+
this.substraitName = substraitName;
158157
this.operator = operator;
159158
this.functions = functions;
160159
this.argRange =
@@ -167,7 +166,7 @@ public FunctionFinder(String name, SqlOperator operator, List<F> functions) {
167166
String key = func.key();
168167
directMap.put(key, func);
169168
if (func.requiredArguments().size() != func.args().size()) {
170-
directMap.put(F.constructKey(name, func.requiredArguments()), func);
169+
directMap.put(F.constructKey(substraitName, func.requiredArguments()), func);
171170
}
172171
}
173172
this.directMap = directMap.build();
@@ -357,7 +356,7 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
357356

358357
Optional<String> directMatchKey =
359358
possibleKeys
360-
.map(argList -> name + ":" + argList)
359+
.map(argList -> substraitName + ":" + argList)
361360
.filter(directMap::containsKey)
362361
.findFirst();
363362

@@ -438,8 +437,8 @@ private Optional<T> matchCoerced(C call, Type outputType, List<Expression> expre
438437
return Optional.of(generateBinding(call, matchFunction.get(), coercedArgs, outputType));
439438
}
440439

441-
protected String getName() {
442-
return name;
440+
protected String getSubstraitName() {
441+
return substraitName;
443442
}
444443

445444
public SqlOperator getOperator() {

0 commit comments

Comments
 (0)