17
17
18
18
import static org .tensorflow .internal .c_api .global .tensorflow .TF_LoadSessionFromSavedModel ;
19
19
import static org .tensorflow .internal .c_api .global .tensorflow .TF_NewGraph ;
20
- import static org .tensorflow .internal .c_api .global .tensorflow .TF_OperationGetAttrValueProto ;
21
20
import static org .tensorflow .internal .c_api .global .tensorflow .TF_SetConfig ;
22
21
23
22
import com .google .protobuf .InvalidProtocolBufferException ;
32
31
import java .util .LinkedHashMap ;
33
32
import java .util .List ;
34
33
import java .util .Map ;
35
- import java .util .regex .Pattern ;
36
34
import java .util .stream .Collectors ;
37
35
import org .bytedeco .javacpp .BytePointer ;
38
36
import org .bytedeco .javacpp .PointerPointer ;
39
37
import org .bytedeco .javacpp .PointerScope ;
40
38
import org .tensorflow .exceptions .TensorFlowException ;
41
39
import org .tensorflow .internal .c_api .TF_Buffer ;
42
40
import org .tensorflow .internal .c_api .TF_Graph ;
43
- import org .tensorflow .internal .c_api .TF_Operation ;
44
41
import org .tensorflow .internal .c_api .TF_Session ;
45
42
import org .tensorflow .internal .c_api .TF_SessionOptions ;
46
43
import org .tensorflow .internal .c_api .TF_Status ;
47
- import org .tensorflow .proto .framework .AttrValue ;
48
44
import org .tensorflow .proto .framework .ConfigProto ;
49
45
import org .tensorflow .proto .framework .MetaGraphDef ;
50
46
import org .tensorflow .proto .framework .MetaGraphDef .MetaInfoDef ;
51
47
import org .tensorflow .proto .framework .RunOptions ;
52
48
import org .tensorflow .proto .framework .SavedModel ;
53
- import org .tensorflow .proto .framework .SignatureDef ;
54
- import org .tensorflow .proto .framework .TensorInfo ;
55
49
import org .tensorflow .proto .util .SaverDef ;
56
50
57
51
/**
@@ -387,53 +381,6 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef
387
381
this .functions = functions ;
388
382
}
389
383
390
- private static final Pattern INFERENCE_FUNCTION_NAME_PATTERN = Pattern
391
- .compile ("__inference_(.+)_\\ d+" , Pattern .DOTALL );
392
-
393
- /**
394
- * Check that all outputs of the signature come from a single call op that takes the inputs.
395
- */
396
- private static GraphOperation findFunctionWrapper (Graph graph , SignatureDef signatureDef ) {
397
-
398
- GraphOperation callOp = null ;
399
- for (TensorInfo output : signatureDef .getOutputsMap ().values ()) {
400
- GraphOperation op = (GraphOperation ) graph .outputOrError (output .getName ()).op ();
401
- if (callOp == null ) {
402
- callOp = op ;
403
- } else if (!callOp .equals (op )) {
404
- return null ;
405
- }
406
- }
407
-
408
- if (callOp == null ) {
409
- return null ;
410
- }
411
-
412
- if (callOp != null ) {
413
-
414
- if (callOp .numInputs () != signatureDef .getInputsCount () || callOp .numOutputs () != signatureDef
415
- .getOutputsCount ()) {
416
- return null ;
417
- }
418
-
419
- int i = 0 ;
420
- List <Operand <?>> opInputs = callOp .inputs ();
421
-
422
- for (TensorInfo input : signatureDef .getInputsMap ().values ()) {
423
- if (!graph .outputOrError (input .getName ()).equals (opInputs .get (i ))) {
424
- return null ;
425
- }
426
- i ++;
427
- }
428
- }
429
-
430
- if (!callOp .type ().equals (ConcreteFunction .CALL_OP ) && !callOp .type ().equals (ConcreteFunction .STATEFUL_CALL_OP )) {
431
- return null ;
432
- }
433
-
434
- return callOp ;
435
- }
436
-
437
384
/**
438
385
* Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session object, plus the
439
386
* MetaGraphDef.
@@ -454,53 +401,6 @@ private static SavedModelBundle fromHandle(
454
401
List <ConcreteFunction > graphFunctions = graph .getFunctions ();
455
402
metaGraphDef .getSignatureDefMap ().forEach ((signatureName , signatureDef ) -> {
456
403
457
- GraphOperation callOp = findFunctionWrapper (graph , signatureDef );
458
-
459
- // if the function is a thin wrapper around a function call, unwrap it
460
- if (callOp != null ) {
461
-
462
- try (PointerScope scope = new PointerScope ()) {
463
- TF_Operation op = ((GraphOperation ) graph
464
- .outputOrError (signatureDef .getOutputsMap ().values ().iterator ().next ().getName ()).op ())
465
- .getUnsafeNativeHandle ();
466
- TF_Status status = TF_Status .newStatus ();
467
- TF_Buffer buff = TF_Buffer .newBuffer ();
468
- TF_OperationGetAttrValueProto (op , "f" , buff , status );
469
- status .throwExceptionIfNotOK ();
470
- AttrValue def = AttrValue .parseFrom (buff .dataAsByteBuffer ());
471
-
472
- String functionName = def .getFunc ().getName ();
473
-
474
- ConcreteFunction function = null ;
475
- for (ConcreteFunction fn : graphFunctions ) {
476
- if (fn .getNativeFunctionName ().equals (functionName )) {
477
- function = fn ;
478
- break ;
479
- }
480
- }
481
-
482
- if (function != null ) {
483
- functions .put (signatureName , function .withNewSignature (new Signature (signatureName , signatureDef )));
484
- }
485
- } catch (InvalidProtocolBufferException | IllegalArgumentException ignored ) {
486
-
487
- }
488
- }
489
- //
490
- // // try to do the unwrapping based on name if there are no outputs (and thus we can't find the call op)
491
- // if (!functions.containsKey(signatureName) && signatureDef.getOutputsCount() < 1) {
492
- // for (ConcreteFunction fn : graphFunctions) {
493
- // Matcher matcher = INFERENCE_FUNCTION_NAME_PATTERN.matcher(fn.getNativeFunctionName());
494
- // if (matcher.find()) {
495
- // String fnName = matcher.group(1);
496
- // if (fnName.equals(signatureName)) {
497
- // functions.put(signatureName, fn);
498
- // break;
499
- // }
500
- // }
501
- // }
502
- // }
503
-
504
404
// otherwise use the wrapper
505
405
if (!functions .containsKey (signatureName )) {
506
406
Signature signature = new Signature (signatureName , signatureDef );
0 commit comments