Skip to content

Commit 257bbb0

Browse files
authored
Better cross-environment error messages (#207)
* Expose env to Op, add nicer error messages when crossing environments Signed-off-by: Ryan Nett <[email protected]> * Remove outdated test Signed-off-by: Ryan Nett <[email protected]>
1 parent 3044d4b commit 257bbb0

File tree

12 files changed

+204
-34
lines changed

12 files changed

+204
-34
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ public String type() {
6868
return type;
6969
}
7070

71+
@Override
72+
public EagerSession env() {
73+
return session;
74+
}
75+
7176
@Override
7277
public int numOutputs() {
7378
return outputHandles.length;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ public EagerOperation build() {
7575

7676
@Override
7777
public EagerOperationBuilder addInput(Output<?> input) {
78+
session.checkInput(input);
7879
addInput(opHandle, (TFE_TensorHandle) input.getUnsafeNativeHandle());
7980
return this;
8081
}
@@ -83,6 +84,7 @@ public EagerOperationBuilder addInput(Output<?> input) {
8384
public EagerOperationBuilder addInputList(Output<?>[] inputs) {
8485
TFE_TensorHandle[] inputHandles = new TFE_TensorHandle[inputs.length];
8586
for (int i = 0; i < inputs.length; ++i) {
87+
session.checkInput(inputs[i]);
8688
inputHandles[i] = (TFE_TensorHandle) inputs[i].getUnsafeNativeHandle();
8789
}
8890
addInputList(opHandle, inputHandles);

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.tensorflow.internal.c_api.TFE_Context;
2828
import org.tensorflow.internal.c_api.TFE_ContextOptions;
2929
import org.tensorflow.internal.c_api.TF_Status;
30+
import org.tensorflow.op.Op;
3031
import org.tensorflow.op.core.Assign;
3132
import org.tensorflow.op.core.Placeholder;
3233
import org.tensorflow.op.core.Variable;
@@ -297,6 +298,13 @@ public boolean isOpEnabled(String opType) {
297298
}
298299
}
299300

301+
@Override
302+
public void checkInput(Op input) {
303+
if (!input.env().isEager()) {
304+
throw new IllegalArgumentException("Can't use graph operation " + input + " in eager mode.");
305+
}
306+
}
307+
300308
TFE_Context nativeHandle() {
301309
checkSession();
302310
return nativeHandle;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515

1616
package org.tensorflow;
1717

18-
/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */
18+
import org.tensorflow.op.Op;
19+
20+
/**
21+
* Defines an environment for creating and executing TensorFlow {@link Operation}s.
22+
*/
1923
public interface ExecutionEnvironment {
2024

2125
enum Types {
@@ -36,13 +40,23 @@ enum Types {
3640

3741
/**
3842
* Returns true if the given operation is valid in this execution environment.
43+
*
3944
* @param opType The op to check.
4045
* @return Whether the given operation is valid in this execution environment.
4146
*/
42-
default boolean isOpEnabled(String opType){
47+
default boolean isOpEnabled(String opType) {
4348
return true;
4449
}
4550

51+
/**
52+
* Checks that {@code input} is valid to use as an input in this execution environment. Throws {@link
53+
* IllegalArgumentException} if not.
54+
*
55+
* @param input The op to check
56+
* @throws IllegalArgumentException if input can't be used as an input in this execution environment.
57+
*/
58+
void checkInput(Op input);
59+
4660
/**
4761
* Get the type of this environment (from the `Environments` enumeration.
4862
*

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,17 @@ public Types environmentType() {
158158
return Types.GRAPH;
159159
}
160160

161+
@Override
162+
public void checkInput(Op input) {
163+
if (input.env().isEager()) {
164+
throw new IllegalArgumentException(
165+
"Input " + input + " was from an eager session, can't use in a graph. Use tf.constantOf(input.asTensor())");
166+
}
167+
if (input.env() != this) {
168+
throw new IllegalArgumentException("Input " + input + " was from a different graph, can't use.");
169+
}
170+
}
171+
161172
/**
162173
* Import a representation of a TensorFlow graph.
163174
*

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ public String type() {
7373
}
7474
}
7575

76+
@Override
77+
public Graph env() {
78+
try (Graph.Reference r = graph.ref()) {
79+
return graph;
80+
}
81+
}
82+
7683
@Override
7784
public int numOutputs() {
7885
Graph.Reference r = graph.ref();

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ public GraphOperationBuilder addControlInput(Operation control) {
9292
throw new IllegalArgumentException(
9393
"Only GraphOperation instances can be used as control inputs");
9494
}
95+
96+
if (control.env() != graph) {
97+
throw new IllegalArgumentException("Control input " + control + " was from a different graph, can't use.");
98+
}
99+
95100
Graph.Reference r = graph.ref();
96101
try {
97102
addControlInput(unsafeNativeHandle, ((GraphOperation) control).getUnsafeNativeHandle());
@@ -103,6 +108,7 @@ public GraphOperationBuilder addControlInput(Operation control) {
103108

104109
@Override
105110
public GraphOperationBuilder addInput(Output<?> input) {
111+
graph.checkInput(input);
106112
Graph.Reference r = graph.ref();
107113
try {
108114
addInput(unsafeNativeHandle, (TF_Operation) input.getUnsafeNativeHandle(), input.index());
@@ -114,6 +120,10 @@ public GraphOperationBuilder addInput(Output<?> input) {
114120

115121
@Override
116122
public GraphOperationBuilder addInputList(Output<?>[] inputs) {
123+
for (Output<?> input : inputs) {
124+
graph.checkInput(input);
125+
}
126+
117127
Graph.Reference r = graph.ref();
118128
try {
119129
TF_Operation[] opHandles = new TF_Operation[inputs.length];

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,24 @@
2525
*/
2626
public interface Operation {
2727

28-
/** Returns the full name of the Operation. */
28+
/**
29+
* Returns the full name of the Operation.
30+
*/
2931
String name();
3032

3133
/**
32-
* Returns the type of the operation, i.e., the name of the computation performed by the
33-
* operation.
34+
* Returns the type of the operation, i.e., the name of the computation performed by the operation.
3435
*/
3536
String type();
3637

37-
/** Returns the number of tensors produced by this operation. */
38+
/**
39+
* Returns the execution environment this operation was created in.
40+
*/
41+
ExecutionEnvironment env();
42+
43+
/**
44+
* Returns the number of tensors produced by this operation.
45+
*/
3846
int numOutputs();
3947

4048
/**

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
package org.tensorflow.op;
1717

18+
import org.tensorflow.ExecutionEnvironment;
1819
import org.tensorflow.Operation;
1920

2021
/**
@@ -48,4 +49,11 @@ public interface Op {
4849
* @return an {@link Operation}
4950
*/
5051
Operation op();
52+
53+
/**
54+
* Return the execution environment this op was created in.
55+
*/
56+
default ExecutionEnvironment env() {
57+
return op().env();
58+
}
5159
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@
1616
package org.tensorflow.op;
1717

1818
import java.util.ArrayList;
19-
2019
import org.tensorflow.DeviceSpec;
2120
import org.tensorflow.ExecutionEnvironment;
2221
import org.tensorflow.OperationBuilder;
2322

2423
/**
25-
* Manages groups of related properties when creating Tensorflow Operations, such as a common name
26-
* prefix.
24+
* Manages groups of related properties when creating Tensorflow Operations, such as a common name prefix.
2725
*
2826
* <p>A {@code Scope} is a container for common properties applied to TensorFlow Ops. Normal user
2927
* code initializes a {@code Scope} and provides it to Operation building classes. For example:
@@ -88,7 +86,9 @@ public Scope(ExecutionEnvironment env) {
8886
this(env, new NameScope(), new ArrayList<>(), DeviceSpec.newBuilder().build());
8987
}
9088

91-
/** Returns the execution environment used by this scope. */
89+
/**
90+
* Returns the execution environment used by this scope.
91+
*/
9292
public ExecutionEnvironment env() {
9393
return env;
9494
}
@@ -97,8 +97,7 @@ public ExecutionEnvironment env() {
9797
* Returns a new scope where added operations will have the provided name prefix.
9898
*
9999
* <p>Ops created with this scope will have {@code name/childScopeName/} as the prefix. The actual
100-
* name will be unique in the returned scope. All other properties are inherited from the current
101-
* scope.
100+
* name will be unique in the returned scope. All other properties are inherited from the current scope.
102101
*
103102
* <p>The child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*}
104103
*
@@ -129,7 +128,8 @@ public Scope withName(String opName) {
129128
/**
130129
* Return a new scope that uses the provided device specification for an op.
131130
*
132-
* <p>Operations created within this scope will place the created operations on the device(s) matching the provided spec.
131+
* <p>Operations created within this scope will place the created operations on the device(s) matching the provided
132+
* spec.
133133
*
134134
* @param deviceSpec device specification for an operator in the returned scope
135135
* @return a new Scope that uses opName for operations.
@@ -151,8 +151,8 @@ public Scope withDevice(DeviceSpec deviceSpec) {
151151
* }</pre>
152152
*
153153
* <p><b>Note:</b> if you provide a composite operator building class (i.e, a class that creates a
154-
* set of related operations by calling other operator building code), the provided name will act
155-
* as a subscope to all underlying operators.
154+
* set of related operations by calling other operator building code), the provided name will act as a subscope to all
155+
* underlying operators.
156156
*
157157
* @param defaultName name for the underlying operator.
158158
* @return unique name for the operator.
@@ -180,11 +180,15 @@ private Scope(
180180
* @return a new scope with the provided control dependencies
181181
*/
182182
public Scope withControlDependencies(Iterable<Op> controls) {
183+
for (Op control : controls) {
184+
env.checkInput(control);
185+
}
183186
return new Scope(env, nameScope, controls, deviceSpec);
184187
}
185188

186189
/**
187-
* Applies device specification and adds each Operand in controlDependencies as a control input to the provided builder.
190+
* Applies device specification and adds each Operand in controlDependencies as a control input to the provided
191+
* builder.
188192
*
189193
* @param builder OperationBuilder to add control inputs and device specification to
190194
*/
@@ -210,7 +214,9 @@ public OperationBuilder applyControlDependencies(OperationBuilder builder) {
210214
private final NameScope nameScope;
211215
private final DeviceSpec deviceSpec;
212216

213-
/** Returns device string from the scope. */
217+
/**
218+
* Returns device string from the scope.
219+
*/
214220
public String getDeviceString() {
215221
return deviceSpec.toString();
216222
}

0 commit comments

Comments
 (0)