Skip to content

Commit 09c136e

Browse files
Shajankarllessard
authored andcommitted
Create, save and load models using functional API
Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]> Save models as functions (#103) * Draft: Java API to use tf.function available on SavedModel. (#89) Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]> * Change API for creating concrete functions and exporting them to a saved model Co-authored-by: Karl Lessard <[email protected]> Rename signature name to key Print function signature when converting to String Add method that returns the signature of all functions in a saved model Add unit tests for python created SavedModel with tf.function
1 parent 99659c0 commit 09c136e

File tree

12 files changed

+1211
-8
lines changed

12 files changed

+1211
-8
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
/*
2+
* Copyright 2020 The TensorFlow Authors. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.tensorflow;
17+
18+
import java.io.IOException;
19+
import java.util.List;
20+
import java.util.ListIterator;
21+
import java.util.HashMap;
22+
import java.util.Map;
23+
import java.util.function.Function;
24+
import org.tensorflow.op.Ops;
25+
import org.tensorflow.proto.framework.SignatureDef;
26+
import org.tensorflow.proto.framework.TensorInfo;
27+
28+
/**
29+
* A graph that can be invoked as a single function, with an input and output signature.
30+
*
31+
* <p>A function can also invoke a
32+
* <a href="https://www.tensorflow.org/api_docs/python/tf/function">tf.function</a>
33+
* defined in a {@link SavedModelBundle}.
34+
*
35+
* <pre>{@code
36+
* ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
37+
* Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap);
38+
* }</pre>
39+
*/
40+
public class ConcreteFunction implements AutoCloseable {
41+
42+
/**
43+
* Creates a function by building a new graph.
44+
*
45+
* <p/>The {@code functionBuilder} must initialize the function graph from the provided
46+
* {@link Ops} instance and return a valid signature that will be used to feed the input tensors
47+
* and fetch the output tensors on execution.
48+
*
49+
* <p/>The function will be the owner of the new graph and its resulting session. Therefore,
50+
* the function must be enclosed properly with a try-with-resources block to guarantee that
51+
* all native resources will be freed once the function is discarded. For example:
52+
*
53+
* <pre>{@code
54+
* public class MyModel {
55+
*
56+
* public static Signature addTwo(Ops tf) {
57+
* Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE);
58+
* Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f));
59+
* return Signature.builder("addTwo").input("x", input).output("y", output).build();
60+
* }
61+
*
62+
* public static void main(String args[]) {
63+
* try (ConcreteFunction function = ConcreteFunction.create(MyModel::addTwo);
64+
* Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) {
65+
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
66+
* }
67+
* }
68+
* }
69+
* }</pre>
70+
*
71+
* @param functionBuilder function builder
72+
* @return the new function
73+
*/
74+
public static ConcreteFunction create(Function<Ops, Signature> functionBuilder) {
75+
Graph graph = new Graph();
76+
try {
77+
Ops tf = Ops.create(graph);
78+
Signature signature = functionBuilder.apply(tf);
79+
return new ConcreteFunction(signature, graph, new Session(graph), Ownership.GRAPH_AND_SESSION);
80+
} catch (Exception e) {
81+
graph.close();
82+
throw e;
83+
}
84+
}
85+
86+
/**
87+
* Create a function from a signature and an existing graph.
88+
*
89+
* <p/>The function will keep the ownership of the session used to run the graph but not
90+
* the graph itself, meaning that the lifetime of the latter can extend beyond the scope
91+
* of the function. For example:
92+
*
93+
* <pre>{@code
94+
* try (Graph g = new Graph()) {
95+
* Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE);
96+
* Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f));
97+
* Signature signature = Signature.builder().input("x", input).output("y", output).build();
98+
*
99+
* try (ConcreteFunction f = ConcreteFunction.create(signature, g);
100+
* Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) {
101+
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
102+
* }
103+
* // Graph g is still valid at this point
104+
* }
105+
* }</pre>
106+
*
107+
* @param signature signature of the function to create
108+
* @param graph a valid and initialized graph
109+
* @return a new function
110+
*/
111+
public static ConcreteFunction create(Signature signature, Graph graph) {
112+
return new ConcreteFunction(signature, graph, new Session(graph), Ownership.SESSION_ONLY);
113+
}
114+
115+
/**
116+
* Create a function from a signature and a valid graph session.
117+
*
118+
* <p/>The function will not own the session nor its graph, meaning that their lifetime
119+
* can extend beyond the scope of the function. Therefore the function does not need to be
120+
* closed after its usage. For example:
121+
*
122+
* <pre>{@code
123+
* try (Graph g = new Graph()) {
124+
* Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE);
125+
* Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f));
126+
* Signature signature = Signature.builder().input("x", input).output("y", output).build();
127+
*
128+
* try (Session s = new Session(g)) {
129+
* // Auto-closing the function just as an example but this is not required since it has
130+
* // no effect
131+
* try (ConcreteFunction f = ConcreteFunction.create(signature, s);
132+
* Tensor<TFloat32> t = TFloat32.scalarOf(2.0f)) {
133+
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
134+
* }
135+
* // Session s is still valid at this point
136+
* }
137+
* // Graph g is still valid at this point
138+
* }
139+
* }</pre>
140+
*
141+
* @param signature signature of the function to create
142+
* @param graph a valid session to an initialized graph
143+
* @return a new function
144+
*/
145+
public static ConcreteFunction create(Signature signature, Session session) {
146+
return new ConcreteFunction(signature, session.graph(), session, Ownership.NONE);
147+
}
148+
149+
/**
150+
* Returns the signature of this function
151+
*/
152+
public Signature signature() {
153+
return signature;
154+
}
155+
156+
/**
157+
* Invokes a function.
158+
*
159+
* <p>Caller is responsible for closing all Tensors.
160+
*
161+
* @param tensor input tensor
162+
* @return output tensor
163+
*/
164+
public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
165+
throws IllegalArgumentException {
166+
167+
final SignatureDef signatureDef = signature.asSignatureDef();
168+
final Session.Runner runner = session.runner();
169+
170+
signatureDef.getInputsMap().forEach((argName, t) -> {
171+
Tensor<?> tensor = arguments.get(argName);
172+
if (tensor == null) {
173+
throw new IllegalArgumentException(String.format("Missing argument [%s]", argName));
174+
}
175+
runner.feed(t.getName(), tensor);
176+
});
177+
178+
Map<String, TensorInfo> outputToNode = signatureDef.getOutputsMap();
179+
outputToNode.values().forEach(t -> runner.fetch(t.getName()));
180+
181+
List<Tensor<?>> resultTensors = runner.run();
182+
try {
183+
ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator();
184+
Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>();
185+
186+
// Use the output names as present in the signature definition
187+
for (String nodeName: outputToNode.keySet()) {
188+
returnMap.put(nodeName, resultTensorIter.next());
189+
}
190+
return returnMap;
191+
192+
} catch (Exception e) {
193+
// Release tensors before throwing exception
194+
for (Tensor<?> t : resultTensors) {
195+
t.close();
196+
}
197+
throw e;
198+
}
199+
}
200+
201+
/**
202+
* Invokes a function with a single input and output.
203+
*
204+
* <p>Caller is responsible for closing all Tensors.
205+
*
206+
* @param tensor input tensor
207+
* @return output tensor
208+
* @throws IllegalArgumentException if there are multiple input or output parameters defined
209+
* in the function
210+
*/
211+
public Tensor<?> call(Tensor<?> tensor) throws IllegalArgumentException {
212+
final SignatureDef signatureDef = signature.asSignatureDef();
213+
214+
if (signatureDef.getInputsCount() != 1) {
215+
throw new IllegalArgumentException(
216+
String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName()));
217+
}
218+
String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName();
219+
220+
if (signatureDef.getOutputsCount() != 1) {
221+
throw new IllegalArgumentException(
222+
String.format("Function [%s] has multiple outputs", signatureDef.getMethodName()));
223+
}
224+
String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName();
225+
226+
return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0);
227+
}
228+
229+
/**
230+
* Export this function as a saved model.
231+
*
232+
* <p>This method is convenient shortcut equivalent to
233+
* {@code SavedModel.exporter(exportDir).withFunction(this).export()}
234+
*/
235+
public void save(String exportDir) throws IOException {
236+
SavedModelBundle.exporter(exportDir)
237+
.withFunction(this)
238+
.export();
239+
}
240+
241+
/**
242+
* Returns the session used to execute the graph when calling this function
243+
*
244+
* <p>In general, a user does not need to handle directly the session of a function and rely
245+
* on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to
246+
* the session might be necessary, as it allows more running options.
247+
*
248+
* @return the function session
249+
*/
250+
public Session session() {
251+
return session;
252+
}
253+
254+
/**
255+
* Returns the graph of this function
256+
*/
257+
public Graph graph() {
258+
return graph;
259+
}
260+
261+
@Override
262+
public void close() {
263+
if (ownership != Ownership.NONE) {
264+
session.close();
265+
if (ownership == Ownership.GRAPH_AND_SESSION) {
266+
graph.close();
267+
}
268+
}
269+
}
270+
271+
@Override
272+
public String toString() {
273+
return signature.toString();
274+
}
275+
276+
private enum Ownership {
277+
GRAPH_AND_SESSION, SESSION_ONLY, NONE;
278+
}
279+
280+
private final Graph graph;
281+
private final Session session;
282+
private final Signature signature;
283+
private final Ownership ownership;
284+
285+
ConcreteFunction(Signature signature, Graph graph, Session session, Ownership ownership) {
286+
this.graph = graph;
287+
this.session = session;
288+
this.signature = signature;
289+
this.ownership = ownership;
290+
}
291+
}

0 commit comments

Comments
 (0)