Skip to content

Commit d5bd64e

Browse files
committed
Proper attribute setters
Signed-off-by: Ryan Nett <[email protected]>
1 parent c42d45c commit d5bd64e

File tree

6 files changed

+100
-15
lines changed

6 files changed

+100
-15
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ public Map<String, Operand<?>> call(Scope scope,
238238

239239
opBuilder.addInputList(inputList.stream().map(Operand::asOutput).toArray(Output[]::new));
240240

241-
opBuilder.setFunctionName("f", name);
241+
opBuilder.setAttr("f", this);
242242
opBuilder.setAttr("Tin", inputList.stream().map(x -> x.asOutput().dataType()).toArray(DataType[]::new));
243243
opBuilder.setAttr("Tout", signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new));
244244

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrBoolList;
2323
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFloat;
2424
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFloatList;
25+
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFunctionList;
2526
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFunctionName;
2627
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrInt;
2728
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrIntList;
@@ -36,6 +37,9 @@
3637

3738
import java.nio.charset.Charset;
3839
import java.nio.charset.StandardCharsets;
40+
import java.util.Arrays;
41+
import java.util.List;
42+
import java.util.stream.Collectors;
3943
import org.bytedeco.javacpp.BooleanPointer;
4044
import org.bytedeco.javacpp.BytePointer;
4145
import org.bytedeco.javacpp.IntPointer;
@@ -219,8 +223,22 @@ public EagerOperationBuilder setAttr(String name, Shape[] values) {
219223
}
220224

221225
@Override
222-
public OperationBuilder setFunctionName(String attrName, String functionName) {
223-
setAttrFunctionName(opHandle, attrName, functionName);
226+
public OperationBuilder setAttr(String name, ConcreteFunction value) {
227+
session.attachFunction(value);
228+
setAttrFunctionName(opHandle, name, value.getNativeFunctionName());
229+
return this;
230+
}
231+
232+
@Override
233+
public OperationBuilder setAttr(String name, ConcreteFunction[] value) {
234+
for (ConcreteFunction fn : value) {
235+
session.attachFunction(fn);
236+
}
237+
238+
setAttrFunctionList(opHandle, session.nativeHandle(), name, Arrays.stream(value)
239+
.map(ConcreteFunction::getNativeFunctionName)
240+
.collect(Collectors.toList()));
241+
224242
return this;
225243
}
226244

@@ -416,7 +434,7 @@ private static void setAttrShapeList(TFE_Op opHandle, String name, long[] shapes
416434
}
417435
TF_Status status = TF_Status.newStatus();
418436
TFE_OpSetAttrShapeList(opHandle, new BytePointer(name), shapesPointers, new IntPointer(numDims),
419-
numDims.length, status);
437+
numDims.length, status);
420438
}
421439
}
422440

@@ -426,4 +444,20 @@ private static void setAttrFunctionName(TFE_Op opHandle, String attrName, String
426444
TFE_OpSetAttrFunctionName(opHandle, attrName, functionName, functionName.length());
427445
}
428446
}
447+
448+
private static void setAttrFunctionList(TFE_Op opHandle, TFE_Context context, String attrName,
449+
List<String> functionNames) {
450+
requireOp(opHandle);
451+
requireContext(context);
452+
try (PointerScope scope = new PointerScope()) {
453+
PointerPointer<TFE_Op> fns = new PointerPointer<>(functionNames.size());
454+
for (int i = 0; i < functionNames.size(); i++) {
455+
TF_Status status = TF_Status.newStatus();
456+
TFE_Op op = TFE_Op.newOp(context, functionNames.get(i), status);
457+
status.throwExceptionIfNotOK();
458+
fns.put(i, op);
459+
}
460+
TFE_OpSetAttrFunctionList(opHandle, new BytePointer(attrName), fns, functionNames.size());
461+
}
462+
}
429463
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,13 @@
3535
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrTensorList;
3636
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrType;
3737
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrTypeList;
38+
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrValueProto;
3839
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetDevice;
3940

4041
import java.nio.charset.Charset;
42+
import java.util.Arrays;
43+
import java.util.List;
44+
import java.util.stream.Collectors;
4145
import org.bytedeco.javacpp.BooleanPointer;
4246
import org.bytedeco.javacpp.BytePointer;
4347
import org.bytedeco.javacpp.IntPointer;
@@ -54,7 +58,10 @@
5458
import org.tensorflow.internal.c_api.TF_Status;
5559
import org.tensorflow.internal.c_api.TF_Tensor;
5660
import org.tensorflow.ndarray.Shape;
61+
import org.tensorflow.proto.framework.AttrValue;
62+
import org.tensorflow.proto.framework.AttrValue.ListValue;
5763
import org.tensorflow.proto.framework.DataType;
64+
import org.tensorflow.proto.framework.NameAttrList;
5865

