Skip to content

Commit e4d08c0

Browse files
committed
Have scope detect used ops from environment
Signed-off-by: Ryan Nett <[email protected]>
1 parent 5ca958d commit e4d08c0

File tree

2 files changed

+63
-9
lines changed

2 files changed

+63
-9
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717

1818
import java.util.HashMap;
1919
import java.util.Map;
20+
import java.util.regex.Matcher;
2021
import java.util.regex.Pattern;
22+
import org.tensorflow.ExecutionEnvironment;
23+
import org.tensorflow.Graph;
2124

2225
/**
2326
* A class to manage scoped (hierarchical) names for operators.
2427
*
2528
* <p>{@code NameScope} manages hierarchical names where each component in the hierarchy is
26-
* separated by a forward slash {@code '/'}. For instance, {@code nn/Const_72} or {@code
27-
* nn/gradient/assign/init}. Each scope is a subtree in this hierarchy.
29+
* separated by a forward slash {@code '/'}. For instance, {@code nn/Const_72} or {@code nn/gradient/assign/init}. Each
30+
* scope is a subtree in this hierarchy.
2831
*
2932
* <p>Use {@code NameScope} to group related operations within a hierarchy, which for example lets
3033
* tensorboard coalesce nodes for better graph visualizations.
@@ -50,6 +53,45 @@ NameScope withName(String name) {
5053
return new NameScope(opPrefix, name, ids);
5154
}
5255

56+
private static final Pattern NAME_PATTERN = Pattern.compile("(.+)_(\\d+)", Pattern.DOTALL);
57+
58+
/**
59+
* "Import" used names from a graph. Useful when adding to a loaded graph.
60+
*/
61+
NameScope withUsedFrom(ExecutionEnvironment env) {
62+
63+
if (env instanceof Graph) {
64+
((Graph) env).operations().forEachRemaining(op -> {
65+
if (op.name().startsWith(opPrefix != null ? opPrefix : "")) {
66+
String name = op.name();
67+
68+
if (opPrefix != null) {
69+
name = name.substring(opPrefix.length() + 1);
70+
}
71+
72+
if (!name.contains("/")) {
73+
Matcher matcher = NAME_PATTERN.matcher(name);
74+
if (matcher.find()) {
75+
String realName = matcher.group(1);
76+
int num = Integer.parseInt(matcher.group(2)) + 1;
77+
78+
if (!(ids.containsKey(realName) && ids.get(realName) > num)) {
79+
ids.put(realName, num);
80+
}
81+
} else {
82+
if (!ids.containsKey(name)) {
83+
ids.put(name, 1);
84+
} else {
85+
ids.put(name, ids.get(name) + 1);
86+
}
87+
}
88+
}
89+
}
90+
});
91+
}
92+
return this;
93+
}
94+
5395
String makeOpName(String name) {
5496
checkPattern(NAME_REGEX, name);
5597
// Override with opName if it exists.
@@ -120,15 +162,22 @@ private String fullyQualify(String name) {
120162
// instance mapped to the next available numeric suffix for it.
121163
private final Map<String, Integer> ids;
122164

165+
static boolean isValidName(String name) {
166+
if (name == null) {
167+
return false;
168+
}
169+
return NAME_REGEX.matcher(name).matches();
170+
}
171+
123172
private static void checkPattern(Pattern pattern, String name) {
124173
if (name == null) {
125174
throw new IllegalArgumentException("Names cannot be null");
126175
}
127176
if (!pattern.matcher(name).matches()) {
128177
throw new IllegalArgumentException(
129-
String.format(
130-
"invalid name: '%s' does not match the regular expression %s",
131-
name, NAME_REGEX.pattern()));
178+
String.format(
179+
"invalid name: '%s' does not match the regular expression %s",
180+
name, NAME_REGEX.pattern()));
132181
}
133182
}
134183

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public final class Scope {
8383
* @param env The execution environment used by the scope.
8484
*/
8585
public Scope(ExecutionEnvironment env) {
86-
this(env, new NameScope(), new ArrayList<>(), DeviceSpec.newBuilder().build());
86+
this(env, new NameScope().withUsedFrom(env), new ArrayList<>(), DeviceSpec.newBuilder().build());
8787
}
8888

8989
/**
@@ -106,7 +106,7 @@ public ExecutionEnvironment env() {
106106
* @throws IllegalArgumentException if the name is invalid
107107
*/
108108
public Scope withSubScope(String childScopeName) {
109-
return new Scope(env, nameScope.withSubScope(childScopeName), controlDependencies, deviceSpec);
109+
return new Scope(env, nameScope.withSubScope(childScopeName).withUsedFrom(env), controlDependencies, deviceSpec);
110110
}
111111

112112
/**
@@ -141,7 +141,8 @@ public Scope withName(String opName) {
141141
* @throws IllegalArgumentException if the name is invalid
142142
*/
143143
public Scope withNameAsSubScope(String defaultName){
144-
return new Scope(env, nameScope.withSubScope(nameScope.makeOpName(defaultName)), controlDependencies, deviceSpec);
144+
return new Scope(env, nameScope.withSubScope(nameScope.makeOpName(defaultName)).withUsedFrom(env),
145+
controlDependencies, deviceSpec);
145146
}
146147

147148
/**
@@ -181,8 +182,12 @@ public String makeOpName(String defaultName) {
181182
return nameScope.makeOpName(defaultName);
182183
}
183184

185+
public static boolean isValidOpName(String name) {
186+
return NameScope.isValidName(name);
187+
}
188+
184189
private Scope(
185-
ExecutionEnvironment env, NameScope nameScope, Iterable<Op> controlDependencies, DeviceSpec deviceSpec) {
190+
ExecutionEnvironment env, NameScope nameScope, Iterable<Op> controlDependencies, DeviceSpec deviceSpec) {
186191
this.env = env;
187192
this.nameScope = nameScope;
188193
this.controlDependencies = controlDependencies;

0 commit comments

Comments
 (0)