Skip to content

Commit c42d45c

Browse files
committed
Remove unwrapping
Signed-off-by: Ryan Nett <[email protected]>
1 parent a3803e5 commit c42d45c

File tree

1 file changed

+0
-100
lines changed

1 file changed

+0
-100
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadSessionFromSavedModel;
1919
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewGraph;
20-
import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrValueProto;
2120
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
2221

2322
import com.google.protobuf.InvalidProtocolBufferException;
@@ -32,26 +31,21 @@
3231
import java.util.LinkedHashMap;
3332
import java.util.List;
3433
import java.util.Map;
35-
import java.util.regex.Pattern;
3634
import java.util.stream.Collectors;
3735
import org.bytedeco.javacpp.BytePointer;
3836
import org.bytedeco.javacpp.PointerPointer;
3937
import org.bytedeco.javacpp.PointerScope;
4038
import org.tensorflow.exceptions.TensorFlowException;
4139
import org.tensorflow.internal.c_api.TF_Buffer;
4240
import org.tensorflow.internal.c_api.TF_Graph;
43-
import org.tensorflow.internal.c_api.TF_Operation;
4441
import org.tensorflow.internal.c_api.TF_Session;
4542
import org.tensorflow.internal.c_api.TF_SessionOptions;
4643
import org.tensorflow.internal.c_api.TF_Status;
47-
import org.tensorflow.proto.framework.AttrValue;
4844
import org.tensorflow.proto.framework.ConfigProto;
4945
import org.tensorflow.proto.framework.MetaGraphDef;
5046
import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef;
5147
import org.tensorflow.proto.framework.RunOptions;
5248
import org.tensorflow.proto.framework.SavedModel;
53-
import org.tensorflow.proto.framework.SignatureDef;
54-
import org.tensorflow.proto.framework.TensorInfo;
5549
import org.tensorflow.proto.util.SaverDef;
5650

5751
/**
@@ -387,53 +381,6 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef
387381
this.functions = functions;
388382
}
389383

390-
private static final Pattern INFERENCE_FUNCTION_NAME_PATTERN = Pattern
391-
.compile("__inference_(.+)_\\d+", Pattern.DOTALL);
392-
393-
/**
394-
* Check that all outputs of the signature come from a single call op that takes the inputs.
395-
*/
396-
private static GraphOperation findFunctionWrapper(Graph graph, SignatureDef signatureDef) {
397-
398-
GraphOperation callOp = null;
399-
for (TensorInfo output : signatureDef.getOutputsMap().values()) {
400-
GraphOperation op = (GraphOperation) graph.outputOrError(output.getName()).op();
401-
if (callOp == null) {
402-
callOp = op;
403-
} else if (!callOp.equals(op)) {
404-
return null;
405-
}
406-
}
407-
408-
if (callOp == null) {
409-
return null;
410-
}
411-
412-
if (callOp != null) {
413-
414-
if (callOp.numInputs() != signatureDef.getInputsCount() || callOp.numOutputs() != signatureDef
415-
.getOutputsCount()) {
416-
return null;
417-
}
418-
419-
int i = 0;
420-
List<Operand<?>> opInputs = callOp.inputs();
421-
422-
for (TensorInfo input : signatureDef.getInputsMap().values()) {
423-
if (!graph.outputOrError(input.getName()).equals(opInputs.get(i))) {
424-
return null;
425-
}
426-
i++;
427-
}
428-
}
429-
430-
if (!callOp.type().equals(ConcreteFunction.CALL_OP) && !callOp.type().equals(ConcreteFunction.STATEFUL_CALL_OP)) {
431-
return null;
432-
}
433-
434-
return callOp;
435-
}
436-
437384
/**
438385
* Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session object, plus the
439386
* MetaGraphDef.
@@ -454,53 +401,6 @@ private static SavedModelBundle fromHandle(
454401
List<ConcreteFunction> graphFunctions = graph.getFunctions();
455402
metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> {
456403

457-
GraphOperation callOp = findFunctionWrapper(graph, signatureDef);
458-
459-
// if the function is a thin wrapper around a function call, unwrap it
460-
if (callOp != null) {
461-
462-
try (PointerScope scope = new PointerScope()) {
463-
TF_Operation op = ((GraphOperation) graph
464-
.outputOrError(signatureDef.getOutputsMap().values().iterator().next().getName()).op())
465-
.getUnsafeNativeHandle();
466-
TF_Status status = TF_Status.newStatus();
467-
TF_Buffer buff = TF_Buffer.newBuffer();
468-
TF_OperationGetAttrValueProto(op, "f", buff, status);
469-
status.throwExceptionIfNotOK();
470-
AttrValue def = AttrValue.parseFrom(buff.dataAsByteBuffer());
471-
472-
String functionName = def.getFunc().getName();
473-
474-
ConcreteFunction function = null;
475-
for (ConcreteFunction fn : graphFunctions) {
476-
if (fn.getNativeFunctionName().equals(functionName)) {
477-
function = fn;
478-
break;
479-
}
480-
}
481-
482-
if (function != null) {
483-
functions.put(signatureName, function.withNewSignature(new Signature(signatureName, signatureDef)));
484-
}
485-
} catch (InvalidProtocolBufferException | IllegalArgumentException ignored) {
486-
487-
}
488-
}
489-
//
490-
// // try to do the unwrapping based on name if there are no outputs (and thus we can't find the call op)
491-
// if (!functions.containsKey(signatureName) && signatureDef.getOutputsCount() < 1) {
492-
// for (ConcreteFunction fn : graphFunctions) {
493-
// Matcher matcher = INFERENCE_FUNCTION_NAME_PATTERN.matcher(fn.getNativeFunctionName());
494-
// if (matcher.find()) {
495-
// String fnName = matcher.group(1);
496-
// if (fnName.equals(signatureName)) {
497-
// functions.put(signatureName, fn);
498-
// break;
499-
// }
500-
// }
501-
// }
502-
// }
503-
504404
// otherwise use the wrapper
505405
if (!functions.containsKey(signatureName)) {
506406
Signature signature = new Signature(signatureName, signatureDef);

0 commit comments

Comments
 (0)