25
25
import java .util .List ;
26
26
import java .util .Locale ;
27
27
import java .util .Map ;
28
+ import java .util .Objects ;
28
29
import java .util .Optional ;
29
30
import java .util .Set ;
30
31
import java .util .function .Function ;
@@ -149,7 +150,6 @@ protected class FunctionFinder {
149
150
private final SqlOperator operator ;
150
151
private final List <F > functions ;
151
152
private final Map <String , F > directMap ;
152
- private final SignatureMatcher <F > matcher ;
153
153
private final Optional <SingularArgumentMatcher <F >> singularInputType ;
154
154
private final Util .IntRange argRange ;
155
155
@@ -161,7 +161,6 @@ public FunctionFinder(String name, SqlOperator operator, List<F> functions) {
161
161
Util .IntRange .of (
162
162
functions .stream ().mapToInt (t -> t .getRange ().getStartInclusive ()).min ().getAsInt (),
163
163
functions .stream ().mapToInt (t -> t .getRange ().getEndExclusive ()).max ().getAsInt ());
164
- this .matcher = getSignatureMatcher (operator , functions );
165
164
this .singularInputType = getSingularInputType (functions );
166
165
var directMap = ImmutableMap .<String , F >builder ();
167
166
for (var func : functions ) {
@@ -178,21 +177,18 @@ public boolean allowedArgCount(int count) {
178
177
return argRange .within (count );
179
178
}
180
179
181
- private static <F extends SimpleExtension .Function > SignatureMatcher <F > getSignatureMatcher (
182
- SqlOperator operator , List <F > functions ) {
183
- return (inputTypes , outputType ) -> {
184
- for (F function : functions ) {
185
- List <SimpleExtension .Argument > args = function .requiredArguments ();
186
- // Make sure that arguments & return are within bounds and match the types
187
- if (function .returnType () instanceof ParameterizedType
188
- && isMatch (outputType , (ParameterizedType ) function .returnType ())
189
- && inputTypesSatisfyDefinedArguments (inputTypes , args )) {
190
- return Optional .of (function );
191
- }
180
+ private Optional <F > signatureMatch (List <Type > inputTypes , Type outputType ) {
181
+ for (F function : functions ) {
182
+ List <SimpleExtension .Argument > args = function .requiredArguments ();
183
+ // Make sure that arguments & return are within bounds and match the types
184
+ if (function .returnType () instanceof ParameterizedType
185
+ && isMatch (outputType , (ParameterizedType ) function .returnType ())
186
+ && inputTypesMatchDefinedArguments (inputTypes , args )) {
187
+ return Optional .of (function );
192
188
}
189
+ }
193
190
194
- return Optional .empty ();
195
- };
191
+ return Optional .empty ();
196
192
}
197
193
198
194
/**
@@ -208,7 +204,7 @@ && inputTypesSatisfyDefinedArguments(inputTypes, args)) {
208
204
* @param args expected arguments as defined in a {@link SimpleExtension.Function}
209
205
* @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise
210
206
*/
211
- private static boolean inputTypesSatisfyDefinedArguments (
207
+ private static boolean inputTypesMatchDefinedArguments (
212
208
List <Type > inputTypes , List <SimpleExtension .Argument > args ) {
213
209
214
210
Map <String , Set <Type >> wildcardToType = new HashMap <>();
@@ -318,7 +314,7 @@ private Stream<String> matchKeys(List<RexNode> rexOperands, List<String> opTypes
318
314
319
315
assert (rexOperands .size () == opTypes .size ());
320
316
321
- if (rexOperands .size () == 0 ) {
317
+ if (rexOperands .isEmpty () ) {
322
318
return Stream .of ("" );
323
319
} else {
324
320
List <List <String >> argTypeLists =
@@ -357,13 +353,12 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
357
353
// try to do a direct match
358
354
List <String > typeStrings =
359
355
opTypes .stream ().map (t -> t .accept (ToTypeString .INSTANCE )).collect (Collectors .toList ());
360
- Stream <String > possibleKeys =
361
- matchKeys (call .getOperands ().collect (Collectors .toList ()), typeStrings );
356
+ Stream <String > possibleKeys = matchKeys (operandsList , typeStrings );
362
357
363
358
Optional <String > directMatchKey =
364
359
possibleKeys
365
360
.map (argList -> name + ":" + argList )
366
- .filter (k -> directMap . containsKey ( k ) )
361
+ .filter (directMap :: containsKey )
367
362
.findFirst ();
368
363
369
364
if (directMatchKey .isPresent ()) {
@@ -376,14 +371,13 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
376
371
RexNode r = operandsList .get (i );
377
372
Expression o = operands .get (i );
378
373
if (EnumConverter .isEnumValue (r )) {
379
- return EnumConverter .fromRex (variant , (RexLiteral ) r , i )
380
- .orElseGet (() -> null );
374
+ return EnumConverter .fromRex (variant , (RexLiteral ) r , i ).orElse (null );
381
375
} else {
382
376
return o ;
383
377
}
384
378
})
385
379
.collect (Collectors .toList ());
386
- boolean allArgsMapped = funcArgs .stream ().filter (e -> e == null ).findFirst ().isEmpty ();
380
+ boolean allArgsMapped = funcArgs .stream ().filter (Objects :: isNull ).findFirst ().isEmpty ();
387
381
if (allArgsMapped ) {
388
382
return Optional .of (generateBinding (call , variant , funcArgs , outputType ));
389
383
} else {
@@ -413,53 +407,35 @@ private Optional<T> matchByLeastRestrictive(
413
407
return Optional .empty ();
414
408
}
415
409
Type type = typeConverter .toSubstrait (leastRestrictive );
416
- var out = singularInputType .get ().tryMatch (type , outputType );
417
-
418
- if (out .isPresent ()) {
419
- var declaration = out .get ();
420
- var coercedArgs = coerceArguments (operands , type );
421
- declaration .validateOutputType (coercedArgs , outputType );
422
- return Optional .of (
423
- generateBinding (
424
- call ,
425
- out .get (),
426
- coercedArgs .stream ().map (FunctionArg .class ::cast ).collect (Collectors .toList ()),
427
- outputType ));
428
- }
429
- return Optional .empty ();
410
+ var out = singularInputType .orElseThrow ().tryMatch (type , outputType );
411
+
412
+ return out .map (
413
+ declaration -> {
414
+ var coercedArgs = coerceArguments (operands , type );
415
+ declaration .validateOutputType (coercedArgs , outputType );
416
+ return generateBinding (call , out .get (), coercedArgs , outputType );
417
+ });
430
418
}
431
419
432
- private Optional <T > matchCoerced (C call , Type outputType , List <Expression > operands ) {
433
-
420
+ private Optional <T > matchCoerced (C call , Type outputType , List <Expression > expressions ) {
434
421
// Convert the operands to the proper Substrait type
435
- List <Type > allTypes =
422
+ List <Type > operandTypes =
436
423
call .getOperands ()
437
424
.map (RexNode ::getType )
438
425
.map (typeConverter ::toSubstrait )
439
426
.collect (Collectors .toList ());
440
427
441
- // See if all the input types match the function
442
- Optional <F > matchFunction = this .matcher .tryMatch (allTypes , outputType );
443
- if (matchFunction .isPresent ()) {
444
- List <Expression > coerced =
445
- Streams .zip (
446
- operands .stream (),
447
- call .getOperands (),
448
- (a , b ) -> {
449
- Type type = typeConverter .toSubstrait (b .getType ());
450
- return coerceArgument (a , type );
451
- })
452
- .collect (Collectors .toList ());
453
-
454
- return Optional .of (
455
- generateBinding (
456
- call ,
457
- matchFunction .get (),
458
- coerced .stream ().map (FunctionArg .class ::cast ).collect (Collectors .toList ()),
459
- outputType ));
428
+ // See if all the input types can be made to match the function
429
+ Optional <F > matchFunction = signatureMatch (operandTypes , outputType );
430
+ if (matchFunction .isEmpty ()) {
431
+ return Optional .empty ();
460
432
}
461
433
462
- return Optional .empty ();
434
+ var coercedArgs =
435
+ Streams .zip (
436
+ expressions .stream (), operandTypes .stream (), FunctionConverter ::coerceArgument )
437
+ .collect (Collectors .toList ());
438
+ return Optional .of (generateBinding (call , matchFunction .get (), coercedArgs , outputType ));
463
439
}
464
440
465
441
protected String getName () {
@@ -481,56 +457,30 @@ public interface GenericCall {
481
457
* Coerced types according to an expected output type. Coercion is only done for type mismatches,
482
458
* not for nullability or parameter mismatches.
483
459
*/
484
- private static List <Expression > coerceArguments (List <Expression > arguments , Type type ) {
485
- return arguments .stream ().map (a -> coerceArgument (a , type )).collect (Collectors .toList ());
460
+ private static List <Expression > coerceArguments (List <Expression > arguments , Type targetType ) {
461
+ return arguments .stream ().map (a -> coerceArgument (a , targetType )).collect (Collectors .toList ());
486
462
}
487
463
488
464
private static Expression coerceArgument (Expression argument , Type type ) {
489
- var typeMatches = isMatch (type , argument .getType ());
490
- if (!typeMatches ) {
491
- return ExpressionCreator .cast (type , argument , Expression .FailureBehavior .THROW_EXCEPTION );
465
+ if (isMatch (type , argument .getType ())) {
466
+ return argument ;
492
467
}
493
- return argument ;
468
+
469
+ return ExpressionCreator .cast (type , argument , Expression .FailureBehavior .THROW_EXCEPTION );
494
470
}
495
471
496
472
protected abstract T generateBinding (
497
- C call , F function , List <FunctionArg > arguments , Type outputType );
473
+ C call , F function , List <? extends FunctionArg > arguments , Type outputType );
498
474
499
- public interface SingularArgumentMatcher <F > {
475
+ @ FunctionalInterface
476
+ private interface SingularArgumentMatcher <F > {
500
477
Optional <F > tryMatch (Type type , Type outputType );
501
478
}
502
479
503
- public interface SignatureMatcher <F > {
504
- Optional <F > tryMatch (List <Type > types , Type outputType );
505
- }
506
-
507
- private static SignatureMatcher chainedSignature (SignatureMatcher ... matchers ) {
508
- return switch (matchers .length ) {
509
- case 0 -> (types , outputType ) -> Optional .empty ();
510
- case 1 -> matchers [0 ];
511
- default -> (types , outputType ) -> {
512
- for (SignatureMatcher m : matchers ) {
513
- var t = m .tryMatch (types , outputType );
514
- if (t .isPresent ()) {
515
- return t ;
516
- }
517
- }
518
- return Optional .empty ();
519
- };
520
- };
521
- }
522
-
523
- private static boolean isMatch (Type inputType , ParameterizedType type ) {
524
- if (type .isWildcard ()) {
525
- return true ;
526
- }
527
- return inputType .accept (new IgnoreNullableAndParameters (type ));
528
- }
529
-
530
- private static boolean isMatch (ParameterizedType inputType , ParameterizedType type ) {
531
- if (type .isWildcard ()) {
480
+ private static boolean isMatch (ParameterizedType actualType , ParameterizedType targetType ) {
481
+ if (targetType .isWildcard ()) {
532
482
return true ;
533
483
}
534
- return inputType .accept (new IgnoreNullableAndParameters (type ));
484
+ return actualType .accept (new IgnoreNullableAndParameters (targetType ));
535
485
}
536
486
}
0 commit comments