diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 92e4cabdbd1..c68b6ee8ff7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -19,6 +19,8 @@ import java.nio.charset.Charset; import java.util.List; +import java.util.Map; +import org.tensorflow.ConcreteFunction; import org.tensorflow.DeviceSpec; import org.tensorflow.EagerSession; import org.tensorflow.ExecutionEnvironment; @@ -87,6 +89,7 @@ import org.tensorflow.op.core.ExtractVolumePatches; import org.tensorflow.op.core.Fill; import org.tensorflow.op.core.Fingerprint; +import org.tensorflow.op.core.Function; import org.tensorflow.op.core.Gather; import org.tensorflow.op.core.GatherNd; import org.tensorflow.op.core.GetSessionHandle; @@ -1116,6 +1119,31 @@ public Bucketize bucketize(Operand input, List boundar return Bucketize.create(scope, input, boundaries); } + /** + * Calls the function in an execution environment, adding its graph as a function if it isn't + * already present. Only works for functions with a single input and output. + * + * @param argument the argument to the call + * @return the output of the function + * @see ConcreteFunction#call(Ops, Operand) + */ + public Operand call(ConcreteFunction function, Operand argument) { + return Function.call(scope, function, argument); + } + + /** + * Calls the function in an execution environment, adding its graph as a function if it isn't + * already present. The inputs and outputs are keyed by the names set in the {@code Signature}. + * + * @param arguments the arguments to the call + * @return the outputs of the function + * @see ConcreteFunction#call(Ops, Map) + */ + public Map> call(ConcreteFunction function, + Map> arguments) { + return Function.call(scope, function, arguments); + } + /** * Clips tensor values to a specified min and max. * Given a tensor {@code t}, this operation returns a tensor of the same type and diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java index e370b2f9f08..829d1cede3c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java @@ -13,7 +13,7 @@ // Once created and added to graphs, functions can be invoked by creating an // operation whose operation type matches the function name. @Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) -public class TF_Function extends Pointer { +public class TF_Function extends org.tensorflow.internal.c_api.AbstractTF_Function { /** Empty constructor. Calls {@code super((Pointer)null)}. */ public TF_Function() { super((Pointer)null); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 71dc0f7cefc..3e264e0e25d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -1,54 +1,81 @@ -/* - * Copyright 2020 The TensorFlow Authors. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. +/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= */ package org.tensorflow; -import java.io.IOException; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionSetAttrValueProto; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; -import java.util.ListIterator; -import java.util.HashMap; import java.util.Map; +import java.util.Set; import java.util.function.Function; +import java.util.stream.Collectors; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.PointerPointer; +import org.bytedeco.javacpp.PointerScope; +import org.tensorflow.Graph.Reference; +import org.tensorflow.internal.c_api.TF_Function; +import org.tensorflow.internal.c_api.TF_Operation; +import org.tensorflow.internal.c_api.TF_Output; +import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.PlaceholderWithDefault; +import org.tensorflow.proto.framework.AttrValue; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.proto.framework.FunctionDef; +import org.tensorflow.proto.framework.OpDef.ArgDef; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.proto.framework.TensorShapeProto; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TType; /** * A graph that can be invoked as a single function, with an input and output signature. * - *

A function can also invoke a - * tf.function - * defined in a {@link SavedModelBundle}. + *

A function can also invoke a tf.function defined in a {@link + * SavedModelBundle}. * *

{@code
  * ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
  * Map outputTensorMap = myFunction.call(inputTensorMap);
  * }
*/ -public class ConcreteFunction implements AutoCloseable { +public class ConcreteFunction implements AutoCloseable, TensorFunction { /** * Creates a function by building a new graph. * - *

The {@code functionBuilder} must initialize the function graph from the provided - * {@link Ops} instance and return a valid signature that will be used to feed the input tensors - * and fetch the output tensors on execution. + *

The {@code functionBuilder} must initialize the function graph from the provided {@link Ops} + * instance and return a valid signature that will be used to feed the input tensors and fetch the + * output tensors on execution. * - *

The function will be the owner of the new graph and its resulting session. Therefore, - * the function must be enclosed properly with a try-with-resources block to guarantee that - * all native resources will be freed once the function is discarded. For example: + *

The function will be the owner of the new graph and its resulting session. Therefore, the + * function must be enclosed properly with a try-with-resources block to guarantee that all native + * resources will be freed once the function is discarded. For example: * *

{@code
    * public class MyModel {
@@ -72,23 +99,19 @@ public class ConcreteFunction implements AutoCloseable {
    * @return the new function
    */
   public static ConcreteFunction create(Function functionBuilder) {
-    Graph graph = new Graph();
-    try {
+    try (Graph graph = new Graph()) {
       Ops tf = Ops.create(graph);
       Signature signature = functionBuilder.apply(tf);
-      return new ConcreteFunction(signature, graph, new Session(graph), Ownership.GRAPH_AND_SESSION);
-    } catch (Exception e) {
-      graph.close();
-      throw e;
+      return buildFromGraph(graph, signature);
     }
   }
 
   /**
    * Create a function from a signature and an existing graph.
    *
-   * 

The function will keep the ownership of the session used to run the graph but not - * the graph itself, meaning that the lifetime of the latter can extend beyond the scope - * of the function. For example: + *

The function will keep the ownership of the session used to run the graph but not the graph + * itself, meaning that the lifetime of the latter can extend beyond the scope of the function. + * For example: * *

{@code
    * try (Graph g = new Graph()) {
@@ -109,15 +132,15 @@ public static ConcreteFunction create(Function functionBuilder)
    * @return a new function
    */
   public static ConcreteFunction create(Signature signature, Graph graph) {
-    return new ConcreteFunction(signature, graph, new Session(graph), Ownership.SESSION_ONLY);
+    return buildFromGraph(graph, signature);
   }
 
   /**
    * Create a function from a signature and a valid graph session.
    *
-   * 

The function will not own the session nor its graph, meaning that their lifetime - * can extend beyond the scope of the function. Therefore the function does not need to be - * closed after its usage. For example: + *

The function will not own the session nor its graph, meaning that their lifetime can extend + * beyond the scope of the function. Therefore the function does not need to be closed after its + * usage. For example: * *

{@code
    * try (Graph g = new Graph()) {
@@ -143,152 +166,480 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
    * @return a new function
    */
   public static ConcreteFunction create(Signature signature, Session session) {
-    return new ConcreteFunction(signature, session.graph(), session, Ownership.NONE);
+    return buildFromGraph(session.graph(), signature);
   }
 
-  /**
-   * Returns the signature of this function
-   */
+  /** Returns the signature of this function */
+  @Override
   public Signature signature() {
     return signature;
   }
 
   /**
-   * Invokes a function.
-   *
-   * 

Caller is responsible for closing all Tensors. + * Get the name of the function definition. This is what it will show up under in the graph and + * any exported GraphDefs, and should be used for anything using tensorflow core directly. + */ + public String getDefinedName() { + return nativeFunction.getName(); + } + + /** Get the {@link FunctionDef} proto. */ + public FunctionDef getFunctionDef() { + return nativeFunction.getFunctionDef(); + } + + /** Get whether the function is stateful. */ + public boolean isStateful() { + return nativeFunction.isStateful(); + } + + Set getDependencies() { + return dependencies; + } + + @Override + public void close() { + scope.close(); + } + + @Override + public String toString() { + return signature.toString(); + } + + // TODO migrate to the actual ops once they are generated + public static final String CALL_OP = "PartitionedCall"; + // TODO migrate to the actual ops once they are generated + public static final String STATEFUL_CALL_OP = "StatefulPartitionedCall"; + + /** + * Calls the function in an execution environment, adding its graph as a function if it isn't + * already present. The inputs and outputs are keyed by the names set in the {@code Signature}. * - * @param arguments list of tensors to pass in input to the function, - * mapped by their signature name - * @return output tensors resulting from the execution of the function, - * mapped by their signature name + * @param scope the scope to call the function in + * @param arguments the arguments to the call + * @return the outputs of the function */ - public Map call(Map arguments) - throws IllegalArgumentException { + public Map> call(Scope scope, Map> arguments) { + List> inputList = new ArrayList<>(); - final SignatureDef signatureDef = signature.asSignatureDef(); - final Session.Runner runner = session.runner(); + Output[] inputs = new Output[signature().inputNames().size()]; - signatureDef.getInputsMap().forEach((argName, t) -> { - Tensor tensor = arguments.get(argName); - if (tensor == null) { - throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); + int i = 0; + for (String inputName : signature().inputNames()) { + if (!arguments.containsKey(inputName)) { + throw new IllegalArgumentException( + "Function " + + signature().methodName() + + " has parameter \"" + + inputName + + "\", but no argument was passed for it."); } - runner.feed(t.getName(), tensor); - }); - Map outputToNode = signatureDef.getOutputsMap(); - outputToNode.values().forEach(t -> runner.fetch(t.getName())); + Operand input = arguments.get(inputName); + if (input == null) { + throw new IllegalArgumentException( + "Can't pass null as an argument to a function. Argument \"" + + inputName + + "\" was null."); + } + inputs[i] = input.asOutput(); + i++; + } - List resultTensors = runner.run(); - try { - ListIterator resultTensorIter = resultTensors.listIterator(); - Map returnMap = new HashMap(); + scope.env().attachFunction(this); + String name = getDefinedName(); - // Use the output names as present in the signature definition - for (String nodeName: outputToNode.keySet()) { - returnMap.put(nodeName, resultTensorIter.next()); - } - return returnMap; + String displayName = Scope.isValidOpName(name) ? name : "FunctionCall"; + + OperationBuilder opBuilder = + scope + .env() + .opBuilder(isStateful() ? STATEFUL_CALL_OP : CALL_OP, scope.makeOpName(displayName)); + + opBuilder.addInputList(inputs); + + opBuilder.setAttr("f", this); + opBuilder.setAttr("Tin", inputDtypes); + opBuilder.setAttr("Tout", outputDtypes); + + opBuilder = scope.apply(opBuilder); + Operation op = opBuilder.build(); + + int numOutputs1 = op.numOutputs(); + List> outputList = new ArrayList<>(signature().outputNames().size()); + + for (i = 0; i < numOutputs1; i++) { + outputList.add(op.output(i)); + } + + Map> namedOutputs = new LinkedHashMap<>(signature().outputNames().size()); - } catch (Exception e) { - // Release tensors before throwing exception - for (Tensor t : resultTensors) { - t.close(); + List outputNames = new ArrayList<>(signature().outputNames()); + for (i = 0; i < outputNames.size(); i++) { + String outputName = outputNames.get(i); + + if (i > outputList.size()) { + throw new IllegalStateException( + "Somehow, not all required outputs were returned from the function"); } - throw e; + + Operand output = outputList.get(i); + namedOutputs.put(outputName, output); } + + return Collections.unmodifiableMap(namedOutputs); } /** - * Invokes a function with a single input and output. - * - *

Caller is responsible for closing all Tensors. + * Calls the function in an execution environment, adding its graph as a function if it isn't + * already present. Only works for functions with a single input and output. * - * @param tensor input tensor - * @return output tensor - * @throws IllegalArgumentException if there are multiple input or output parameters defined - * in the function + * @param scope the scope to call the function in + * @param argument the argument to the call + * @return the output of the function */ - public Tensor call(Tensor tensor) throws IllegalArgumentException { + public Operand call(Scope scope, Operand argument) { final SignatureDef signatureDef = signature.asSignatureDef(); if (signatureDef.getInputsCount() != 1) { throw new IllegalArgumentException( - String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); + String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); } - String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName(); + String inputName = signatureDef.getInputsMap().keySet().iterator().next(); if (signatureDef.getOutputsCount() != 1) { throw new IllegalArgumentException( - String.format("Function [%s] has multiple outputs", signatureDef.getMethodName())); + String.format("Function [%s] has multiple outputs", signatureDef.getMethodName())); } - String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName(); + String outputName = signatureDef.getOutputsMap().keySet().iterator().next(); + + Map> inputMap = new LinkedHashMap<>(); + inputMap.put(inputName, argument); + + return call(scope, inputMap).get(outputName); + } + + @Override + public Map call(Map arguments) { + // FIXME need to manage input/output operand lifetimes + Ops tf = Ops.create(); + Map> inputs = new LinkedHashMap<>(arguments.size()); - return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); + for (String inputName : arguments.keySet()) { + Tensor argument = arguments.get(inputName); + inputs.put(inputName, tf.constantOf((TType) argument)); + } + Map> outputs = tf.call(this, inputs); + Map tensorOutputs = new LinkedHashMap<>(outputs.size()); + for (String outputName : outputs.keySet()) { + tensorOutputs.put(outputName, outputs.get(outputName).asTensor()); + } + return tensorOutputs; } /** - * Export this function as a saved model. - * - *

This method is convenient shortcut equivalent to - * {@code SavedModel.exporter(exportDir).withFunction(this).export()} + * Calls the function in an execution environment, adding its graph as a function if it isn't + * already present. The inputs and outputs are keyed by the names set in the {@code Signature}. * - * @param exportDir directory where to export the saved model - * @throws IOException if saved model or variable state cannot be written on disk + * @param tf the scope to call the function in + * @param arguments the arguments to the call + * @return the outputs of the function */ - public void save(String exportDir) throws IOException { - SavedModelBundle.exporter(exportDir).withFunction(this).export(); + public Map> call(Ops tf, Map> arguments) { + return tf.call(this, arguments); } /** - * Returns the session used to execute the graph when calling this function + * Calls the function in an execution environment, adding its graph as a function if it isn't + * already present. Only works for functions with a single input and output. * - *

In general, a user does not need to handle directly the session of a function and rely - * on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to - * the session might be necessary, as it allows more running options. - * - * @return the function session + * @param tf the scope to call the function in + * @param argument the argument to the call + * @return the output of the function + */ + public Operand call(Ops tf, Operand argument) { + return tf.call(this, argument); + } + + TF_Function nativeHandle() { + if (nativeFunction.getNativeHandle().isNull()) { + throw new IllegalStateException("Function has been closed"); + } + return nativeFunction.getNativeHandle(); + } + + /** All native functions should have deallocators registered */ + ConcreteFunction( + Signature signature, + NativeFunction nativeFunction, + Collection availableFunctions) { + this(signature, nativeFunction, nativeFunction.getAllDependencies(availableFunctions)); + } + + /** + * Detects the signature from the handle. Does not close passed functions. All passed functions + * should have deallocators. */ - public Session session() { - return session; + static ConcreteFunction fromNativeHandle( + NativeFunction nativeFunction, Collection availableFunctions) { + + Signature.Builder builder = + Signature.builder() + .methodName(nativeFunction.getFunctionDef().getSignature().getName()) + .key(nativeFunction.getName()); + + for (ArgDef input : nativeFunction.getFunctionDef().getSignature().getInputArgList()) { + TensorInfo info = + TensorInfo.newBuilder() + .setDtype(input.getType()) + .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build()) + .setName(input.getName()) + .build(); + + builder.input(input.getName(), info); + } + + for (ArgDef outputDef : nativeFunction.getFunctionDef().getSignature().getOutputArgList()) { + TensorInfo info = + TensorInfo.newBuilder() + .setDtype(outputDef.getType()) + .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build()) + .setName(outputDef.getName()) + .build(); + + builder.output(outputDef.getName(), info); + } + + return new ConcreteFunction(builder.build(), nativeFunction, availableFunctions); + } + + private final Signature signature; + private final NativeFunction nativeFunction; + private final PointerScope scope; + private final Set dependencies; + private final DataType[] inputDtypes; + private final DataType[] outputDtypes; + + /** All native functions should have deallocators registered */ + private ConcreteFunction( + Signature signature, NativeFunction nativeFunction, Set dependencies) { + this.signature = signature; + this.nativeFunction = nativeFunction; + this.dependencies = Collections.unmodifiableSet(dependencies); + + if (this.signature.getInputs().size() + != nativeFunction.getFunctionDef().getSignature().getInputArgCount()) { + throw new IllegalArgumentException( + "Signature must have the same number of inputs as the native function. Expected " + + nativeFunction.getFunctionDef().getSignature().getInputArgCount() + + ", got " + + this.signature.getInputs().size()); + } + + if (this.signature.getOutputs().size() + != nativeFunction.getFunctionDef().getSignature().getOutputArgCount()) { + throw new IllegalArgumentException( + "New signature must have the same number of outputs as the native function. Expected " + + nativeFunction.getFunctionDef().getSignature().getOutputArgCount() + + ", got " + + this.signature.getOutputs().size()); + } + + inputDtypes = + this.signature.getInputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new); + + List inputs = Arrays.asList(inputDtypes); + List nativeInputs = + nativeFunction.getFunctionDef().getSignature().getInputArgList().stream() + .map(ArgDef::getType) + .collect(Collectors.toList()); + + if (!dataTypesMatch(inputs, nativeInputs)) { + throw new IllegalArgumentException( + "Data types of the signature's inputs must match the native function's (in order). Expected " + + nativeInputs + + ", got " + + inputs); + } + + outputDtypes = + signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new); + + List outputs = Arrays.asList(outputDtypes); + List nativeOutputs = + nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream() + .map(ArgDef::getType) + .collect(Collectors.toList()); + + if (!dataTypesMatch(outputs, nativeOutputs)) { + throw new IllegalArgumentException( + "Data types of the signature's outputs must match the native function's (in order). Expected " + + nativeOutputs + + ", got " + + outputs); + } + + try (PointerScope scope = new PointerScope()) { + this.scope = scope; + scope.extend(); + scope.attach(this.nativeFunction.getNativeHandle()); + this.dependencies.forEach(scope::attach); + } } /** - * Returns the graph of this function + * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because + * how to enable XLA JIT is extremely non-obvious. + * + *

Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered + * platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails). */ - public Graph graph() { - return graph; + private void makeJit() { + try (PointerScope scope = new PointerScope()) { + byte[] bytes = AttrValue.newBuilder().setB(true).build().toByteArray(); + BytePointer trueValue = new BytePointer(bytes); + + TF_Status status1 = TF_Status.newStatus(); + TF_FunctionSetAttrValueProto( + nativeHandle(), "_XlaMustCompile", trueValue, bytes.length, status1); + status1.throwExceptionIfNotOK(); + + TF_Status status2 = TF_Status.newStatus(); + TF_FunctionSetAttrValueProto(nativeHandle(), "_noinline", trueValue, bytes.length, status2); + status2.throwExceptionIfNotOK(); + } } - @Override - public void close() { - if (ownership != Ownership.NONE) { - session.close(); - if (ownership == Ownership.GRAPH_AND_SESSION) { - graph.close(); + private static boolean dataTypesMatch(List a, List b) { + if (a.size() != b.size()) { + return false; + } + + for (int i = 0; i < a.size(); i++) { + DataType aType = a.get(i); + DataType bType = b.get(i); + + if (aType != DataType.DT_INVALID && bType != DataType.DT_INVALID && !a.equals(b)) { + return false; } } + + return true; } - @Override - public String toString() { - return signature.toString(); + private static TF_Operation outputHandle(Operand operand) { + if (operand == null) { + throw new NullPointerException("Can't get output handle for null operand"); + } + + Pointer handle = operand.asOutput().getUnsafeNativeHandle(); + if (handle.isNull()) { + throw new NullPointerException("Native handle of operand is null, has it been closed?"); + } + + if (!(handle instanceof TF_Operation)) { + throw new IllegalArgumentException("Operand was not a graph operand"); + } + + return (TF_Operation) handle; } - private enum Ownership { - GRAPH_AND_SESSION, SESSION_ONLY, NONE; + private static TF_Output resolveToOutput(Graph graph, List> operands) { + TF_Output handles = new TF_Output(operands.size()); + for (int i = 0; i < operands.size(); i++) { + Operand input = operands.get(i); + graph.checkInput(input); + TF_Operation handle = outputHandle(input); + + handles.position(i).oper(handle).index(input.asOutput().index()); + } + handles.position(0); + return handles; } - private final Graph graph; - private final Session session; - private final Signature signature; - private final Ownership ownership; + private static ConcreteFunction buildFromGraph(Graph graph, Signature signature) { + try (PointerScope scope = new PointerScope(); + Reference ref = graph.ref()) { + TF_Status status = TF_Status.newStatus(); - ConcreteFunction(Signature signature, Graph graph, Session session, Ownership ownership) { - this.graph = graph; - this.session = session; - this.signature = signature; - this.ownership = ownership; + List> inputs = + signature.getInputs().entrySet().stream() + .map( + (x) -> + TensorFunction.validateDescription(x.getValue(), graph, x.getKey(), "Input")) + .collect(Collectors.toList()); + + List> outputs = + signature.getOutputs().entrySet().stream() + .map( + (x) -> + TensorFunction.validateDescription(x.getValue(), graph, x.getKey(), "Output")) + .collect(Collectors.toList()); + + List ops = + new ArrayList<>(graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs))); + + inputs.forEach(input -> ops.remove((GraphOperation) input.op())); + + ops.forEach( + x -> { + if (x.type().equals(Placeholder.OP_NAME) + || x.type().equals(PlaceholderWithDefault.OP_NAME)) { + throw new IllegalArgumentException( + "Can't calculate outputs (" + + outputs + + ") from inputs (" + + inputs + + "), " + + "they also depend on \"" + + x + + "\""); + } + }); + + // Python sometimes has NoOps as outputs + Ops tf = Ops.create(graph).withSubScope("functionControlOutputs"); + for (int i = 0; i < outputs.size(); i++) { + Operand output = outputs.get(i); + if (output.op().numOutputs() < 1) { + Operand realOutput = + tf.withControlDependencies(Collections.singletonList(output)) + .withName(output.op().name() + "_control") + .constant(false); + ops.add((GraphOperation) realOutput.op()); + outputs.set(i, realOutput); + } + } + + PointerPointer operations = new PointerPointer<>(ops.size()); + for (int i = 0; i < ops.size(); i++) { + operations.put(i, ops.get(i).getUnsafeNativeHandle()); + } + + TF_Function handle = + TF_GraphToFunction( + ref.nativeHandle(), + new BytePointer(signature.key()), + (byte) 1, + ops.size(), + operations, + inputs.size(), + resolveToOutput(graph, inputs), + outputs.size(), + resolveToOutput(graph, outputs), + null, + null, + new BytePointer( + signature.methodName() != null + ? signature.methodName() + : "Method " + signature.key()), + status); + + handle.withDeallocator(); + status.throwExceptionIfNotOK(); + return new ConcreteFunction( + signature, new NativeFunction(handle), graph.getNativeFunctions(scope)); + } } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index f1dd6216a79..e3283ee2ab3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -1,18 +1,18 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_Execute; @@ -22,6 +22,8 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrBoolList; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFloat; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFloatList; +import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFunctionList; +import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFunctionName; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrInt; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrIntList; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrShape; @@ -35,6 +37,9 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; import org.bytedeco.javacpp.BooleanPointer; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.IntPointer; @@ -88,7 +93,8 @@ public EagerOperationBuilder addInputList(Output[] inputs) { @Override public OperationBuilder addControlInput(Operation control) { - // No-op. Any operations passed to this method will already be evaluated (b/c eager evaluation). + // No-op. Any operations passed to this method will already be evaluated (b/c eager + // evaluation). return this; } @@ -217,15 +223,35 @@ public EagerOperationBuilder setAttr(String name, Shape[] values) { return this; } + @Override + public OperationBuilder setAttr(String name, ConcreteFunction value) { + session.attachFunction(value); + setAttrFunctionName(opHandle, name, value.getDefinedName()); + return this; + } + + @Override + public OperationBuilder setAttr(String name, ConcreteFunction[] value) { + for (ConcreteFunction fn : value) { + session.attachFunction(fn); + } + + setAttrFunctionList( + opHandle, + session.nativeHandle(), + name, + Arrays.stream(value).map(ConcreteFunction::getDefinedName).collect(Collectors.toList())); + + return this; + } + private TFE_Op opHandle; private final EagerSession session; private final String type; private final String name; - /** - * This value should be >= to the maximum number of outputs in any op - */ + /** This value should be >= to the maximum number of outputs in any op */ private static final int MAX_OUTPUTS_PER_OP = 1000; private static void requireOp(TFE_Op handle) { @@ -267,7 +293,8 @@ private static TFE_TensorHandle[] execute(TFE_Op opHandle, EagerSession session) requireOp(opHandle); try (PointerScope scope = new PointerScope()) { IntPointer numRetvals = new IntPointer(1).put(MAX_OUTPUTS_PER_OP); - PointerPointer retvals = new PointerPointer(MAX_OUTPUTS_PER_OP); + PointerPointer retvals = + new PointerPointer(MAX_OUTPUTS_PER_OP); TF_Status status = TF_Status.newStatus(); TFE_Execute(opHandle, retvals, numRetvals, status); status.throwExceptionIfNotOK(); @@ -294,7 +321,8 @@ private static void addInput(TFE_Op opHandle, TFE_TensorHandle tensorHandle) { private static void addInputList(TFE_Op opHandle, TFE_TensorHandle[] tensorHandles) { requireOp(opHandle); try (PointerScope scope = new PointerScope()) { - PointerPointer tensorPointers = new PointerPointer(tensorHandles.length); + PointerPointer tensorPointers = + new PointerPointer(tensorHandles.length); for (int i = 0; i < tensorHandles.length; ++i) { requireTensorHandle(tensorHandles[i]); tensorPointers.put(i, tensorHandles[i]); @@ -363,7 +391,8 @@ private static void setAttrBool(TFE_Op opHandle, String name, boolean value) { private static void setAttrBoolList(TFE_Op opHandle, String name, boolean[] values) { requireOp(opHandle); try (PointerScope scope = new PointerScope()) { - TFE_OpSetAttrBoolList(opHandle, name, new BytePointer(new BooleanPointer(values)), values.length); + TFE_OpSetAttrBoolList( + opHandle, name, new BytePointer(new BooleanPointer(values)), values.length); } } @@ -408,8 +437,36 @@ private static void setAttrShapeList(TFE_Op opHandle, String name, long[] shapes shapesPointer.position(shapesPointer.position() + numDims[i] * 8); } TF_Status status = TF_Status.newStatus(); - TFE_OpSetAttrShapeList(opHandle, new BytePointer(name), shapesPointers, new IntPointer(numDims), - numDims.length, status); + TFE_OpSetAttrShapeList( + opHandle, + new BytePointer(name), + shapesPointers, + new IntPointer(numDims), + numDims.length, + status); + } + } + + private static void setAttrFunctionName(TFE_Op opHandle, String attrName, String functionName) { + requireOp(opHandle); + try (PointerScope scope = new PointerScope()) { + TFE_OpSetAttrFunctionName(opHandle, attrName, functionName, functionName.length()); + } + } + + private static void setAttrFunctionList( + TFE_Op opHandle, TFE_Context context, String attrName, List functionNames) { + requireOp(opHandle); + requireContext(context); + try (PointerScope scope = new PointerScope()) { + PointerPointer fns = new PointerPointer<>(functionNames.size()); + for (int i = 0; i < functionNames.size(); i++) { + TF_Status status = TF_Status.newStatus(); + TFE_Op op = TFE_Op.newOp(context, functionNames.get(i), status); + status.throwExceptionIfNotOK(); + fns.put(i, op); + } + TFE_OpSetAttrFunctionList(opHandle, new BytePointer(attrName), fns, functionNames.size()); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index c5d67128406..84fe7675c40 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -15,6 +15,7 @@ */ package org.tensorflow; +import static org.tensorflow.internal.c_api.global.tensorflow.TFE_ContextAddFunction; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_ContextOptionsSetAsync; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_ContextOptionsSetConfig; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy; @@ -284,6 +285,25 @@ public OperationBuilder opBuilder(String type, String name) { return new EagerOperationBuilder(this, type, name); } + @Override + public void attachFunction(ConcreteFunction function) { + checkSession(); + try (PointerScope scope = new PointerScope()) { + TF_Status status = TF_Status.newStatus(); + TFE_ContextAddFunction(nativeHandle, function.nativeHandle(), status); + status.throwExceptionIfNotOK(); + + function + .getDependencies() + .forEach( + fn -> { + TF_Status status2 = TF_Status.newStatus(); + TFE_ContextAddFunction(nativeHandle, fn, status2); + status2.throwExceptionIfNotOK(); + }); + } + } + @Override public Types environmentType() { return Types.EAGER; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index a18c7fff38b..6f50aeafe98 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -37,6 +37,15 @@ enum Types { */ OperationBuilder opBuilder(String type, String name); + /** + * Attach the function and its dependencies to this execution environment, allowing it to be + * called. + * + *

Done automatically in the {@link org.tensorflow.op.Ops#call(ConcreteFunction, + * java.util.Map)} ops. + */ + void attachFunction(ConcreteFunction function); + /** * Returns true if the given operation is valid in this execution environment. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index b69fe89da0a..f3e712492b8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -18,8 +18,11 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddGradientsWithPrefix; import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteGraph; import static org.tensorflow.internal.c_api.global.tensorflow.TF_FinishWhile; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphCopyFunction; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphGetFunctions; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphImportGraphDef; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphNextOperation; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphNumFunctions; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphOperationByName; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToGraphDef; import static org.tensorflow.internal.c_api.global.tensorflow.TF_ImportGraphDefOptionsSetPrefix; @@ -39,10 +42,12 @@ import java.util.stream.Collectors; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; import org.bytedeco.javacpp.SizeTPointer; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.c_api.TF_Buffer; +import org.tensorflow.internal.c_api.TF_Function; import org.tensorflow.internal.c_api.TF_Graph; import org.tensorflow.internal.c_api.TF_ImportGraphDefOptions; import org.tensorflow.internal.c_api.TF_Operation; @@ -378,6 +383,95 @@ public GraphOperationBuilder opBuilder(String type, String name) { return new GraphOperationBuilder(this, type, name); } + @Override + public void attachFunction(ConcreteFunction function) { + try (Reference ref = ref(); + PointerScope scope = new PointerScope()) { + TF_Status status = TF_Status.newStatus(); + TF_GraphCopyFunction(ref.nativeHandle(), function.nativeHandle(), null, status); + status.throwExceptionIfNotOK(); + + function + .getDependencies() + .forEach( + x -> { + TF_Status status2 = TF_Status.newStatus(); + TF_GraphCopyFunction(ref.nativeHandle(), x, null, status2); + status2.throwExceptionIfNotOK(); + }); + } + } + + /** + * Get the graph's functions. + * + * @param outerScope the pointer scope to attach the functions to. + */ + List getNativeFunctions(PointerScope outerScope) { + try (Reference ref = ref(); + PointerScope scope = new PointerScope()) { + TF_Status status = TF_Status.newStatus(); + + int numFunctions = TF_GraphNumFunctions(ref.nativeHandle()); + + PointerPointer output = new PointerPointer<>(numFunctions); + + TF_GraphGetFunctions(ref.nativeHandle(), output, numFunctions, status); + status.throwExceptionIfNotOK(); + + List funcs = new ArrayList<>(numFunctions); + for (int i = 0; i < numFunctions; i++) { + TF_Function function = output.get(TF_Function.class, i); + + function.withDeallocator(); + outerScope.attach(function); + + funcs.add(new NativeFunction(function)); + } + + return funcs; + } + } + + /** + * Get the function attached to the graph with the given native name. Returns {@code null} if none + * found. + * + * @param key the name of the native function. Note that this may include an argument hash. + * @return the found {@link ConcreteFunction}, or {@code null} if none were found with the correct + * name + */ + public ConcreteFunction getFunction(String key) { + try (Reference ref = ref(); + PointerScope scope = new PointerScope()) { + List funcs = getNativeFunctions(scope); + + for (NativeFunction f : funcs) { + + if (f.getName().equals(key)) { + return ConcreteFunction.fromNativeHandle(f, funcs); + } + } + } + return null; + } + + /** + * Get the functions attached to the graph. + * + * @return all functions attached to this graph. + */ + public List getFunctions() { + try (Reference ref = ref(); + PointerScope scope = new PointerScope()) { + List funcs = getNativeFunctions(scope); + + return funcs.stream() + .map(x -> ConcreteFunction.fromNativeHandle(x, funcs)) + .collect(Collectors.toList()); + } + } + @Override public Types environmentType() { return Types.GRAPH; @@ -1077,12 +1171,20 @@ private static SaverDef addVariableSaver(Graph graph) { } } + Placeholder saveFilename = tf.withName("filename").placeholder(TString.class); + + if (varNames.isEmpty()) { + return SaverDef.newBuilder() + .setFilenameTensorName(saveFilename.op().name()) + .setSaveTensorName(tf.withName("empty_save").identity(saveFilename).op().name()) + .setRestoreOpName(tf.withName("restore_all").noOp().op().name()) + .build(); + } + // FIXME Need an easier way to initialize an NdArray from a list String[] tmp = new String[varNames.size()]; Constant varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp))); Operand varSlices = tf.zerosLike(varNamesTensor); - - Placeholder saveFilename = tf.withName("filename").placeholder(TString.class); Save saveVariables = tf.train.save(saveFilename, varNamesTensor, varSlices, varOutputs); Identity id = tf.withControlDependencies(Arrays.asList(saveFilename, saveVariables)) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index 72858ece572..53ab50db4b4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -1,18 +1,18 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddControlInput; @@ -24,6 +24,7 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrBoolList; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrFloat; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrFloatList; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrFuncName; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrInt; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrIntList; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrShape; @@ -34,9 +35,13 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrTensorList; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrType; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrTypeList; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrValueProto; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetDevice; import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; import org.bytedeco.javacpp.BooleanPointer; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.IntPointer; @@ -45,6 +50,7 @@ import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; import org.bytedeco.javacpp.SizeTPointer; +import org.tensorflow.Graph.Reference; import org.tensorflow.internal.c_api.TF_Graph; import org.tensorflow.internal.c_api.TF_Operation; import org.tensorflow.internal.c_api.TF_OperationDescription; @@ -52,11 +58,12 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.AttrValue; +import org.tensorflow.proto.framework.AttrValue.ListValue; import org.tensorflow.proto.framework.DataType; +import org.tensorflow.proto.framework.NameAttrList; -/** - * An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. - */ +/** An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. */ public final class GraphOperationBuilder implements OperationBuilder { GraphOperationBuilder(Graph graph, String type, String name) { @@ -94,7 +101,8 @@ public GraphOperationBuilder addControlInput(Operation control) { } if (control.env() != graph) { - throw new IllegalArgumentException("Control input " + control + " was from a different graph, can't use."); + throw new IllegalArgumentException( + "Control input " + control + " was from a different graph, can't use."); } Graph.Reference r = graph.ref(); @@ -344,6 +352,30 @@ public GraphOperationBuilder setAttr(String name, String[] value) { return this; } + @Override + public OperationBuilder setAttr(String name, ConcreteFunction value) { + graph.attachFunction(value); + try (Reference r = graph.ref()) { + setAttrFunctionName(unsafeNativeHandle, name, value.getDefinedName()); + } + return this; + } + + @Override + public OperationBuilder setAttr(String name, ConcreteFunction[] value) { + for (ConcreteFunction f : value) { + graph.attachFunction(f); + } + + try (Reference r = graph.ref()) { + setAttrFunctionList( + unsafeNativeHandle, + name, + Arrays.stream(value).map(ConcreteFunction::getDefinedName).collect(Collectors.toList())); + } + return this; + } + private TF_OperationDescription unsafeNativeHandle; private Graph graph; @@ -394,11 +426,16 @@ private static void addInput(TF_OperationDescription handle, TF_Operation opHand } } - private static void addInputList(TF_OperationDescription handle, TF_Operation[] opHandles, int[] indices) { + private static void addInputList( + TF_OperationDescription handle, TF_Operation[] opHandles, int[] indices) { requireHandle(handle); if (indices.length != opHandles.length) { - throw new IllegalArgumentException("mismatch in number of Operations (" - + opHandles.length + ") and output indices (" + indices.length + ") provided"); + throw new IllegalArgumentException( + "mismatch in number of Operations (" + + opHandles.length + + ") and output indices (" + + indices.length + + ") provided"); } try (PointerScope scope = new PointerScope()) { @@ -412,8 +449,8 @@ private static void addInputList(TF_OperationDescription handle, TF_Operation[] private static void addControlInput(TF_OperationDescription handle, TF_Operation opHandle) { if (opHandle == null || opHandle.isNull()) { - throw new IllegalStateException("control input is not valid, " - + "perhaps the Graph containing it has been closed()?"); + throw new IllegalStateException( + "control input is not valid, " + "perhaps the Graph containing it has been closed()?"); } requireHandle(handle); TF_AddControlInput(handle, opHandle); @@ -459,7 +496,8 @@ private static void setAttrBool(TF_OperationDescription handle, String name, boo TF_SetAttrBool(handle, name, (byte) (value ? 1 : 0)); } - private static void setAttrBoolList(TF_OperationDescription handle, String name, boolean[] value) { + private static void setAttrBoolList( + TF_OperationDescription handle, String name, boolean[] value) { requireHandle(handle); try (PointerScope scope = new PointerScope()) { TF_SetAttrBoolList(handle, name, new BytePointer(new BooleanPointer(value)), value.length); @@ -476,7 +514,8 @@ private static void setAttrTypeList(TF_OperationDescription handle, String name, TF_SetAttrTypeList(handle, name, type, type.length); } - private static void setAttrTensor(TF_OperationDescription handle, String name, TF_Tensor tensorHandle) { + private static void setAttrTensor( + TF_OperationDescription handle, String name, TF_Tensor tensorHandle) { requireHandle(handle); requireTensor(tensorHandle); @@ -487,7 +526,8 @@ private static void setAttrTensor(TF_OperationDescription handle, String name, T } } - private static void setAttrTensorList(TF_OperationDescription handle, String name, TF_Tensor[] tensorHandles) { + private static void setAttrTensorList( + TF_OperationDescription handle, String name, TF_Tensor[] tensorHandles) { requireHandle(handle); try (PointerScope scope = new PointerScope()) { @@ -498,12 +538,14 @@ private static void setAttrTensorList(TF_OperationDescription handle, String nam } TF_Status status = TF_Status.newStatus(); - TF_SetAttrTensorList(handle, new BytePointer(name), tensors.position(0), tensorHandles.length, status); + TF_SetAttrTensorList( + handle, new BytePointer(name), tensors.position(0), tensorHandles.length, status); status.throwExceptionIfNotOK(); } } - private static void setAttrShape(TF_OperationDescription handle, String name, long[] shape, int numDims) { + private static void setAttrShape( + TF_OperationDescription handle, String name, long[] shape, int numDims) { requireHandle(handle); // num_dims and env->GetArrayLength(shape) are assumed to be consistent. @@ -511,7 +553,8 @@ private static void setAttrShape(TF_OperationDescription handle, String name, lo TF_SetAttrShape(handle, name, shape, numDims); } - private static void setAttrShapeList(TF_OperationDescription handle, String name, long[] shapes, int[] numDims) { + private static void setAttrShapeList( + TF_OperationDescription handle, String name, long[] shapes, int[] numDims) { requireHandle(handle); try (PointerScope scope = new PointerScope()) { @@ -521,11 +564,13 @@ private static void setAttrShapeList(TF_OperationDescription handle, String name shapesPointers.put(i, shapesPointer); shapesPointer.position(shapesPointer.position() + numDims[i] * 8); } - TF_SetAttrShapeList(handle, new BytePointer(name), shapesPointers, new IntPointer(numDims), numDims.length); + TF_SetAttrShapeList( + handle, new BytePointer(name), shapesPointers, new IntPointer(numDims), numDims.length); } } - private static void setAttrStringList(TF_OperationDescription handle, String name, byte[][] value) { + private static void setAttrStringList( + TF_OperationDescription handle, String name, byte[][] value) { requireHandle(handle); try (PointerScope scope = new PointerScope()) { @@ -539,4 +584,33 @@ private static void setAttrStringList(TF_OperationDescription handle, String nam TF_SetAttrStringList(handle, new BytePointer(name), valuePointers, lengths, value.length); } } + + private static void setAttrFunctionName( + TF_OperationDescription opHandle, String attrName, String functionName) { + requireHandle(opHandle); + try (PointerScope scope = new PointerScope()) { + TF_SetAttrFuncName(opHandle, attrName, functionName, functionName.length()); + } + } + + private static void setAttrFunctionList( + TF_OperationDescription opHandle, String attrName, List functionNames) { + requireHandle(opHandle); + try (PointerScope scope = new PointerScope()) { + TF_Status status = TF_Status.newStatus(); + AttrValue value = + AttrValue.newBuilder() + .setList( + ListValue.newBuilder() + .addAllFunc( + functionNames.stream() + .map(x -> NameAttrList.newBuilder().setName(x).build()) + .collect(Collectors.toList())) + .build()) + .build(); + byte[] bytes = value.toByteArray(); + TF_SetAttrValueProto(opHandle, attrName, new BytePointer(bytes), bytes.length, status); + status.throwExceptionIfNotOK(); + } + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java new file mode 100644 index 00000000000..faab6dbca7b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java @@ -0,0 +1,155 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow; + +import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionName; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionToFunctionDef; + +import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.stream.Collectors; +import org.bytedeco.javacpp.PointerScope; +import org.tensorflow.internal.c_api.TF_Buffer; +import org.tensorflow.internal.c_api.TF_Function; +import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.proto.framework.FunctionDef; +import org.tensorflow.proto.framework.NodeDef; + +/** + * A class holding a native function handle and providing cached access to it's {@link FunctionDef}. + */ +class NativeFunction { + public NativeFunction(TF_Function nativeHandle) { + this.nativeHandle = nativeHandle; + } + + /** Get the native handle. No guarantees about liveness are made. */ + public TF_Function getNativeHandle() { + return nativeHandle; + } + + /** Get the function's {@link FunctionDef} */ + public synchronized FunctionDef getFunctionDef() { + if (functionDef == null) { + try (PointerScope scope = new PointerScope()) { + TF_Buffer funcDefBuffer = TF_Buffer.newBuffer(); + TF_Status status = TF_Status.newStatus(); + + TF_FunctionToFunctionDef(nativeHandle, funcDefBuffer, status); + status.throwExceptionIfNotOK(); + + try { + functionDef = FunctionDef.parseFrom(funcDefBuffer.dataAsByteBuffer()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalStateException("Failed to parse FunctionDef proto", e); + } + } + } + + return functionDef; + } + + /** Get the first-level dependencies of the function. */ + public synchronized List getDependencies() { + if (dependencies == null) { + Set deps = new LinkedHashSet<>(); + + for (NodeDef node : getFunctionDef().getNodeDefList()) { + if (node.getOp().equals(ConcreteFunction.CALL_OP) + || node.getOp().equals(ConcreteFunction.STATEFUL_CALL_OP)) { + deps.add(node.getAttrMap().get("f").getFunc().getName()); + } + } + dependencies = Collections.unmodifiableList(new ArrayList<>(deps)); + } + + return dependencies; + } + + /** Get whether the function is stateful (whether it has stateful ops). */ + public synchronized boolean isStateful() { + if (stateful == null) { + stateful = + getFunctionDef().getSignature().getIsStateful() + || getFunctionDef().getNodeDefList().stream() + .anyMatch(x -> TensorFlow.isOpStateful(x.getOp())); + } + return stateful; + } + + /** Get the name of the function. */ + public synchronized String getName() { + if (name == null) { + try (PointerScope scope = new PointerScope()) { + return TF_FunctionName(nativeHandle).getString(); + } + } + + return name; + } + + synchronized Set getAllDependencies(Collection availableFunctions) { + Map fnMap = + availableFunctions.stream().collect(Collectors.toMap(NativeFunction::getName, e -> e)); + Set done = new LinkedHashSet<>(1 + getDependencies().size()); + + Queue todo = new ArrayDeque<>(1 + getDependencies().size()); + todo.add(this); + + while (!todo.isEmpty()) { + NativeFunction next = todo.remove(); + + if (!done.add(next.getName())) { + continue; + } + + for (String dep : next.getDependencies()) { + if (!done.contains(dep)) { + NativeFunction fn = fnMap.get(dep); + + if (fn == null) { + throw new IllegalStateException( + "Function " + dep + " is required, but not present in graph."); + } + + todo.add(fn); + } + } + } + + done.remove(getName()); + + return done.stream() + .map(fnMap::get) + .map(NativeFunction::getNativeHandle) + .collect(Collectors.toSet()); + } + + private final TF_Function nativeHandle; + + private FunctionDef functionDef = null; + private List dependencies = null; + private Boolean stateful = null; + private String name = null; +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java index a487d8b9237..569f37c8f4a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import org.tensorflow.ndarray.Shape; @@ -49,7 +49,7 @@ public interface OperationBuilder { * *

The OperationBuilder is not usable after build() returns. */ - Operation build(); + Operation build(); /** * Add the output of another operation as the next input of the operation being built. @@ -57,7 +57,7 @@ public interface OperationBuilder { * @param input {@link Output} supposed to be the input of the operation being built. * @return the OperationBuilder instance for chaining. */ - OperationBuilder addInput(Output input); + OperationBuilder addInput(Output input); /** * Add the outputs of another operation as the next inputs of the operation being built. @@ -65,7 +65,7 @@ public interface OperationBuilder { * @param inputs list of {@link Output} supposed to be the inputs of the operation being built. * @return the OperationBuilder instance for chaining. */ - OperationBuilder addInputList(Output[] inputs); + OperationBuilder addInputList(Output[] inputs); /** * Ensure that the operation does not execute before the control operation does. @@ -80,7 +80,7 @@ public interface OperationBuilder { * @param control operation that must be executed before running this operation. * @return the OperationBuilder instance for chaining. */ - OperationBuilder addControlInput(Operation control); + OperationBuilder addControlInput(Operation control); /** * Set the device requested for computing the operation being built. @@ -88,7 +88,7 @@ public interface OperationBuilder { * @param device the requested device, as a string * @return the OperationBuilder instance for chaining. */ - OperationBuilder setDevice(String device); + OperationBuilder setDevice(String device); /** * Set the string values of an attribute of the operation being built. @@ -97,7 +97,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, String[] value); + OperationBuilder setAttr(String name, String[] value); /** * Set the string value of an attribute of the operation being built. @@ -106,7 +106,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, String value); + OperationBuilder setAttr(String name, String value); /** * Set the byte values of an attribute of the operation being built. @@ -115,7 +115,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, byte[] value); + OperationBuilder setAttr(String name, byte[] value); /** * Set the long value of an attribute of the operation being built. @@ -124,7 +124,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, long value); + OperationBuilder setAttr(String name, long value); /** * Set the long values of an attribute of the operation being built. @@ -133,7 +133,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, long[] value); + OperationBuilder setAttr(String name, long[] value); /** * Set the float value of an attribute of the operation being built. @@ -142,7 +142,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, float value); + OperationBuilder setAttr(String name, float value); /** * Set the float values of an attribute of the operation being built. @@ -151,7 +151,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, float[] value); + OperationBuilder setAttr(String name, float[] value); /** * Set the boolean value of an attribute of the operation being built. @@ -160,7 +160,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, boolean value); + OperationBuilder setAttr(String name, boolean value); /** * Set the boolean values of an attribute of the operation being built. @@ -169,7 +169,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, boolean[] value); + OperationBuilder setAttr(String name, boolean[] value); /** * Set the type value of an attribute of the operation being built. @@ -178,7 +178,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, DataType value); + OperationBuilder setAttr(String name, DataType value); /** * Set the type values of an attribute of the operation being built. @@ -187,7 +187,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, DataType[] value); + OperationBuilder setAttr(String name, DataType[] value); /** * Set the tensor value of an attribute of the operation being built. @@ -196,7 +196,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, Tensor value); + OperationBuilder setAttr(String name, Tensor value); /** * Set the tensor values of an attribute of the operation being built. @@ -205,7 +205,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, Tensor[] value); + OperationBuilder setAttr(String name, Tensor[] value); /** * Set the shape value of an attribute of the operation being built. @@ -214,7 +214,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, Shape value); + OperationBuilder setAttr(String name, Shape value); /** * Set the shape values of an attribute of the operation being built. @@ -223,5 +223,25 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, Shape[] value); + OperationBuilder setAttr(String name, Shape[] value); + + /** + * Set the function value of an attribute of the operation being built. Also attaches the function + * and dependencies to the execution environment. + * + * @param name attribute name + * @param value attribute value + * @return the OperationBuilder instance for chaining. + */ + OperationBuilder setAttr(String name, ConcreteFunction value); + + /** + * Set the function values of an attribute of the operation being built. Also attaches the + * functions and dependencies to the execution environment. + * + * @param name attribute name + * @param value attribute value + * @return the OperationBuilder instance for chaining. + */ + OperationBuilder setAttr(String name, ConcreteFunction[] value); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 6992e5eee37..3a6433701e6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadSessionFromSavedModel; @@ -25,11 +25,13 @@ import java.io.OutputStream; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.stream.Collectors; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.PointerPointer; @@ -50,6 +52,9 @@ /** * SavedModelBundle represents a model loaded from storage. * + *

All operations on a loaded bundle, and any functions from it, share the same underlying + * session. The session is initialized when loaded. + * *

The model consists of a description of the computation (a {@link Graph}), a {@link Session} * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage, * and a description of the model as a The concrete function carries a signature (i.e. a list of user-friendly input and outputs - * names to a graph) and a valid session to a graph to be saved in the model. + *

The function carries a signature (i.e. a list of user-friendly input and outputs names to + * a graph) and a valid session to a graph to be saved in the model. * *

Note:Eventually, TensorFlow for Java will support the export of functions objects like - * the Python API does but right now, only session-centric models are supported (i.e. models that - * has a single main graph and one or more signatures). These models are compatible with those - * exported by TensorFlow 1.x or by TensorFlow 2.x estimators. - * - *
Therefore, all functions exported in a model should share the same session at the moment - * or an exception will be thrown.
+ * the Python API does but right now, only session-centric models are supported (i.e. models + * that has a single main graph and one or more signatures). These models are compatible with + * those exported by TensorFlow 1.x or by TensorFlow 2.x estimators.
+ * Therefore, all functions exported in a model should share the same session at the moment or + * an exception will be thrown. This applies to sessions set via {@link + * #withSession(Session)} as well, the exporter can only even have one session. * * @param function a function carrying a signature and a valid session to the graph to be saved * @return this object - * @throws IllegalArgumentException if a function with the same name has already been added to the model - * @throws UnsupportedOperationException if this function does not share the same session with the other - * functions added to this model + * @throws IllegalArgumentException if a function with the same name has already been added to + * the model + * @throws UnsupportedOperationException if the session is already set to a different session */ - public Exporter withFunction(ConcreteFunction function) { + public Exporter withFunction(SessionFunction function) { Signature signature = function.signature(); if (functions.containsKey(signature.key())) { - throw new IllegalArgumentException("Function \"" + signature.key() + "\" was already added to the model"); + throw new IllegalArgumentException( + "Function \"" + signature.key() + "\" was already added to the model"); + } + if (session != null && session != function.session()) { + throw new UnsupportedOperationException( + "This exporter already has a session that differs from the passed function's session"); } + + session = function.session(); functions.put(signature.key(), function); + metaGraphDefBuilder.putSignatureDef(signature.key(), signature.asSignatureDef()); + return this; + } + + /** + * Save multiple functions. Wrapper around {@link #withFunction(SessionFunction)}. All functions + * must have the same session. + * + * @param functions the functions to export + * @return this object + * @throws IllegalArgumentException if a function with the same name has already been added to + * the model + * @throws UnsupportedOperationException if the session is already set to a different session + * @see #withFunction(SessionFunction) + */ + public Exporter withFunctions(SessionFunction... functions) { + for (SessionFunction f : functions) { + withFunction(f); + } + return this; + } + + /** + * Add a signature to the model. This wraps the signature in a {@link SessionFunction} using the + * exporter's already-set session. As such, either {@link #withSession(Session)} or {@link + * #withFunction(SessionFunction)} must be called before this method. + * + * @throws IllegalStateException if no session has been set + * @return this + */ + public Exporter withSignature(Signature signature) { if (session == null) { - session = function.session(); - } else if (session != function.session()) { - throw new UnsupportedOperationException("Saving multiple functions with different graphs/sessions is not supported yet."); + throw new IllegalStateException( + "Session has not been set yet, you must call withSession or withFunction first."); + } + return withFunction(session.function(signature)); + } + + /** + * Add multiple signatures to the model. Wraps {@link #withSignature(Signature)} + * + *

Either {@link #withSession(Session)} or {@link * #withFunction(SessionFunction)} must + * be called before this method, and the session set there will be used for these + * signatures. + * + * @throws IllegalStateException if no session has been set + * @return this + * @see #withSession(Session) + */ + public Exporter withSignatures(Signature... signatures) { + for (Signature s : signatures) { + withSignature(s); } - metaGraphDefBuilder.putSignatureDef(signature.key(), signature.asSignatureDef()); return this; } @@ -178,7 +252,7 @@ public Exporter withFunction(ConcreteFunction function) { * @throws IOException if saved model or variable state cannot be written on disk */ public void export() throws IOException { - if (functions.isEmpty() || session == null) { + if (functions.isEmpty()) { throw new IllegalStateException("Model should contain at least one valid function"); } Graph graph = session.graph(); @@ -187,10 +261,11 @@ public void export() throws IOException { // new ops to the graph for saving and restoring the variables. SaverDef saverDef = graph.saverDef(); - MetaGraphDef.Builder metaGraphDef = metaGraphDefBuilder - .setSaverDef(saverDef) - .setGraphDef(graph.toGraphDef()) - .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags))); + MetaGraphDef.Builder metaGraphDef = + metaGraphDefBuilder + .setSaverDef(saverDef) + .setGraphDef(graph.toGraphDef()) + .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags))); functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef())); // Make sure saved model directories exist @@ -213,10 +288,10 @@ public void export() throws IOException { } private final String exportDir; - private String[] tags = { DEFAULT_TAG }; + private String[] tags = {DEFAULT_TAG}; private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder(); - private final Map functions = new LinkedHashMap<>(); private Session session; + private final Map functions = new LinkedHashMap<>(); } /** @@ -289,9 +364,7 @@ public Session session() { return session; } - /** - * Return the signature of all functions available in this saved model. - */ + /** Return the signature of all functions available in this saved model. */ public List signatures() { return functions.values().stream().map(f -> f.signature()).collect(Collectors.toList()); } @@ -304,13 +377,14 @@ public List signatures() { * Map outputTensorMap = myFunction.call(session, inputTensorMap); * }

* + * All functions use the bundle's underlying session. + * * @param signatureKey name of the {@code SignatureDef} in the saved model. * @return object that can be used to make calls to a function - * @throws IllegalArgumentException if {@code signatureKey} is not found in this - * saved model. + * @throws IllegalArgumentException if {@code signatureKey} is not found in this saved model. */ - public ConcreteFunction function(String signatureKey) { - ConcreteFunction function = functions.get(signatureKey); + public TensorFunction function(String signatureKey) { + SessionFunction function = functions.get(signatureKey); if (function == null) { throw new IllegalArgumentException( String.format("Function with signature [%s] not found", signatureKey)); @@ -318,24 +392,37 @@ public ConcreteFunction function(String signatureKey) { return function; } + /** + * Get all functions in the bundle. + * + *

All functions use the bundle's underlying session. + */ + public List functions() { + return new ArrayList<>(functions.values()); + } + /** * Invokes the default function directly from this model. * *

The default function selection is done based on the first of the following conditions that * is true: + * *

    - *
  • The function is the only signature available attached to the main graph of this saved model
  • - *
  • The function is mapped to the default signature name, which is "serving_default"
  • + *
  • The function is the only signature available attached to the main graph of this saved + * model + *
  • The function is mapped to the default signature name, which is "serving_default" *
* *

Caller is responsible for closing all returned Tensors. * + *

This uses the model's underlying session + * * @param arguments list of input tensors, mapped by their signature name * @return list of output tensors, mapped by the signature name * @throws IllegalArgumentException if no function can be selected by default */ public Map call(Map arguments) { - ConcreteFunction function = null; + SessionFunction function = null; if (functions.size() == 1) { function = functions.values().iterator().next(); } else { @@ -360,13 +447,17 @@ public void close() { private final Graph graph; private final Session session; private final MetaGraphDef metaGraphDef; - private final Map functions; + private final Map functions; - private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map functions) { + private SavedModelBundle( + Graph graph, Session session, MetaGraphDef metaGraphDef, Map signatures) { this.graph = graph; this.session = session; this.metaGraphDef = metaGraphDef; - this.functions = functions; + this.functions = + signatures.entrySet().stream() + .collect( + Collectors.toMap(Entry::getKey, e -> new SessionFunction(e.getValue(), session))); } /** @@ -385,11 +476,17 @@ private static SavedModelBundle fromHandle( // Note that the saved model will remain the owner of the graph and the session, meaning // that the functions do not need to be closed by the user and if it does, it should have // no effect. - final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount()); - metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> { - Signature signature = new Signature(signatureName, signatureDef); - functions.put(signatureName, ConcreteFunction.create(signature, session)); - }); + final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount()); + + metaGraphDef + .getSignatureDefMap() + .forEach( + (signatureName, signatureDef) -> { + if (!functions.containsKey(signatureName)) { + Signature signature = new Signature(signatureName, signatureDef); + functions.put(signatureName, signature); + } + }); return new SavedModelBundle(graph, session, metaGraphDef, functions); } @@ -412,14 +509,22 @@ private static SavedModelBundle load( // load the session TF_Graph graph = TF_NewGraph(); TF_Buffer metagraphDef = TF_Buffer.newBuffer(); - TF_Session session = TF_LoadSessionFromSavedModel( - opts, runOpts, new BytePointer(exportDir), new PointerPointer(tags), - tags.length, graph, metagraphDef, status); + TF_Session session = + TF_LoadSessionFromSavedModel( + opts, + runOpts, + new BytePointer(exportDir), + new PointerPointer(tags), + tags.length, + graph, + metagraphDef, + status); status.throwExceptionIfNotOK(); // handle the result try { - bundle = fromHandle(graph, session, MetaGraphDef.parseFrom(metagraphDef.dataAsByteBuffer())); + bundle = + fromHandle(graph, session, MetaGraphDef.parseFrom(metagraphDef.dataAsByteBuffer())); } catch (InvalidProtocolBufferException e) { throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 58fb62b5fee..fd0b390bc28 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.Graph.resolveOutputs; @@ -23,6 +23,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -89,9 +90,11 @@ public Session(Graph g) { * Construct a new session with the associated {@link Graph} and configuration options. * * @param g The {@link Graph} the created Session will operate on. - * @param config Configuration parameters for the session specified as a ConfigProto - * protocol buffer. - * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto protocol buffer. + * @param config Configuration parameters for the session specified as a ConfigProto + * protocol buffer. + * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto + * protocol buffer. */ public Session(Graph g, ConfigProto config) { graph = g; @@ -104,9 +107,7 @@ public Session(Graph g, ConfigProto config) { } } - /** - * Wrap an existing session with the associated {@link Graph}. - */ + /** Wrap an existing session with the associated {@link Graph}. */ Session(Graph g, TF_Session nativeHandle) { graph = g; this.nativeHandle = nativeHandle; @@ -144,20 +145,22 @@ public void close() { * Run {@link Operation}s and evaluate {@link Tensor Tensors}. * *

A Runner runs the necessary graph fragments to execute every {@link Operation} required to - * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String, int, Tensor)} call allows callers to - * override the value of {@link Tensor Tensors} in the graph by substituting the provided {@link Tensor Tensors} for - * the outputs of the operations provided to {@link #feed(String, int, Tensor)}. + * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String, int, Tensor)} call + * allows callers to override the value of {@link Tensor Tensors} in the graph by substituting the + * provided {@link Tensor Tensors} for the outputs of the operations provided to {@link + * #feed(String, int, Tensor)}. */ public final class Runner { /** * Avoid evaluating {@code operation} and substitute {@code t} for the value it produces. * - * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code - * feed(operation, 0)}, or it is a string of the form - * operation_name:output_index , in which case this method acts like {@code - * feed(operation_name, output_index)}. These colon-separated names are commonly used in the {@code SignatureDef} - * protocol buffer messages that are included in {@link SavedModelBundle#metaGraphDef()}. + * @param operation Is either the string name of the operation, in which case this method is a + * shorthand for {@code feed(operation, 0)}, or it is a string of the form + * operation_name:output_index , in which case this method acts like {@code + * feed(operation_name, output_index)}. These colon-separated names are commonly used in the + * {@code SignatureDef} protocol buffer messages that are included in {@link + * SavedModelBundle#metaGraphDef()}. * @param t the tensor substituting the operation * @return this session runner * @throws IllegalArgumentException if no output exists with the provided name @@ -167,8 +170,8 @@ public Runner feed(String operation, Tensor t) { } /** - * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t} for the value it - * produces. + * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t} + * for the value it produces. * *

Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which * one {@code t} is being provided for. @@ -187,7 +190,8 @@ public Runner feed(String operation, int index, Tensor t) { } /** - * Use {@code t} instead of the Tensor referred to by executing the operation referred to by {@code operand}. + * Use {@code t} instead of the Tensor referred to by executing the operation referred to by + * {@code operand}. * * @param operand the node in the graph representing the operation to substitute * @param t the tensor substituting the operation @@ -195,8 +199,12 @@ public Runner feed(String operation, int index, Tensor t) { */ public Runner feed(Operand operand, Tensor t) { if (operand.env() != graph) { - throw new IllegalStateException("Can't feed value for operand " + operand + ", it is from " + - (operand.env().isEager() ? "an eager session" : "a different graph") + "."); + throw new IllegalStateException( + "Can't feed value for operand " + + operand + + ", it is from " + + (operand.env().isEager() ? "an eager session" : "a different graph") + + "."); } inputs.add(operand.asOutput()); @@ -207,13 +215,14 @@ public Runner feed(Operand operand, Tensor t) { /** * Make {@link #run()} return the output of {@code operation}. * - * If the output is a resource variable, will fetch the value. + *

If the output is a resource variable, will fetch the value. * - * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code - * fetch(operation, 0)}, or it is a string of the form - * operation_name:output_index , in which case this method acts like {@code - * fetch(operation_name, output_index)}. These colon-separated names are commonly used in the {@code SignatureDef} - * protocol buffer messages that are included in {@link SavedModelBundle#metaGraphDef()}. + * @param operation Is either the string name of the operation, in which case this method is a + * shorthand for {@code fetch(operation, 0)}, or it is a string of the form + * operation_name:output_index , in which case this method acts like {@code + * fetch(operation_name, output_index)}. These colon-separated names are commonly used in + * the {@code SignatureDef} protocol buffer messages that are included in {@link + * SavedModelBundle#metaGraphDef()}. * @return this session runner * @throws IllegalArgumentException if no output exists with the provided name */ @@ -224,7 +233,7 @@ public Runner fetch(String operation) { /** * Make {@link #run()} return the {@code index}-th output of {@code operation}. * - * If the output is a resource variable, will fetch the value. + *

If the output is a resource variable, will fetch the value. * *

Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which * one to return. @@ -242,15 +251,19 @@ public Runner fetch(String operation, int index) { /** * Makes {@link #run()} return the Tensor referred to by {@code output}. * - * If {@code output} is a resource variable, will fetch the value. + *

If {@code output} is a resource variable, will fetch the value. * * @param output the node to fetch the tensor from * @return this session runner */ public Runner fetch(Output output) { if (output.env() != graph) { - throw new IllegalStateException("Can't fetch output " + output + ", it is from " + - (output.env().isEager() ? "an eager session" : "a different graph") + "."); + throw new IllegalStateException( + "Can't fetch output " + + output + + ", it is from " + + (output.env().isEager() ? "an eager session" : "a different graph") + + "."); } if (output.dataType() == DataType.DT_RESOURCE) { @@ -275,8 +288,11 @@ public Runner fetch(Output output) { } if (read == null) { - read = Ops.create(graph).withSubScope("session_reads").withName(output.op().name() + "_read") - .readVariableOp(output, TensorTypeRegistry.find(valueDt).type()); + read = + Ops.create(graph) + .withSubScope("session_reads") + .withName(output.op().name() + "_read") + .readVariableOp(output, TensorTypeRegistry.find(valueDt).type()); } outputs.add(read.asOutput()); @@ -289,7 +305,7 @@ public Runner fetch(Output output) { /** * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. * - * If {@code operand} is a resource variable, will fetch the value. + *

If {@code operand} is a resource variable, will fetch the value. * * @param operand the node to fetch the tensor from, as an operand * @return this session runner @@ -299,7 +315,8 @@ public Runner fetch(Operand operand) { } /** - * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor Tensors}. + * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor + * Tensors}. * * @param operation the string name of the operation to execute * @return this session runner @@ -310,7 +327,8 @@ public Runner addTarget(String operation) { } /** - * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor Tensors}. + * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor + * Tensors}. * * @param operation the operation to execute * @return this session runner @@ -319,8 +337,12 @@ public Runner addTarget(String operation) { */ public Runner addTarget(Operation operation) { if (operation.env() != graph) { - throw new IllegalStateException("Can't target operation " + operation + ", it is from " + - (operation.env().isEager() ? "an eager session" : "a different graph") + "."); + throw new IllegalStateException( + "Can't target operation " + + operation + + ", it is from " + + (operation.env().isEager() ? "an eager session" : "a different graph") + + "."); } targets.add((GraphOperation) operation); return this; @@ -340,7 +362,8 @@ public Runner addTarget(Op op) { * Set options (typically for debugging) for this run. * *

The options are presented as a RunOptions protocol buffer. + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions + * protocol buffer. * * @param options a {@code RunOptions} proto * @return this session runner @@ -354,11 +377,13 @@ public Runner setOptions(RunOptions options) { * Execute the graph fragments necessary to compute all requested fetches. * *

WARNING: The caller assumes ownership of all returned {@link Tensor Tensors}, i.e., - * the caller must call {@link Tensor#close} on all elements of the returned list to free up resources. + * the caller must call {@link Tensor#close} on all elements of the returned list to free up + * resources. * *

TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it - * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in SessionTest.java), and - * (b) Evaluate whether the return value should be a list, or maybe a {@code Map}? + * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in + * SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a + * {@code Map}? * *

TODO(andrewmyers): It would also be good if whatever is returned here made it easier to * extract output tensors in a type-safe way. @@ -373,7 +398,8 @@ public List run() { * Execute graph fragments to compute requested fetches and return metadata about the run. * *

This is exactly like {@link #run()}, but in addition to the requested Tensors, also - * returns metadata about the graph execution in the form of a RunMetadata + * returns metadata about the graph execution in the form of a RunMetadata * protocol buffer. * * @return list of resulting tensors fetched by this session runner, with execution metadata @@ -474,9 +500,7 @@ public void close() { private RunOptions runOptions = null; } - /** - * Create a Runner to execute graph operations and evaluate Tensors. - */ + /** Create a Runner to execute graph operations and evaluate Tensors. */ public Runner runner() { return new Runner(); } @@ -504,6 +528,24 @@ public void run(Op op) { runner().addTarget(op.op()).run(); } + /** + * Create a new session function, backed by this session. + * + * @param signature the signature of the function. + */ + public SessionFunction function(Signature signature) { + return new SessionFunction(signature, this); + } + + /** + * Create and call a function, returning the results. + * + * @param signature the signature of the function + * @param arguments the arguments to call with. + */ + public Map run(Signature signature, Map arguments) { + return function(signature).call(arguments); + } /** * Execute the graph's initializers. @@ -511,9 +553,12 @@ public void run(Op op) { *

This method is equivalent to {@code session.run(Ops.create(session.graph).init())}. */ public void runInit() { - Runner runner = runner(); - graph.initializers().forEach(runner::addTarget); - runner.run(); + List initializers = graph.initializers(); + if (!initializers.isEmpty()) { + Runner runner = runner(); + initializers.forEach(runner::addTarget); + runner.run(); + } } /** @@ -524,14 +569,15 @@ public void runInit() { * mymodel/myvariables/variables, then the generated files will be located under * mymodel/myvariables and named variables.data-*-of-* * - *

Note that this method might alter the underlying graph if it is the first time that one - * of its sessions is saved, see {@link Graph#saverDef()} for more details. + *

Note that this method might alter the underlying graph if it is the first time that one of + * its sessions is saved, see {@link Graph#saverDef()} for more details. * * @param prefix prefix to the variable files to save */ public void save(String prefix) { SaverDef saverDef = graph.saverDef(); - runner().addTarget(saverDef.getSaveTensorName()) + runner() + .addTarget(saverDef.getSaveTensorName()) .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) .run(); } @@ -539,19 +585,20 @@ public void save(String prefix) { /** * Restore the actual state of the variables of this session's graph. * - *

{@code prefix} is the path where the files containing the variables state live, - * followed by the filename prefix. For example, if {@code prefix} is set to - * mymodel/myvariables/variables, then the files are loaded from - * mymodel/myvariables and named variables.data-*-of-* + *

{@code prefix} is the path where the files containing the variables state live, followed by + * the filename prefix. For example, if {@code prefix} is set to + * mymodel/myvariables/variables, then the files are loaded from mymodel/myvariables + * and named variables.data-*-of-* * - *

Note that this method might alter the underlying graph if it is the first time that one - * of its sessions is saved, see {@link Graph#saverDef()} for more details. + *

Note that this method might alter the underlying graph if it is the first time that one of + * its sessions is saved, see {@link Graph#saverDef()} for more details. * * @param prefix prefix to restore from */ public void restore(String prefix) { SaverDef saverDef = graph.saverDef(); - runner().addTarget(saverDef.getRestoreOpName()) + runner() + .addTarget(saverDef.getRestoreOpName()) .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) .run(); } @@ -563,16 +610,15 @@ public void restore(String prefix) { */ public static final class Run { - /** - * Tensors from requested fetches. - */ + /** Tensors from requested fetches. */ public List outputs; /** * Metadata about the run. * *

A RunMetadata protocol buffer. + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata + * protocol buffer. */ public RunMetadata metadata; } @@ -639,21 +685,22 @@ private static void delete(TF_Session handle) { * * @param handle to the C API TF_Session object (Session.nativeHandle) * @param runOptions A RunOptions protocol buffer, or null - * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed" - * (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a - * Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, - * it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. + * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values + * that are being "fed" (do not need to be computed) during graph execution. + * inputTensorHandles[i] (which corresponds to a Tensor.nativeHandle) is considered to be the + * inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, it is required that + * inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. * @param inputOpHandles (see inputOpIndices) * @param inputOpIndices (see inputTensorHandles) * @param outputOpHandles (see outputOpIndices) - * @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The - * outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length == - * outputOpIndices.length. - * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose output will not be - * returned + * @param outputOpIndices together with outputOpHandles identifies the set of values that should + * be computed. The outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is + * required that outputOpHandles.length == outputOpIndices.length. + * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose + * output will not be returned * @param wantRunMetadata indicates whether metadata about this execution should be returned. - * @param outputTensors will be filled in with tensors to the outputs requested. It is required that outputs.length == - * outputOpHandles.length. + * @param outputTensors will be filled in with tensors to the outputs requested. It is required + * that outputs.length == outputOpHandles.length. * @return if wantRunMetadata is true, a RunMetadata protocol buffer, false otherwise. */ private static RunMetadata run( diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java new file mode 100644 index 00000000000..07bc418ac51 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java @@ -0,0 +1,127 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * A callable function backed by a session. All calls of this function will be ran on the same + * session. + * + *

Does no resource management, the session and all returned tensors are the caller's + * responsibility. + * + *

Does not initialize the session, since it may be shared. + */ +public class SessionFunction implements TensorFunction { + + private final Signature signature; + private final Session session; + + public SessionFunction(Signature signature, Session session) { + this.signature = signature; + this.session = session; + + signature + .getInputs() + .forEach( + (name, description) -> { + TensorFunction.validateDescription(description, session.graph(), name, "Input"); + }); + + signature + .getInputs() + .forEach( + (name, description) -> { + TensorFunction.validateDescription(description, session.graph(), name, "Output"); + }); + } + + public static SessionFunction create(Signature signature, Session session) { + return new SessionFunction(signature, session); + } + + /** + * Save this function using {@link SavedModelBundle}. + * + *

This is identical to calling {@code + * SavedModelBundle.exporter(exportDir).withFunction(this).export()}. + * + * @param exportDir the directory path containing a saved model. + * @throws IOException if saved model or variable state cannot be written on disk + */ + public void save(String exportDir) throws IOException { + SavedModelBundle.exporter(exportDir).withFunction(this).export(); + } + + @Override + public Signature signature() { + return signature; + } + + public Session session() { + return session; + } + + /** + * Get a new function with the same signature, but backed by a new session. + * + * @param session the new backing session. + */ + public SessionFunction withNewSession(Session session) { + return new SessionFunction(signature, session); + } + + @Override + public Map call(Map arguments) { + Session.Runner runner = session.runner(); + signature + .getInputs() + .forEach( + (argName, operand) -> { + if (!arguments.containsKey(argName)) { + throw new IllegalArgumentException( + "No argument found for parameter \"" + argName + "\""); + } + Tensor value = arguments.get(argName); + + if (value == null) { + throw new IllegalArgumentException( + "Can't pass null as an argument to a function. Argument \"" + + argName + + "\" was null."); + } + + runner.feed(operand.name, value); + }); + + signature.getOutputs().values().forEach(x -> runner.fetch(x.name)); + + List results = runner.run(); + + Map outputs = new LinkedHashMap<>(results.size()); + int i = 0; + for (String outputName : signature.outputNames()) { + outputs.put(outputName, results.get(i)); + i++; + } + + return outputs; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index 66b4dad4132..41fab27e068 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -1,21 +1,22 @@ -/* - * Copyright 2020 The TensorFlow Authors. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. +/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= */ package org.tensorflow; -import java.util.HashMap; +import java.util.Collections; +import java.util.LinkedHashMap; import java.util.Map; import java.util.Set; import org.tensorflow.ndarray.Shape; @@ -26,27 +27,32 @@ import org.tensorflow.proto.framework.TensorShapeProto.Dim; /** - * Describe the inputs and outputs of an executable entity, such as a {@link ConcreteFunction}, among - * other useful metadata. + * Describe the inputs and outputs of an executable entity, such as a {@link ConcreteFunction}, + * among other useful metadata. */ -public class Signature { +public class Signature { /** The default signature key, when not provided */ public static final String DEFAULT_KEY = "serving_default"; public static class TensorDescription { + + /** The name of the tensor's operand in the graph */ + public final String name; + /** The data type of the tensor */ public final DataType dataType; + + /** The shape of the tensor */ public final Shape shape; - public TensorDescription(DataType dataType, Shape shape) { + public TensorDescription(DataType dataType, Shape shape, String name) { this.dataType = dataType; this.shape = shape; + this.name = name; } } - /** - * Builds a new function signature. - */ + /** Builds a new function signature. */ public static class Builder { /** @@ -76,12 +82,30 @@ public Builder key(String key) { */ public Builder input(String inputName, Operand input) { if (signatureBuilder.containsInputs(inputName)) { - throw new IllegalArgumentException("\"" + inputName + "\" is already being mapped to another input"); + throw new IllegalArgumentException( + "\"" + inputName + "\" is already being mapped to another input"); } signatureBuilder.putInputs(inputName, toTensorInfo(input.asOutput())); return this; } + /** + * Register a tensor as an input of the function. + * + * @param inputName user-friendly name for this input tensor + * @param input input tensor info + * @return this builder + * @throws IllegalArgumentException if {@code inputName} is already mapped to another input + */ + Builder input(String inputName, TensorInfo input) { + if (signatureBuilder.containsInputs(inputName)) { + throw new IllegalArgumentException( + "\"" + inputName + "\" is already being mapped to another input"); + } + signatureBuilder.putInputs(inputName, input); + return this; + } + /** * Register a tensor as an output of the function. * @@ -92,12 +116,30 @@ public Builder input(String inputName, Operand input) { */ public Builder output(String outputName, Operand output) { if (signatureBuilder.containsOutputs(outputName)) { - throw new IllegalArgumentException("\"" + outputName + "\" is already being mapped to another output"); + throw new IllegalArgumentException( + "\"" + outputName + "\" is already being mapped to another output"); } signatureBuilder.putOutputs(outputName, toTensorInfo(output.asOutput())); return this; } + /** + * Register a tensor as an output of the function. + * + * @param outputName user-friendly name for this output tensor + * @param output output tensor + * @return this builder + * @throws IllegalArgumentException if {@code outputName} is already mapped to another output + */ + Builder output(String outputName, TensorInfo output) { + if (signatureBuilder.containsOutputs(outputName)) { + throw new IllegalArgumentException( + "\"" + outputName + "\" is already being mapped to another output"); + } + signatureBuilder.putOutputs(outputName, output); + return this; + } + /** * Provide extensible name information enabling third-party users to mark a signature as * supporting a particular method @@ -110,9 +152,7 @@ public Builder methodName(String methodName) { return this; } - /** - * Returns a signature from the provided data. - */ + /** Returns a signature from the provided data. */ public Signature build() { return new Signature(key, signatureBuilder.build()); } @@ -134,44 +174,41 @@ private static TensorInfo toTensorInfo(Output operand) { private final SignatureDef.Builder signatureBuilder = SignatureDef.newBuilder(); } - /** - * Returns a new builder for creating a signature - */ + /** Returns a new builder for creating a signature */ public static Builder builder() { return new Builder(); } /** - * Return the key of this signature + * Returns a new builder for creating a signature, with the methodName and key set to {@code name} */ + public static Builder builder(String name) { + return new Builder().methodName(name).key(name); + } + + /** Return the key of this signature */ public String key() { return key; } - /** - * Returns the method name of this signature (e.g. as exposed by TF serving) or null if none - */ + /** Returns the method name of this signature (e.g. as exposed by TF serving) or null if none */ public String methodName() { return signatureDef.getMethodName().isEmpty() ? null : signatureDef.getMethodName(); } - /** - * Returns the names of the inputs in this signature - */ + /** Returns the names of the inputs in this signature */ public Set inputNames() { return signatureDef.getInputsMap().keySet(); } - /** - * Returns the names of the outputs in this signature - */ + /** Returns the names of the outputs in this signature */ public Set outputNames() { return signatureDef.getOutputsMap().keySet(); } @Override public String toString() { - StringBuilder strBuilder = new StringBuilder("Signature for \"" + key +"\":\n"); + StringBuilder strBuilder = new StringBuilder("Signature for \"" + key + "\":\n"); if (!methodName().isEmpty()) { strBuilder.append("\tMethod: \"").append(methodName()).append("\"\n"); } @@ -186,30 +223,40 @@ public String toString() { return strBuilder.toString(); } - private Map buildTensorDescriptionMap(Map dataMapIn) { - Map dataTypeMap = new HashMap<>(); - dataMapIn.forEach((a, b) -> { - long[] tensorDims = b.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray(); - Shape tensorShape = Shape.of(tensorDims); - dataTypeMap.put(a, new TensorDescription(b.getDtype(), - tensorShape)); - }); - return dataTypeMap; + private Map buildTensorDescriptionMap( + Map dataMapIn) { + Map dataTypeMap = new LinkedHashMap<>(); + dataMapIn.forEach( + (name, info) -> { + long[] tensorDims = + info.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray(); + Shape tensorShape = Shape.of(tensorDims); + dataTypeMap.put( + name, new TensorDescription(info.getDtype(), tensorShape, info.getName())); + }); + return Collections.unmodifiableMap(dataTypeMap); } /** - * Returns the names of the inputs in this signature mapped to their expected data type and shape - * @return + * Returns the names of the inputs in this signature mapped to their expected data type, shape, + * and operand name */ public Map getInputs() { - return buildTensorDescriptionMap(signatureDef.getInputsMap()); + if (inputMap == null) { + inputMap = buildTensorDescriptionMap(signatureDef.getInputsMap()); + } + return inputMap; } /** - * Returns the names of the outputs in this signature mapped to their expected data type and shape + * Returns the names of the outputs in this signature mapped to their expected data type, shape, + * and operand name */ public Map getOutputs() { - return buildTensorDescriptionMap(signatureDef.getOutputsMap()); + if (outputMap == null) { + outputMap = buildTensorDescriptionMap(signatureDef.getOutputsMap()); + } + return outputMap; } Signature(String key, SignatureDef signatureDef) { @@ -223,21 +270,25 @@ SignatureDef asSignatureDef() { private final String key; private final SignatureDef signatureDef; + private Map inputMap; + private Map outputMap; private static void printTensorInfo(Map tensorMap, StringBuilder strBuilder) { - tensorMap.forEach((key, tensorInfo) -> { - strBuilder.append("\t\t\"") - .append(key) - .append("\": dtype=") - .append(tensorInfo.getDtype().name()) - .append(", shape=("); - for (int i = 0; i < tensorInfo.getTensorShape().getDimCount(); ++i) { - strBuilder.append(tensorInfo.getTensorShape().getDim(i).getSize()); - if (i < tensorInfo.getTensorShape().getDimCount() - 1) { - strBuilder.append(", "); - } - } - strBuilder.append(")\n"); - }); + tensorMap.forEach( + (key, tensorInfo) -> { + strBuilder + .append("\t\t\"") + .append(key) + .append("\": dtype=") + .append(tensorInfo.getDtype().name()) + .append(", shape=("); + for (int i = 0; i < tensorInfo.getTensorShape().getDimCount(); ++i) { + strBuilder.append(tensorInfo.getTensorShape().getDim(i).getSize()); + if (i < tensorInfo.getTensorShape().getDimCount() - 1) { + strBuilder.append(", "); + } + } + strBuilder.append(")\n"); + }); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index de481d256a3..23f4c62bc7f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteBuffer; @@ -23,6 +23,8 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_Version; import com.google.protobuf.InvalidProtocolBufferException; +import java.util.Set; +import java.util.stream.Collectors; import org.bytedeco.javacpp.PointerScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.c_api.TF_Buffer; @@ -56,6 +58,20 @@ public static OpList registeredOpList() { } } + private static Set statefulOps; + + public static synchronized boolean isOpStateful(String opType) { + if (statefulOps == null) { + statefulOps = + registeredOpList().getOpList().stream() + .filter(x -> x.getIsStateful()) + .map(x -> x.getName()) + .collect(Collectors.toSet()); + } + + return statefulOps.contains(opType); + } + /** * Load the dynamic library in filename and register the operations and kernels present in that * library. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java new file mode 100644 index 00000000000..0304d786494 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java @@ -0,0 +1,129 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow; + +import java.util.LinkedHashMap; +import java.util.Map; +import org.tensorflow.Signature.TensorDescription; + +/** A function that can be called with tensors. */ +public interface TensorFunction { + + /** Returns the signature of this function */ + Signature signature(); + + /** + * Invokes a function using the default eager session. + * + *

Caller is responsible for closing all Tensors. + * + * @param arguments list of tensors to pass in input to the function, mapped by their signature + * name + * @return output tensors resulting from the execution of the function, mapped by their signature + * name + * @throws IllegalArgumentException if the passed arguments don't match up to the function's + * parameters. + */ + Map call(Map arguments); + + /** + * Invokes a function with a single input and output using the default eager session. + * + *

Caller is responsible for closing all Tensors. + * + * @param tensor input tensor + * @return output tensor + * @throws IllegalArgumentException if there are multiple input or output parameters defined in + * the function + */ + default Tensor call(Tensor tensor) { + if (signature().inputNames().size() > 1) { + throw new IllegalArgumentException( + "Can't use call(Tensor) on function \"" + + signature().methodName() + + "\" with more than one input."); + } + if (signature().inputNames().size() < 1) { + throw new IllegalArgumentException( + "Can't use call(Tensor) on function \"" + + signature().methodName() + + "\" with no inputs."); + } + if (signature().outputNames().size() > 1) { + throw new IllegalArgumentException( + "Can't use call(Tensor) on function \"" + + signature().methodName() + + "\" with more than one output."); + } + if (signature().outputNames().size() < 1) { + throw new IllegalArgumentException( + "Can't use call(Tensor) on function \"" + + signature().methodName() + + "\" with no outputs."); + } + + String inputName = signature().inputNames().iterator().next(); + String outputName = signature().outputNames().iterator().next(); + + Map inputMap = new LinkedHashMap<>(); + inputMap.put(inputName, tensor); + + return call(inputMap).get(outputName); + } + + static Operand validateDescription( + TensorDescription description, Graph graph, String name, String prefix) { + Output operand = graph.output(description.name); + if (operand == null) { + throw new IllegalArgumentException( + prefix + + " \"" + + name + + "\"'s operand \"" + + description.name + + "\" does not exist on the graph."); + } + + if (operand.dataType() != description.dataType) { + throw new IllegalArgumentException( + prefix + + " \"" + + name + + "\"'s operand \"" + + description.name + + "\" has data type " + + operand.dataType() + + " in the graph, but the signature requires data type " + + description.dataType + + "."); + } + + if (!operand.shape().isCompatibleWith(description.shape)) { + throw new IllegalArgumentException( + prefix + + " \"" + + name + + "\"'s operand \"" + + description.name + + "\" has shape " + + operand.shape() + + ", which is incompatible with the signature's required shape of " + + description.shape + + "."); + } + return operand; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java new file mode 100644 index 00000000000..a3647b5671d --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java @@ -0,0 +1,53 @@ +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.internal.c_api; + +import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteFunction; + +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.annotation.Properties; + +@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) +public abstract class AbstractTF_Function extends Pointer { + + protected static class DeleteDeallocator extends TF_Function implements Deallocator { + + DeleteDeallocator(TF_Function s) { + super(s); + } + + @Override + public void deallocate() { + if (!isNull()) { + TF_DeleteFunction(this); + } + setNull(); + } + } + + public AbstractTF_Function(Pointer p) { + super(p); + } + + public TF_Function withDeallocator() { + return this.deallocator(new DeleteDeallocator((TF_Function) this)); + } + + /** Calls the deallocator, if registered, otherwise has no effect. */ + public void delete() { + deallocate(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java index 17bf9dbf79a..66dead59967 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java @@ -1,5 +1,4 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +13,6 @@ limitations under the License. ======================================================================= */ - package org.tensorflow.internal.c_api.presets; import java.util.List; @@ -28,204 +26,363 @@ import org.bytedeco.javacpp.tools.InfoMap; import org.bytedeco.javacpp.tools.InfoMapper; -/** - * - * @author Samuel Audet - */ +/** @author Samuel Audet */ @Properties( value = { - @Platform( - value = {"linux", "macosx", "windows"}, - compiler = "cpp11", - include = { - "tensorflow/core/platform/ctstring_internal.h", - "tensorflow/core/platform/ctstring.h", - "tensorflow/core/util/port.h", - "tensorflow/c/tf_attrtype.h", - "tensorflow/c/c_api_macros.h", - "tensorflow/c/tf_datatype.h", - "tensorflow/c/tf_status.h", - "tensorflow/c/tf_tensor.h", - "tensorflow/c/tf_tstring.h", - "tensorflow/c/c_api.h", -// "tensorflow/c/env.h", - "tensorflow/c/kernels.h", - "tensorflow/c/ops.h", - "tensorflow/c/eager/c_api.h" - }, - link = "tensorflow_cc@.2", - preload = {"iomp5", "mklml", "mklml_intel", "tensorflow_framework@.2"}, - preloadresource = "/org/bytedeco/mkldnn/", - resource = {"LICENSE", "THIRD_PARTY_TF_JNI_LICENSES"} - ), - @Platform( - value = "windows", - preload = { - "api-ms-win-crt-locale-l1-1-0", "api-ms-win-crt-string-l1-1-0", "api-ms-win-crt-stdio-l1-1-0", "api-ms-win-crt-math-l1-1-0", - "api-ms-win-crt-heap-l1-1-0", "api-ms-win-crt-runtime-l1-1-0", "api-ms-win-crt-convert-l1-1-0", "api-ms-win-crt-environment-l1-1-0", - "api-ms-win-crt-time-l1-1-0", "api-ms-win-crt-filesystem-l1-1-0", "api-ms-win-crt-utility-l1-1-0", "api-ms-win-crt-multibyte-l1-1-0", - "api-ms-win-core-string-l1-1-0", "api-ms-win-core-errorhandling-l1-1-0", "api-ms-win-core-timezone-l1-1-0", "api-ms-win-core-file-l1-1-0", - "api-ms-win-core-namedpipe-l1-1-0", "api-ms-win-core-handle-l1-1-0", "api-ms-win-core-file-l2-1-0", "api-ms-win-core-heap-l1-1-0", - "api-ms-win-core-libraryloader-l1-1-0", "api-ms-win-core-synch-l1-1-0", "api-ms-win-core-processthreads-l1-1-0", - "api-ms-win-core-processenvironment-l1-1-0", "api-ms-win-core-datetime-l1-1-0", "api-ms-win-core-localization-l1-2-0", - "api-ms-win-core-sysinfo-l1-1-0", "api-ms-win-core-synch-l1-2-0", "api-ms-win-core-console-l1-1-0", "api-ms-win-core-debug-l1-1-0", - "api-ms-win-core-rtlsupport-l1-1-0", "api-ms-win-core-processthreads-l1-1-1", "api-ms-win-core-file-l1-2-0", "api-ms-win-core-profile-l1-1-0", - "api-ms-win-core-memory-l1-1-0", "api-ms-win-core-util-l1-1-0", "api-ms-win-core-interlocked-l1-1-0", "ucrtbase", - "vcruntime140", "vcruntime140_1", "msvcp140", "concrt140", "vcomp140", "msvcr120", "libiomp5md", "mklml", "tensorflow_framework" - } - ), - @Platform( - value = "windows-x86", - preloadpath = { - "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x86/Microsoft.VC140.CRT/", - "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x86/Microsoft.VC140.OpenMP/", - "C:/Program Files (x86)/Windows Kits/10/Redist/ucrt/DLLs/x86/" - } - ), - @Platform( - value = "windows-x86_64", - preloadpath = { - "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x64/Microsoft.VC140.CRT/", - "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x64/Microsoft.VC140.OpenMP/", - "C:/Program Files (x86)/Windows Kits/10/Redist/ucrt/DLLs/x64/" - } - ), - @Platform( - value = {"linux", "macosx", "windows"}, - extension = {"-mkl", "-gpu", "-mkl-gpu"} - ) + @Platform( + value = {"linux", "macosx", "windows"}, + compiler = "cpp11", + include = { + "tensorflow/core/platform/ctstring_internal.h", + "tensorflow/core/platform/ctstring.h", + "tensorflow/core/util/port.h", + "tensorflow/c/tf_attrtype.h", + "tensorflow/c/c_api_macros.h", + "tensorflow/c/tf_datatype.h", + "tensorflow/c/tf_status.h", + "tensorflow/c/tf_tensor.h", + "tensorflow/c/tf_tstring.h", + "tensorflow/c/c_api.h", + // "tensorflow/c/env.h", + "tensorflow/c/kernels.h", + "tensorflow/c/ops.h", + "tensorflow/c/eager/c_api.h" + }, + link = "tensorflow_cc@.2", + preload = {"iomp5", "mklml", "mklml_intel", "tensorflow_framework@.2"}, + preloadresource = "/org/bytedeco/mkldnn/", + resource = {"LICENSE", "THIRD_PARTY_TF_JNI_LICENSES"}), + @Platform( + value = "windows", + preload = { + "api-ms-win-crt-locale-l1-1-0", + "api-ms-win-crt-string-l1-1-0", + "api-ms-win-crt-stdio-l1-1-0", + "api-ms-win-crt-math-l1-1-0", + "api-ms-win-crt-heap-l1-1-0", + "api-ms-win-crt-runtime-l1-1-0", + "api-ms-win-crt-convert-l1-1-0", + "api-ms-win-crt-environment-l1-1-0", + "api-ms-win-crt-time-l1-1-0", + "api-ms-win-crt-filesystem-l1-1-0", + "api-ms-win-crt-utility-l1-1-0", + "api-ms-win-crt-multibyte-l1-1-0", + "api-ms-win-core-string-l1-1-0", + "api-ms-win-core-errorhandling-l1-1-0", + "api-ms-win-core-timezone-l1-1-0", + "api-ms-win-core-file-l1-1-0", + "api-ms-win-core-namedpipe-l1-1-0", + "api-ms-win-core-handle-l1-1-0", + "api-ms-win-core-file-l2-1-0", + "api-ms-win-core-heap-l1-1-0", + "api-ms-win-core-libraryloader-l1-1-0", + "api-ms-win-core-synch-l1-1-0", + "api-ms-win-core-processthreads-l1-1-0", + "api-ms-win-core-processenvironment-l1-1-0", + "api-ms-win-core-datetime-l1-1-0", + "api-ms-win-core-localization-l1-2-0", + "api-ms-win-core-sysinfo-l1-1-0", + "api-ms-win-core-synch-l1-2-0", + "api-ms-win-core-console-l1-1-0", + "api-ms-win-core-debug-l1-1-0", + "api-ms-win-core-rtlsupport-l1-1-0", + "api-ms-win-core-processthreads-l1-1-1", + "api-ms-win-core-file-l1-2-0", + "api-ms-win-core-profile-l1-1-0", + "api-ms-win-core-memory-l1-1-0", + "api-ms-win-core-util-l1-1-0", + "api-ms-win-core-interlocked-l1-1-0", + "ucrtbase", + "vcruntime140", + "vcruntime140_1", + "msvcp140", + "concrt140", + "vcomp140", + "msvcr120", + "libiomp5md", + "mklml", + "tensorflow_framework" + }), + @Platform( + value = "windows-x86", + preloadpath = { + "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x86/Microsoft.VC140.CRT/", + "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x86/Microsoft.VC140.OpenMP/", + "C:/Program Files (x86)/Windows Kits/10/Redist/ucrt/DLLs/x86/" + }), + @Platform( + value = "windows-x86_64", + preloadpath = { + "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x64/Microsoft.VC140.CRT/", + "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x64/Microsoft.VC140.OpenMP/", + "C:/Program Files (x86)/Windows Kits/10/Redist/ucrt/DLLs/x64/" + }), + @Platform( + value = {"linux", "macosx", "windows"}, + extension = {"-mkl", "-gpu", "-mkl-gpu"}) }, target = "org.tensorflow.internal.c_api", global = "org.tensorflow.internal.c_api.global.tensorflow") @NoException public class tensorflow implements LoadEnabled, InfoMapper { - @Override public void init(ClassProperties properties) { - String platform = properties.getProperty("platform"); - String extension = properties.getProperty("platform.extension"); - List preloads = properties.get("platform.preload"); - List resources = properties.get("platform.preloadresource"); - List preloadpaths = properties.get("platform.preloadpath"); - - String vcredistdir = System.getenv("VCToolsRedistDir"); - if (vcredistdir != null && vcredistdir.length() > 0) { - switch (platform) { - case "windows-x86": - preloadpaths.add(0, vcredistdir + "\\x86\\Microsoft.VC142.CRT"); - preloadpaths.add(1, vcredistdir + "\\x86\\Microsoft.VC142.OpenMP"); - preloadpaths.add(2, vcredistdir + "\\x86\\Microsoft.VC141.CRT"); - preloadpaths.add(3, vcredistdir + "\\x86\\Microsoft.VC141.OpenMP"); - break; - case "windows-x86_64": - preloadpaths.add(0, vcredistdir + "\\x64\\Microsoft.VC142.CRT"); - preloadpaths.add(1, vcredistdir + "\\x64\\Microsoft.VC142.OpenMP"); - preloadpaths.add(2, vcredistdir + "\\x64\\Microsoft.VC141.CRT"); - preloadpaths.add(3, vcredistdir + "\\x64\\Microsoft.VC141.OpenMP"); - break; - default: - // not Windows - } - } + @Override + public void init(ClassProperties properties) { + String platform = properties.getProperty("platform"); + String extension = properties.getProperty("platform.extension"); + List preloads = properties.get("platform.preload"); + List resources = properties.get("platform.preloadresource"); + List preloadpaths = properties.get("platform.preloadpath"); - // Only apply this at load time - if (!Loader.isLoadLibraries()) { - return; - } + String vcredistdir = System.getenv("VCToolsRedistDir"); + if (vcredistdir != null && vcredistdir.length() > 0) { + switch (platform) { + case "windows-x86": + preloadpaths.add(0, vcredistdir + "\\x86\\Microsoft.VC142.CRT"); + preloadpaths.add(1, vcredistdir + "\\x86\\Microsoft.VC142.OpenMP"); + preloadpaths.add(2, vcredistdir + "\\x86\\Microsoft.VC141.CRT"); + preloadpaths.add(3, vcredistdir + "\\x86\\Microsoft.VC141.OpenMP"); + break; + case "windows-x86_64": + preloadpaths.add(0, vcredistdir + "\\x64\\Microsoft.VC142.CRT"); + preloadpaths.add(1, vcredistdir + "\\x64\\Microsoft.VC142.OpenMP"); + preloadpaths.add(2, vcredistdir + "\\x64\\Microsoft.VC141.CRT"); + preloadpaths.add(3, vcredistdir + "\\x64\\Microsoft.VC141.OpenMP"); + break; + default: + // not Windows + } + } - // Let users enable loading of the full version of MKL - String load = System.getProperty("org.bytedeco.openblas.load", - System.getProperty("org.bytedeco.mklml.load", "")).toLowerCase(); + // Only apply this at load time + if (!Loader.isLoadLibraries()) { + return; + } - int i = 0; - if (load.equals("mkl") || load.equals("mkl_rt")) { - String[] libs = {"iomp5", "libiomp5md", "mkl_core", "mkl_avx", "mkl_avx2", "mkl_avx512", "mkl_avx512_mic", - "mkl_def", "mkl_mc", "mkl_mc3", "mkl_intel_lp64", "mkl_intel_thread", "mkl_gnu_thread", "mkl_rt"}; - for (i = 0; i < libs.length; i++) { - preloads.add(i, libs[i] + "#" + libs[i]); - } - load = "mkl_rt"; - resources.add("/org/bytedeco/mkl/"); - } + // Let users enable loading of the full version of MKL + String load = + System.getProperty( + "org.bytedeco.openblas.load", System.getProperty("org.bytedeco.mklml.load", "")) + .toLowerCase(); - if (load.length() > 0) { - if (platform.startsWith("linux")) { - preloads.add(i, load + "#mklml_intel"); - } else if (platform.startsWith("macosx")) { - preloads.add(i, load + "#mklml"); - } else if (platform.startsWith("windows")) { - preloads.add(i, load + "#mklml"); - } - } + int i = 0; + if (load.equals("mkl") || load.equals("mkl_rt")) { + String[] libs = { + "iomp5", + "libiomp5md", + "mkl_core", + "mkl_avx", + "mkl_avx2", + "mkl_avx512", + "mkl_avx512_mic", + "mkl_def", + "mkl_mc", + "mkl_mc3", + "mkl_intel_lp64", + "mkl_intel_thread", + "mkl_gnu_thread", + "mkl_rt" + }; + for (i = 0; i < libs.length; i++) { + preloads.add(i, libs[i] + "#" + libs[i]); + } + load = "mkl_rt"; + resources.add("/org/bytedeco/mkl/"); + } - // Only apply this at load time since we don't want to copy the CUDA libraries here - if (!Loader.isLoadLibraries() || extension == null || !extension.endsWith("-gpu")) { - return; - } - String[] libs = {"cudart", "cublasLt", "cublas", "cufft", "curand", "cusolver", "cusparse", "cudnn", "nccl", "nvrtc", "myelin", "nvinfer", - "cudnn_ops_infer", "cudnn_ops_train", "cudnn_adv_infer", "cudnn_adv_train", "cudnn_cnn_infer", "cudnn_cnn_train"}; - for (String lib : libs) { - if (platform.startsWith("linux")) { - lib += lib.startsWith("cudnn") ? "@.8" - : lib.equals("nccl") ? "@.2" - : lib.equals("myelin") ? "@.1" - : lib.equals("nvinfer") ? "@.7" - : lib.equals("cufft") || lib.equals("curand") || lib.equals("cusolver") ? "@.10" - : lib.equals("cudart") ? "@.11.0" - : lib.equals("nvrtc") ? "@.11.0" - : "@.11"; - } else if (platform.startsWith("windows")) { - lib += lib.startsWith("cudnn") ? "64_8" - : lib.equals("nccl") ? "64_2" - : lib.equals("myelin") ? "64_1" - : lib.equals("nvinfer") ? "64_7" - : lib.equals("cufft") || lib.equals("curand") || lib.equals("cusolver") ? "64_10" - : lib.equals("cudart") ? "64_110" - : lib.equals("nvrtc") ? "64_110_0" - : "64_11"; - } else { - continue; // no CUDA - } - if (!preloads.contains(lib)) { - preloads.add(i++, lib); - } - } - if (i > 0) { - resources.add("/org/bytedeco/cuda/"); - resources.add("/org/bytedeco/tensorrt/"); - } + if (load.length() > 0) { + if (platform.startsWith("linux")) { + preloads.add(i, load + "#mklml_intel"); + } else if (platform.startsWith("macosx")) { + preloads.add(i, load + "#mklml"); + } else if (platform.startsWith("windows")) { + preloads.add(i, load + "#mklml"); + } } - public void map(InfoMap infoMap) { - infoMap.put(new Info("TF_CAPI_EXPORT", "TF_Bool").cppTypes().annotations()) - .put(new Info("TF_Buffer::data").javaText("public native @Const Pointer data(); public native TF_Buffer data(Pointer data);")) - .put(new Info("TF_Status").pointerTypes("TF_Status").base("org.tensorflow.internal.c_api.AbstractTF_Status")) - .put(new Info("TF_Buffer").pointerTypes("TF_Buffer").base("org.tensorflow.internal.c_api.AbstractTF_Buffer")) - .put(new Info("TF_Tensor").pointerTypes("TF_Tensor").base("org.tensorflow.internal.c_api.AbstractTF_Tensor")) - .put(new Info("TF_Session").pointerTypes("TF_Session").base("org.tensorflow.internal.c_api.AbstractTF_Session")) - .put(new Info("TF_SessionOptions").pointerTypes("TF_SessionOptions").base("org.tensorflow.internal.c_api.AbstractTF_SessionOptions")) - .put(new Info("TF_Graph").pointerTypes("TF_Graph").base("org.tensorflow.internal.c_api.AbstractTF_Graph")) - .put(new Info("TF_Graph::graph").javaText("public native @MemberGetter @ByRef Graph graph();")) - .put(new Info("TF_Graph::refiner").javaText("public native @MemberGetter @ByRef ShapeRefiner refiner();")) - .put(new Info("TF_ImportGraphDefOptions").pointerTypes("TF_ImportGraphDefOptions").base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions")) - .put(new Info("TF_Operation", "TF_WhileParams", "TFE_MonitoringCounterCell", "TFE_MonitoringSamplerCell", - "TFE_MonitoringCounter0", "TFE_MonitoringCounter1", "TFE_MonitoringCounter2", - "TFE_MonitoringIntGaugeCell", "TFE_MonitoringStringGaugeCell", "TFE_MonitoringBoolGaugeCell", - "TFE_MonitoringIntGauge0", "TFE_MonitoringIntGauge1", "TFE_MonitoringIntGauge2", - "TFE_MonitoringStringGauge0", "TFE_MonitoringStringGauge1", "TFE_MonitoringStringGauge2", - "TFE_MonitoringBoolGauge0", "TFE_MonitoringBoolGauge1", "TFE_MonitoringBoolGauge2", - "TFE_MonitoringSampler0", "TFE_MonitoringSampler1", "TFE_MonitoringSampler2").purify()) - .put(new Info("TF_Operation::node").javaText("public native @MemberGetter @ByRef Node node();")) - .put(new Info("TFE_MonitoringCounterCell::cell").javaText("public native @MemberGetter @ByRef CounterCell cell();")) - .put(new Info("TFE_MonitoringSamplerCell::cell").javaText("public native @MemberGetter @ByRef SamplerCell cell();")) - .put(new Info("TFE_MonitoringIntGaugeCell::cell").javaText("public native @MemberGetter @ByRef IntGaugeCell cell();")) - .put(new Info("TFE_MonitoringStringGaugeCell::cell").javaText("public native @MemberGetter @ByRef StringGaugeCell cell();")) - .put(new Info("TFE_MonitoringBoolGaugeCell::cell").javaText("public native @MemberGetter @ByRef BoolGaugeCell cell();")) - .put(new Info("TFE_Context").pointerTypes("TFE_Context").base("org.tensorflow.internal.c_api.AbstractTFE_Context")) - .put(new Info("TFE_ContextOptions").pointerTypes("TFE_ContextOptions").base("org.tensorflow.internal.c_api.AbstractTFE_ContextOptions")) - .put(new Info("TFE_Context::context").javaText("@MemberGetter public native @ByRef EagerContext context();")) - .put(new Info("TFE_Op").pointerTypes("TFE_Op").base("org.tensorflow.internal.c_api.AbstractTFE_Op")) - .put(new Info("TFE_Op::operation").javaText("@MemberGetter public native @ByRef EagerOperation operation();")) - .put(new Info("TFE_TensorHandle").pointerTypes("TFE_TensorHandle").base("org.tensorflow.internal.c_api.AbstractTFE_TensorHandle")) - .put(new Info("TF_ShapeInferenceContextDimValueKnown", "TFE_NewTensorHandle(const tensorflow::Tensor&, TF_Status*)").skip()); + // Only apply this at load time since we don't want to copy the CUDA libraries here + if (!Loader.isLoadLibraries() || extension == null || !extension.endsWith("-gpu")) { + return; + } + String[] libs = { + "cudart", + "cublasLt", + "cublas", + "cufft", + "curand", + "cusolver", + "cusparse", + "cudnn", + "nccl", + "nvrtc", + "myelin", + "nvinfer", + "cudnn_ops_infer", + "cudnn_ops_train", + "cudnn_adv_infer", + "cudnn_adv_train", + "cudnn_cnn_infer", + "cudnn_cnn_train" + }; + for (String lib : libs) { + if (platform.startsWith("linux")) { + lib += + lib.startsWith("cudnn") + ? "@.8" + : lib.equals("nccl") + ? "@.2" + : lib.equals("myelin") + ? "@.1" + : lib.equals("nvinfer") + ? "@.7" + : lib.equals("cufft") || lib.equals("curand") || lib.equals("cusolver") + ? "@.10" + : lib.equals("cudart") + ? "@.11.0" + : lib.equals("nvrtc") ? "@.11.0" : "@.11"; + } else if (platform.startsWith("windows")) { + lib += + lib.startsWith("cudnn") + ? "64_8" + : lib.equals("nccl") + ? "64_2" + : lib.equals("myelin") + ? "64_1" + : lib.equals("nvinfer") + ? "64_7" + : lib.equals("cufft") || lib.equals("curand") || lib.equals("cusolver") + ? "64_10" + : lib.equals("cudart") + ? "64_110" + : lib.equals("nvrtc") ? "64_110_0" : "64_11"; + } else { + continue; // no CUDA + } + if (!preloads.contains(lib)) { + preloads.add(i++, lib); + } } + if (i > 0) { + resources.add("/org/bytedeco/cuda/"); + resources.add("/org/bytedeco/tensorrt/"); + } + } + + @Override + public void map(InfoMap infoMap) { + infoMap + .put(new Info("TF_CAPI_EXPORT", "TF_Bool").cppTypes().annotations()) + .put( + new Info("TF_Buffer::data") + .javaText( + "public native @Const Pointer data(); public native TF_Buffer data(Pointer data);")) + .put( + new Info("TF_Status") + .pointerTypes("TF_Status") + .base("org.tensorflow.internal.c_api.AbstractTF_Status")) + .put( + new Info("TF_Buffer") + .pointerTypes("TF_Buffer") + .base("org.tensorflow.internal.c_api.AbstractTF_Buffer")) + .put( + new Info("TF_Tensor") + .pointerTypes("TF_Tensor") + .base("org.tensorflow.internal.c_api.AbstractTF_Tensor")) + .put( + new Info("TF_Session") + .pointerTypes("TF_Session") + .base("org.tensorflow.internal.c_api.AbstractTF_Session")) + .put( + new Info("TF_SessionOptions") + .pointerTypes("TF_SessionOptions") + .base("org.tensorflow.internal.c_api.AbstractTF_SessionOptions")) + .put( + new Info("TF_Graph") + .pointerTypes("TF_Graph") + .base("org.tensorflow.internal.c_api.AbstractTF_Graph")) + .put( + new Info("TF_Graph::graph") + .javaText("public native @MemberGetter @ByRef Graph graph();")) + .put( + new Info("TF_Graph::refiner") + .javaText("public native @MemberGetter @ByRef ShapeRefiner refiner();")) + .put( + new Info("TF_Function") + .pointerTypes("TF_Function") + .base("org.tensorflow.internal.c_api.AbstractTF_Function")) + .put( + new Info("TF_ImportGraphDefOptions") + .pointerTypes("TF_ImportGraphDefOptions") + .base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions")) + .put( + new Info( + "TF_Operation", + "TF_WhileParams", + "TFE_MonitoringCounterCell", + "TFE_MonitoringSamplerCell", + "TFE_MonitoringCounter0", + "TFE_MonitoringCounter1", + "TFE_MonitoringCounter2", + "TFE_MonitoringIntGaugeCell", + "TFE_MonitoringStringGaugeCell", + "TFE_MonitoringBoolGaugeCell", + "TFE_MonitoringIntGauge0", + "TFE_MonitoringIntGauge1", + "TFE_MonitoringIntGauge2", + "TFE_MonitoringStringGauge0", + "TFE_MonitoringStringGauge1", + "TFE_MonitoringStringGauge2", + "TFE_MonitoringBoolGauge0", + "TFE_MonitoringBoolGauge1", + "TFE_MonitoringBoolGauge2", + "TFE_MonitoringSampler0", + "TFE_MonitoringSampler1", + "TFE_MonitoringSampler2") + .purify()) + .put( + new Info("TF_Operation::node") + .javaText("public native @MemberGetter @ByRef Node node();")) + .put( + new Info("TFE_MonitoringCounterCell::cell") + .javaText("public native @MemberGetter @ByRef CounterCell cell();")) + .put( + new Info("TFE_MonitoringSamplerCell::cell") + .javaText("public native @MemberGetter @ByRef SamplerCell cell();")) + .put( + new Info("TFE_MonitoringIntGaugeCell::cell") + .javaText("public native @MemberGetter @ByRef IntGaugeCell cell();")) + .put( + new Info("TFE_MonitoringStringGaugeCell::cell") + .javaText("public native @MemberGetter @ByRef StringGaugeCell cell();")) + .put( + new Info("TFE_MonitoringBoolGaugeCell::cell") + .javaText("public native @MemberGetter @ByRef BoolGaugeCell cell();")) + .put( + new Info("TFE_Context") + .pointerTypes("TFE_Context") + .base("org.tensorflow.internal.c_api.AbstractTFE_Context")) + .put( + new Info("TFE_ContextOptions") + .pointerTypes("TFE_ContextOptions") + .base("org.tensorflow.internal.c_api.AbstractTFE_ContextOptions")) + .put( + new Info("TFE_Context::context") + .javaText("@MemberGetter public native @ByRef EagerContext context();")) + .put( + new Info("TFE_Op") + .pointerTypes("TFE_Op") + .base("org.tensorflow.internal.c_api.AbstractTFE_Op")) + .put( + new Info("TFE_Op::operation") + .javaText("@MemberGetter public native @ByRef EagerOperation operation();")) + .put( + new Info("TFE_TensorHandle") + .pointerTypes("TFE_TensorHandle") + .base("org.tensorflow.internal.c_api.AbstractTFE_TensorHandle")) + .put( + new Info( + "TF_ShapeInferenceContextDimValueKnown", + "TFE_NewTensorHandle(const tensorflow::Tensor&, TF_Status*)") + .skip()); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java new file mode 100644 index 00000000000..255a62e1253 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java @@ -0,0 +1,58 @@ +/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.op.core; + +import java.util.Map; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; + +/** Ops for calling {@link ConcreteFunction}. */ +@Operator(name = "call") +public abstract class Function { + + /** + * Calls the function in an execution environment, adding its graph as a function if it isn't + * already present. The inputs and outputs are keyed by the names set in the {@code Signature}. + * + * @param scope the scope to call the function in + * @param arguments the arguments to the call + * @return the outputs of the function + * @see ConcreteFunction#call(Ops, Map) + */ + @Endpoint + public static Map> call( + Scope scope, ConcreteFunction function, Map> arguments) { + return function.call(scope, arguments); + } + + /** + * Calls the function in an execution environment, adding its graph as a function if it isn't + * already present. Only works for functions with a single input and output. + * + * @param scope the scope to call the function in + * @param argument the argument to the call + * @return the output of the function + * @see ConcreteFunction#call(Ops, Operand) + */ + @Endpoint + public static Operand call(Scope scope, ConcreteFunction function, Operand argument) { + return function.call(scope, argument); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index b2b2c34e223..64c33f451fb 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -1,28 +1,31 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import java.util.Arrays; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.math.Add; import org.tensorflow.op.math.Sub; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TFloat32; public class ConcreteFunctionTest { @@ -30,7 +33,7 @@ public class ConcreteFunctionTest { private static Signature plusFive(Ops tf) { Placeholder input = tf.placeholder(TFloat32.class); Add output = tf.math.add(input, tf.constant(5.0f)); - Init init = tf.init(); // for native resource management tests + Init init = tf.init(); // for native resource management tests return Signature.builder().key("plusFive").input("x", input).output("y", output).build(); } @@ -40,11 +43,25 @@ private static Signature minusTwo(Ops tf) { return Signature.builder().key("minusTwo").input("x", input).output("y", output).build(); } + @SuppressWarnings("unchecked") + private static Signature plusFiveMinusTwo(Ops tf) { + Placeholder input = tf.placeholder(TFloat32.class); + try (ConcreteFunction plusFive = ConcreteFunction.create(ConcreteFunctionTest::plusFive); + ConcreteFunction minusTwo = ConcreteFunction.create(ConcreteFunctionTest::minusTwo)) { + Operand result = (Operand) minusTwo.call(tf, plusFive.call(tf, input)); + return Signature.builder() + .key("plusFiveMinusTwo") + .input("x", input) + .output("y", result) + .build(); + } + } + @Test public void createFunction() { try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive); TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); + assertEquals(8.0f, ((TFloat32) f.call(x)).getFloat()); } } @@ -54,7 +71,7 @@ public void createFunctionFromGraph() { Signature signature = plusFive(Ops.create(g)); try (ConcreteFunction f = ConcreteFunction.create(signature, g); TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); + assertEquals(8.0f, ((TFloat32) f.call(x)).getFloat()); } } } @@ -66,7 +83,7 @@ public void createFunctionFromSession() { try (Session s = new Session(g)) { try (ConcreteFunction f = ConcreteFunction.create(signature, s); TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); + assertEquals(8.0f, ((TFloat32) f.call(x)).getFloat()); } } } @@ -77,45 +94,109 @@ public void chainFunctions() { try (ConcreteFunction f1 = ConcreteFunction.create(ConcreteFunctionTest::plusFive); ConcreteFunction f2 = ConcreteFunction.create(ConcreteFunctionTest::minusTwo); TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(6.0f, ((TFloat32)f2.call(f1.call(x))).getFloat()); + assertEquals(6.0f, ((TFloat32) f2.call(f1.call(x))).getFloat()); } } @Test - public void closingFunctionReleaseAllResourcesItOwns() { - Graph g; - Session s; - try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive)) { - g = f.graph(); - s = f.session(); + public void getGraphFunctions() { + try (ConcreteFunction function = ConcreteFunction.create(ConcreteFunctionTest::plusFive); + Graph g = new Graph()) { + Ops tf = Ops.create(g); + tf.call(function, tf.constant(3f)); + + ConcreteFunction attached = g.getFunction(function.getDefinedName()); + assertNotNull(attached); + + try (TFloat32 x = TFloat32.scalarOf(10f); + TFloat32 y = (TFloat32) attached.call(x)) { + assertEquals(15f, y.getFloat()); + } } - assertThrows(IllegalStateException.class, () -> s.run("Add")); - assertThrows(IllegalStateException.class, () -> g.toGraphDef()); } @Test - public void closingFunctionCreatedFromGraphOnlyReleaseResourcesItOwns() { - try (Graph g = new Graph()) { - Signature signature = plusFive(Ops.create(g)); - Session s; - try (ConcreteFunction f = ConcreteFunction.create(signature, g)) { - s = f.session(); + public void testNestedFunctionEager() { + try (EagerSession sess = EagerSession.create(); + ConcreteFunction function = + ConcreteFunction.create(ConcreteFunctionTest::plusFiveMinusTwo)) { + Ops tf = Ops.create(sess); + Operand a = tf.constant(10f); + Operand result = (Operand) function.call(tf, a); + try (TFloat32 t = result.asTensor()) { + assertEquals(13f, t.getFloat()); } - assertThrows(IllegalStateException.class, () -> s.run(Init.DEFAULT_NAME)); - g.toGraphDef(); // check that graph is still valid } } @Test - public void closingFunctionCreatedFromSessionDoesNotReleaseResources() { - try (Graph g = new Graph()) { - Signature signature = plusFive(Ops.create(g)); - try (Session s = new Session(g)) { - try (ConcreteFunction f = ConcreteFunction.create(signature, s)) { - } - s.run(Init.DEFAULT_NAME); // check that session is still valid + public void testNestedFunctionGraph() { + try (Graph graph = new Graph(); + ConcreteFunction function = + ConcreteFunction.create(ConcreteFunctionTest::plusFiveMinusTwo)) { + Ops tf = Ops.create(graph); + Operand a = tf.constant(10f); + Operand result = (Operand) function.call(tf, a); + try (Session sess = new Session(graph); + TFloat32 t = (TFloat32) sess.runner().fetch(result).run().get(0)) { + assertEquals(13f, t.getFloat()); + } + } + } + + private static Signature square(Ops tf) { + Placeholder input = tf.placeholder(TFloat32.class); + Operand output = tf.math.square(input); + return Signature.builder() + .methodName("square") + .key("square") + .input("x", input) + .output("y", output) + .build(); + } + + // call op gradients are not defined in c++ + // @Test + public void testGradientsGraph() { + try (Graph g = new Graph(); + ConcreteFunction square = ConcreteFunction.create(ConcreteFunctionTest::square); + Session s = new Session(g)) { + Ops tf = Ops.create(g); + + Output x1 = tf.placeholder(TFloat32.class).output(); + Output x2 = tf.placeholder(TFloat32.class).output(); + Output y0 = (Output) square.call(tf, x1); + Output y1 = (Output) square.call(tf, y0); + Output y2 = tf.math.addN(Arrays.asList(y0, x2)).sum(); + + Output[] grads0 = g.addGradients(y1, new Output[] {x1}); + assertNotNull(grads0); + assertEquals(1, grads0.length); + assertEquals(DataType.DT_FLOAT, grads0[0].dataType()); + + Output[] grads1 = g.addGradients(y2, new Output[] {x1, x2}); + assertNotNull(grads1); + assertEquals(2, grads1.length); + assertEquals(DataType.DT_FLOAT, grads1[0].dataType()); + assertEquals(DataType.DT_FLOAT, grads1[1].dataType()); + + try (TFloat32 c1 = TFloat32.scalarOf(3.0f); + TFloat32 c2 = TFloat32.scalarOf(2.0f); + AutoCloseableList outputs = + new AutoCloseableList<>( + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run())) { + + assertEquals(3, outputs.size()); + assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); + assertEquals(6.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f); + assertEquals(1.0f, ((TFloat32) outputs.get(2)).getFloat(), 0.0f); } - g.toGraphDef(); // check that graph is still valid } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java index b39ecec9881..b694e0e5a39 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java @@ -1,18 +1,18 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.junit.jupiter.api.Assertions.fail; @@ -134,7 +134,8 @@ public void setAttrs() { .addInput(tf.constant(10.00000f).asOutput()) .setAttr("tolerance", 0.1f) .build(); - // Missing tests: list(string), list(byte), list(bool), list(type) + // Missing tests: list(string), list(byte), list(bool), list(type), list(func) + // func is done via ConcreteFunction execution } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java index 33ae979ccbd..d0e79534d2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java @@ -1,18 +1,18 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -106,7 +106,8 @@ public void setAttr() { .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f}) .build(); assertTrue(hasNode(g, "FloatList")); - // Missing tests: float, list(dtype), list(tensor), list(string), list(bool) + // Missing tests: float, list(dtype), list(tensor), list(string), list(bool), list(func) + // func is done via ConcreteFunction execution } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 032c835c0cc..1561842a689 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -55,8 +55,12 @@ public class SavedModelBundleTest { static { try { - SAVED_MODEL_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model").toURI()).toString(); - SAVED_MODEL_PY_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model_using_python/model").toURI()).toString(); + SAVED_MODEL_PATH = + Paths.get(SavedModelBundleTest.class.getResource("/saved_model").toURI()).toString(); + SAVED_MODEL_PY_PATH = + Paths.get( + SavedModelBundleTest.class.getResource("/saved_model_using_python/model").toURI()) + .toString(); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -85,38 +89,84 @@ public void loadNonExistentBundle() { @Test public void loader() { - try (SavedModelBundle bundle = SavedModelBundle.loader(SAVED_MODEL_PATH) - .withTags("serve") - .withConfigProto(sillyConfigProto()) - .withRunOptions(sillyRunOptions()) - .load()) { + try (SavedModelBundle bundle = + SavedModelBundle.loader(SAVED_MODEL_PATH) + .withTags("serve") + .withConfigProto(sillyConfigProto()) + .withRunOptions(sillyRunOptions()) + .load()) { assertNotNull(bundle.session()); assertNotNull(bundle.graph()); assertNotNull(bundle.metaGraphDef()); } } + @Test + public void exportMultipleFunctions() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + float reducedSum; + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); + Signature f2Signature = buildIdentityGraph(tf, "identity"); + try (Session s = new Session(g); ) { + SessionFunction f1 = SessionFunction.create(f1Signature, s); + SessionFunction f2 = SessionFunction.create(f2Signature, s); + s.runInit(); + try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[] {2, 2})); + TFloat32 t = (TFloat32) f1.call(x)) { + reducedSum = t.getFloat(); + } + SavedModelBundle.exporter(testFolder.toString()).withFunction(f1).withFunction(f2).export(); + } + } + try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) { + assertEquals(2, model.signatures().size()); + TensorFunction f1 = model.function(Signature.DEFAULT_KEY); + assertNotNull(f1); + try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[] {2, 2})); + TFloat32 t = (TFloat32) f1.call(x)) { + assertEquals(reducedSum, t.getFloat(), EPSILON); + } + TensorFunction f2 = model.function("identity"); + assertNotNull(f2); + try (TFloat32 x = TFloat32.scalarOf(10.0f); + TFloat32 t = (TFloat32) f2.call(x)) { + assertEquals(10.0f, t.getFloat(), 0.0f); + } + try { + model.function("NoSuchFunction"); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + } + } + @Test public void exportFunctionWithVariables() throws IOException { Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); float reducedSum; - FloatNdArray xValue = StdArrays.ndCopyOf(new float[][]{{0, 1, 2}, {3, 4, 5}}); + FloatNdArray xValue = StdArrays.ndCopyOf(new float[][] {{0, 1, 2}, {3, 4, 5}}); Shape xyShape = Shape.of(2, 3L); - try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape))) { + try (Graph g = new Graph(); + Session session = new Session(g)) { + Ops tf = Ops.create(g); + SessionFunction f = session.function(buildGraphWithVariables(tf, xyShape)); // Init variable state by running the Init operation directly - f.session().run(Init.DEFAULT_NAME); + session.runInit(); // Call the graph and remember the result of computation for later try (TFloat32 xTensor = TFloat32.tensorOf(xValue); - TFloat32 zTensor = (TFloat32)f.call(xTensor)) { + TFloat32 zTensor = (TFloat32) f.call(xTensor)) { reducedSum = zTensor.getFloat(); } // Save/export the model (which is a single function in this case) f.save(testFolder.toString()); } assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index")))); - assertTrue(Files - .exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001")))); + assertTrue( + Files.exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001")))); assertTrue(Files.exists(testFolder.resolve("saved_model.pb"))); // Reload the model just saved and validate its data @@ -125,10 +175,11 @@ public void exportFunctionWithVariables() throws IOException { assertNotNull(savedModel.metaGraphDef()); assertNotNull(savedModel.metaGraphDef().getSaverDef()); assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); - assertEquals(Signature.DEFAULT_KEY, + assertEquals( + Signature.DEFAULT_KEY, savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); - ConcreteFunction function = savedModel.function(Signature.DEFAULT_KEY); + TensorFunction function = savedModel.function(Signature.DEFAULT_KEY); assertNotNull(function); Signature signature = function.signature(); @@ -155,12 +206,13 @@ public void exportFunctionWithVariables() throws IOException { try (TFloat32 xTensor = TFloat32.tensorOf(xValue)) { // Call the saved model function and make sure it returns the same result as before - try (TFloat32 zTensor = (TFloat32)function.call(xTensor)) { + try (TFloat32 zTensor = (TFloat32) function.call(xTensor)) { assertEquals(reducedSum, zTensor.getFloat(), EPSILON); } // Now call the same function directly from the model try (TFloat32 zTensor = - (TFloat32)savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum")) { + (TFloat32) + savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum")) { assertEquals(reducedSum, zTensor.getFloat(), EPSILON); } } @@ -168,73 +220,27 @@ public void exportFunctionWithVariables() throws IOException { } @Test - public void exportMultipleFunctions() throws IOException { + public void cannotExportMultipleFunctionsWithDifferentSessions() throws IOException { Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); - float reducedSum; - try (Graph g = new Graph()) { + try (Graph g = new Graph(); + Session s1 = new Session(g); + Session s2 = new Session(g)) { Ops tf = Ops.create(g); Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); Signature f2Signature = buildIdentityGraph(tf, "identity"); - try (Session s = new Session(g); - ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); - ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { - f1.session().run(Init.DEFAULT_NAME); - try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); - TFloat32 t = (TFloat32)f1.call(x)) { - reducedSum = t.getFloat(); - } - SavedModelBundle.exporter(testFolder.toString()) - .withFunction(f1) - .withFunction(f2) - .export(); - } - } - try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) { - assertEquals(2, model.signatures().size()); - ConcreteFunction f1 = model.function(Signature.DEFAULT_KEY); - assertNotNull(f1); - try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); - TFloat32 t = (TFloat32)f1.call(x)) { - assertEquals(reducedSum, t.getFloat(), EPSILON); - } - ConcreteFunction f2 = model.function("identity"); - assertNotNull(f2); - try (TFloat32 x = TFloat32.scalarOf(10.0f); - TFloat32 t = (TFloat32)f2.call(x)) { - assertEquals(10.0f, t.getFloat(), 0.0f); - } + SessionFunction f1 = s1.function(f1Signature); + SessionFunction f2 = s2.function(f2Signature); + s1.runInit(); + s2.runInit(); try { - model.function("NoSuchFunction"); + SavedModelBundle.exporter(testFolder.toString()).withFunction(f1).withFunction(f2).export(); fail(); - } catch (IllegalArgumentException e) { + } catch (UnsupportedOperationException e) { // as expected } } } - @Test - public void cannotExportMultipleFunctionsWithDifferentSessions() throws IOException { - Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); - try (Graph g = new Graph()) { - Ops tf = Ops.create(g); - Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); - Signature f2Signature = buildIdentityGraph(tf, "identity"); - try (ConcreteFunction f1 = ConcreteFunction.create(f1Signature, g); - ConcreteFunction f2 = ConcreteFunction.create(f2Signature, g)) { - f1.session().run(Init.DEFAULT_NAME); - try { - SavedModelBundle.exporter(testFolder.toString()) - .withFunction(f1) - .withFunction(f2) - .export(); - fail(); - } catch (UnsupportedOperationException e) { - // as expected - } - } - } - } - @Test public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOException { Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); @@ -242,15 +248,12 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti Ops tf = Ops.create(g); Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); Signature f2Signature = buildIdentityGraph(tf, Signature.DEFAULT_KEY); - try (Session s = new Session(g); - ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); - ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { - f1.session().run(Init.DEFAULT_NAME); + try (Session s = new Session(g); ) { + SessionFunction f1 = SessionFunction.create(f1Signature, s); + SessionFunction f2 = SessionFunction.create(f2Signature, s); + s.runInit(); try { - SavedModelBundle.exporter(testFolder.toString()) - .withFunction(f1) - .withFunction(f2) - .export(); + SavedModelBundle.exporter(testFolder.toString()).withFunctions(f1, f2).export(); fail(); } catch (IllegalArgumentException e) { // as expected @@ -261,24 +264,21 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti @Test public void cannotExportOrImportInvalidTags() { - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.loader("/").withTags(null) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.loader("/").withTags(new String[]{"tag", null}) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.loader("/").withTags(new String[]{"tag", ""}) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.exporter("/").withTags(null) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.exporter("/").withTags(new String[]{"tag", null}) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.exporter("/").withTags(new String[]{"tag", ""}) - ); + assertThrows(IllegalArgumentException.class, () -> SavedModelBundle.loader("/").withTags(null)); + assertThrows( + IllegalArgumentException.class, + () -> SavedModelBundle.loader("/").withTags(new String[] {"tag", null})); + assertThrows( + IllegalArgumentException.class, + () -> SavedModelBundle.loader("/").withTags(new String[] {"tag", ""})); + assertThrows( + IllegalArgumentException.class, () -> SavedModelBundle.exporter("/").withTags(null)); + assertThrows( + IllegalArgumentException.class, + () -> SavedModelBundle.exporter("/").withTags(new String[] {"tag", null})); + assertThrows( + IllegalArgumentException.class, + () -> SavedModelBundle.exporter("/").withTags(new String[] {"tag", ""})); } @Test @@ -289,10 +289,11 @@ public void pythonTfFunction() { * Test model was created in python * Signature name used for saving 'add', argument names 'a' and 'b' */ - ConcreteFunction add = bundle.function("add"); + TensorFunction add = bundle.function("add"); Map args = new HashMap<>(); try (TFloat32 a = TFloat32.scalarOf(10.0f); TFloat32 b = TFloat32.scalarOf(15.5f)) { + System.out.println(add.signature()); args.put("a", a); args.put("b", b); Map result = add.call(args); @@ -304,12 +305,15 @@ public void pythonTfFunction() { args.clear(); // variable unwrapping happens in Session, which is used by ConcreteFunction.call - ConcreteFunction getVariable = bundle.function("get_variable"); + TensorFunction getVariable = bundle.function("get_variable"); try (TFloat32 dummy = TFloat32.scalarOf(1.0f)) { - args.put("dummy",dummy); + args.put("dummy", dummy); // TF functions always require an input, so we supply a dummy one here // This test actually checks that resource variables can be loaded correctly. - try (TFloat32 v = (TFloat32) getVariable.call(args) + try (TFloat32 v = + (TFloat32) + getVariable + .call(args) .get(getVariable.signature().outputNames().iterator().next())) { assertEquals(2f, v.getFloat()); } @@ -319,8 +323,9 @@ public void pythonTfFunction() { private static Signature buildGraphWithVariables(Ops tf, Shape xShape) { Placeholder x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape)); - Variable y = tf.withName("variable") - .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class)); + Variable y = + tf.withName("variable") + .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class)); ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); Init init = tf.init(); return Signature.builder().input("input", x).output("reducedSum", z).build(); @@ -333,9 +338,7 @@ private static Signature buildIdentityGraph(Ops tf, String signatureKey) { } private static RunOptions sillyRunOptions() { - return RunOptions.newBuilder() - .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) - .build(); + return RunOptions.newBuilder().setTraceLevel(RunOptions.TraceLevel.FULL_TRACE).build(); } private static ConfigProto sillyConfigProto() { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 4223a03ee23..8a3e64c3336 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -43,18 +43,33 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; -/** - * Unit tests for {@link org.tensorflow.Session}. - */ +/** Unit tests for {@link org.tensorflow.Session}. */ public class SessionTest { + @Test + public void runUsingFunction() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + Ops tf = Ops.create(g); + transpose_A_times_X(tf, new int[][] {{2}, {3}}); + Signature sig = + Signature.builder("sess").input("X", g.output("X")).output("Y", g.output("Y")).build(); + SessionFunction func = s.function(sig); + + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); + TInt32 y = (TInt32) func.call(x)) { + assertEquals(31, y.getInt(0, 0)); + } + } + } + @Test public void runUsingOperationNames() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][]{{2}, {3}}); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}})); + transpose_A_times_X(tf, new int[][] {{2}, {3}}); + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); AutoCloseableList outputs = new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) { assertEquals(1, outputs.size()); @@ -68,10 +83,10 @@ public void runUsingOperationHandles() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][]{{2}, {3}}); + transpose_A_times_X(tf, new int[][] {{2}, {3}}); Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}})); + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); AutoCloseableList outputs = new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) { assertEquals(1, outputs.size()); @@ -95,12 +110,9 @@ public void runUsingColonSeparatedNames() { } // Feed using colon separated names. try (TInt32 fed = TInt32.vectorOf(4, 3, 2, 1); - TInt32 fetched = (TInt32) s.runner() - .feed("Split:0", fed) - .feed("Split:1", fed) - .fetch("Add") - .run() - .get(0)) { + TInt32 fetched = + (TInt32) + s.runner().feed("Split:0", fed).feed("Split:1", fed).fetch("Add").run().get(0)) { assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched); } } @@ -111,13 +123,14 @@ public void runWithMetadata() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][]{{2}, {3}}); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}))) { - Session.Run result = s.runner() - .feed("X", x) - .fetch("Y") - .setOptions(fullTraceRunOptions()) - .runAndFetchMetadata(); + transpose_A_times_X(tf, new int[][] {{2}, {3}}); + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { + Session.Run result = + s.runner() + .feed("X", x) + .fetch("Y") + .setOptions(fullTraceRunOptions()) + .runAndFetchMetadata(); // Sanity check on outputs. AutoCloseableList outputs = new AutoCloseableList<>(result.outputs); assertEquals(1, outputs.size()); @@ -163,8 +176,7 @@ public void failOnUseAfterClose() { @Test public void createWithConfigProto() { try (Graph g = new Graph(); - Session s = new Session(g, singleThreadConfigProto())) { - } + Session s = new Session(g, singleThreadConfigProto())) {} } @Test @@ -219,10 +231,12 @@ public void saveAndRestore() throws IOException { Path testFolder = Files.createTempDirectory("tf-session-save-restore-test"); try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Variable x = tf.withName("x") - .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); - Variable y = tf.withName("y") - .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + Variable x = + tf.withName("x") + .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + Variable y = + tf.withName("y") + .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); Init init = tf.init(); try (Session s = new Session(g)) { @@ -234,9 +248,10 @@ public void saveAndRestore() throws IOException { restoredGraph.importGraphDef(graphDef); try (Session restoredSession = new Session(restoredGraph)) { restoredSession.restore(testFolder.resolve("checkpoint").toString()); - try (AutoCloseableList oldList = new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run()); - AutoCloseableList newList = new AutoCloseableList<>( - restoredSession.runner().fetch("x").fetch("y").run())) { + try (AutoCloseableList oldList = + new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run()); + AutoCloseableList newList = + new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())) { assertEquals(oldList.get(0), newList.get(0)); assertEquals(oldList.get(1), newList.get(1)); } @@ -265,7 +280,6 @@ public static void testFetchVariable() { try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)) { assertEquals(2, value.getInt()); } - } } @@ -295,14 +309,11 @@ public static void testFetchVariableReusingRead() { } assertEquals(0, numOperations(g) - ops); - } } private static RunOptions fullTraceRunOptions() { - return RunOptions.newBuilder() - .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) - .build(); + return RunOptions.newBuilder().setTraceLevel(RunOptions.TraceLevel.FULL_TRACE).build(); } private static ConfigProto singleThreadConfigProto() { @@ -313,10 +324,11 @@ private static ConfigProto singleThreadConfigProto() { } private static void transpose_A_times_X(Ops tf, int[][] a) { - tf.withName("Y").linalg.matMul( - tf.withName("A").constant(a), - tf.withName("X").placeholder(TInt32.class), - MatMul.transposeA(true).transposeB(false) - ); + tf.withName("Y") + .linalg + .matMul( + tf.withName("A").constant(a), + tf.withName("X").placeholder(TInt32.class), + MatMul.transposeA(true).transposeB(false)); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java new file mode 100644 index 00000000000..be4386698fa --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java @@ -0,0 +1,67 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.op.core; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.EagerSession; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Signature; +import org.tensorflow.op.Ops; +import org.tensorflow.op.math.Add; +import org.tensorflow.types.TFloat32; + +/** Tests for GraphFunction and it's ops */ +public class FunctionTest { + + private static Signature plusFive(Ops tf) { + Placeholder input = tf.placeholder(TFloat32.class); + Add output = tf.math.add(input, tf.constant(5.0f)); + Init init = tf.init(); // for native resource management tests + return Signature.builder().key("plusFive").input("x", input).output("y", output).build(); + } + + @Test + public void testConcreteFunctionEager() { + try (EagerSession sess = EagerSession.create(); + ConcreteFunction function = ConcreteFunction.create(FunctionTest::plusFive)) { + Ops tf = Ops.create(sess); + Operand a = tf.constant(10f); + Operand result = (Operand) function.call(tf, a); + try (TFloat32 t = result.asTensor()) { + assertEquals(15f, t.getFloat()); + } + } + } + + @Test + public void testConcreteFunctionGraph() { + try (Graph graph = new Graph(); + ConcreteFunction function = ConcreteFunction.create(FunctionTest::plusFive)) { + Ops tf = Ops.create(graph); + Operand a = tf.constant(10f); + Operand result = (Operand) function.call(tf, a); + try (Session sess = new Session(graph); + TFloat32 t = (TFloat32) sess.runner().fetch(result).run().get(0)) { + assertEquals(15f, t.getFloat()); + } + } + } +}