17
17
18
18
import java .util .HashMap ;
19
19
import java .util .Map ;
20
+ import java .util .regex .Matcher ;
20
21
import java .util .regex .Pattern ;
22
+ import org .tensorflow .ExecutionEnvironment ;
23
+ import org .tensorflow .Graph ;
21
24
22
25
/**
23
26
* A class to manage scoped (hierarchical) names for operators.
24
27
*
25
28
* <p>{@code NameScope} manages hierarchical names where each component in the hierarchy is
26
- * separated by a forward slash {@code '/'}. For instance, {@code nn/Const_72} or {@code
27
- * nn/gradient/assign/init}. Each scope is a subtree in this hierarchy.
29
+ * separated by a forward slash {@code '/'}. For instance, {@code nn/Const_72} or {@code nn/gradient/assign/init}. Each
30
+ * scope is a subtree in this hierarchy.
28
31
*
29
32
* <p>Use {@code NameScope} to group related operations within a hierarchy, which for example lets
30
33
* tensorboard coalesce nodes for better graph visualizations.
@@ -50,6 +53,45 @@ NameScope withName(String name) {
50
53
return new NameScope (opPrefix , name , ids );
51
54
}
52
55
56
+ private static final Pattern NAME_PATTERN = Pattern .compile ("(.+)_(\\ d+)" , Pattern .DOTALL );
57
+
58
+ /**
59
+ * "Import" used names from a graph. Useful when adding to a loaded graph.
60
+ */
61
+ NameScope withUsedFrom (ExecutionEnvironment env ) {
62
+
63
+ if (env instanceof Graph ) {
64
+ ((Graph ) env ).operations ().forEachRemaining (op -> {
65
+ if (op .name ().startsWith (opPrefix != null ? opPrefix : "" )) {
66
+ String name = op .name ();
67
+
68
+ if (opPrefix != null ) {
69
+ name = name .substring (opPrefix .length () + 1 );
70
+ }
71
+
72
+ if (!name .contains ("/" )) {
73
+ Matcher matcher = NAME_PATTERN .matcher (name );
74
+ if (matcher .find ()) {
75
+ String realName = matcher .group (1 );
76
+ int num = Integer .parseInt (matcher .group (2 )) + 1 ;
77
+
78
+ if (!(ids .containsKey (realName ) && ids .get (realName ) > num )) {
79
+ ids .put (realName , num );
80
+ }
81
+ } else {
82
+ if (!ids .containsKey (name )) {
83
+ ids .put (name , 1 );
84
+ } else {
85
+ ids .put (name , ids .get (name ) + 1 );
86
+ }
87
+ }
88
+ }
89
+ }
90
+ });
91
+ }
92
+ return this ;
93
+ }
94
+
53
95
String makeOpName (String name ) {
54
96
checkPattern (NAME_REGEX , name );
55
97
// Override with opName if it exists.
@@ -120,15 +162,22 @@ private String fullyQualify(String name) {
120
162
// instance mapped to the next available numeric suffix for it.
121
163
private final Map <String , Integer > ids ;
122
164
165
+ static boolean isValidName (String name ) {
166
+ if (name == null ) {
167
+ return false ;
168
+ }
169
+ return NAME_REGEX .matcher (name ).matches ();
170
+ }
171
+
123
172
private static void checkPattern (Pattern pattern , String name ) {
124
173
if (name == null ) {
125
174
throw new IllegalArgumentException ("Names cannot be null" );
126
175
}
127
176
if (!pattern .matcher (name ).matches ()) {
128
177
throw new IllegalArgumentException (
129
- String .format (
130
- "invalid name: '%s' does not match the regular expression %s" ,
131
- name , NAME_REGEX .pattern ()));
178
+ String .format (
179
+ "invalid name: '%s' does not match the regular expression %s" ,
180
+ name , NAME_REGEX .pattern ()));
132
181
}
133
182
}
134
183
0 commit comments