5966
/**
6067
* An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}.
@@ -347,9 +354,24 @@ public GraphOperationBuilder setAttr(String name, String[] value) {
347354
}
348355

349356
@Override
350-
public OperationBuilder setFunctionName(String attrName, String functionName) {
357+
public OperationBuilder setAttr(String name, ConcreteFunction value) {
358+
graph.attachFunction(value);
351359
try (Reference r = graph.ref()) {
352-
setAttrFunctionName(unsafeNativeHandle, attrName, functionName);
360+
setAttrFunctionName(unsafeNativeHandle, name, value.getNativeFunctionName());
361+
}
362+
return this;
363+
}
364+
365+
@Override
366+
public OperationBuilder setAttr(String name, ConcreteFunction[] value) {
367+
for (ConcreteFunction f : value) {
368+
graph.attachFunction(f);
369+
}
370+
371+
try (Reference r = graph.ref()) {
372+
setAttrFunctionList(unsafeNativeHandle, name, Arrays.stream(value)
373+
.map(ConcreteFunction::getNativeFunctionName)
374+
.collect(Collectors.toList()));
353375
}
354376
return this;
355377
}
@@ -556,4 +578,20 @@ private static void setAttrFunctionName(TF_OperationDescription opHandle, String
556578
TF_SetAttrFuncName(opHandle, attrName, functionName, functionName.length());
557579
}
558580
}
581+
582+
private static void setAttrFunctionList(TF_OperationDescription opHandle, String attrName,
583+
List<String> functionNames) {
584+
requireHandle(opHandle);
585+
try (PointerScope scope = new PointerScope()) {
586+
TF_Status status = TF_Status.newStatus();
587+
AttrValue value = AttrValue.newBuilder().setList(ListValue.newBuilder().addAllFunc(
588+
functionNames.stream()
589+
.map(x -> NameAttrList.newBuilder().setName(x).build())
590+
.collect(Collectors.toList())
591+
).build()).build();
592+
byte[] bytes = value.toByteArray();
593+
TF_SetAttrValueProto(opHandle, attrName, new BytePointer(bytes), bytes.length, status);
594+
status.throwExceptionIfNotOK();
595+
}
596+
}
559597
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,22 @@ public interface OperationBuilder {
226226
OperationBuilder setAttr(String name, Shape[] value);
227227

228228
/**
229-
* Set a function name attribute of the operation being build.
229+
* Set the function value of an attribute of the operation being built. Also attaches the function and dependencies to
230+
* the execution environment.
230231
*
231-
* @param attrName the attribute to set
232-
* @param functionName the function name
232+
* @param name attribute name
233+
* @param value attribute value
234+
* @return the OperationBuilder instance for chaining.
235+
*/
236+
OperationBuilder setAttr(String name, ConcreteFunction value);
237+
238+
/**
239+
* Set the function values of an attribute of the operation being built. Also attaches the functions and dependencies
240+
* to the execution environment.
241+
*
242+
* @param name attribute name
243+
* @param value attribute value
233244
* @return the OperationBuilder instance for chaining.
234245
*/
235-
OperationBuilder setFunctionName(String attrName, String functionName);
246+
OperationBuilder setAttr(String name, ConcreteFunction[] value);
236247
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public void setAttrs() {
124124
.build();
125125
// bool
126126
opBuilder(session, "All", "Bool")
127-
.addInput(tf.constant(new boolean[] {true, true, false}).asOutput())
127+
.addInput(tf.constant(new boolean[]{true, true, false}).asOutput())
128128
.addInput(tf.constant(0).asOutput())
129129
.setAttr("keep_dims", false)
130130
.build();
@@ -134,7 +134,8 @@ public void setAttrs() {
134134
.addInput(tf.constant(10.00000f).asOutput())
135135
.setAttr("tolerance", 0.1f)
136136
.build();
137-
// Missing tests: list(string), list(byte), list(bool), list(type)
137+
// Missing tests: list(string), list(byte), list(bool), list(type), list(func)
138+
// func is done via ConcreteFunction execution
138139
}
139140
}
140141

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,18 @@ public void setAttr() {
9696
g.opBuilder("MaxPool", "IntList")
9797
.addInput(tf.constant(new float[2][2][2][2]).asOutput())
9898
.setAttr("ksize", new long[] {1, 1, 1, 1})
99-
.setAttr("strides", new long[] {1, 1, 1, 1})
99+
.setAttr("strides", new long[]{1, 1, 1, 1})
100100
.setAttr("padding", "SAME")
101101
.build();
102102
assertTrue(hasNode(g, "IntList"));
103103
// list(float)
104104
g.opBuilder("FractionalMaxPool", "FloatList")
105105
.addInput(tf.constant(new float[2][2][2][2]).asOutput())
106-
.setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f})
106+
.setAttr("pooling_ratio", new float[]{1.0f, 1.44f, 1.73f, 1.0f})
107107
.build();
108108
assertTrue(hasNode(g, "FloatList"));
109-
// Missing tests: float, list(dtype), list(tensor), list(string), list(bool)
109+
// Missing tests: float, list(dtype), list(tensor), list(string), list(bool), list(func)
110+
// func is done via ConcreteFunction execution
110111
}
111112
}
112113

0 commit comments

Comments
 (0)