Skip to content
This repository was archived by the owner on Feb 7, 2023. It is now read-only.

Commit 7df0d66

Browse files
benzyxfacebook-github-bot
authored andcommitted
Implementation for Graph Transforms
Summary: The Implementation of Graph Transformations, with the PatternMatch and ReplaceMatch rules. Reviewed By: akyrola Differential Revision: D5404144 fbshipit-source-id: 2bab68e6bff2e841ea9fb64df5d92ea945e704af
1 parent f5bbac6 commit 7df0d66

File tree

7 files changed

+482
-4
lines changed

7 files changed

+482
-4
lines changed

caffe2/contrib/transform/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
if(USE_TRANSFORMS)
22
message(STATUS "Include Graph Transformations")
33
set(Caffe2_CONTRIB_TRANSFORMS_CPU_SRC
4+
"${CMAKE_CURRENT_SOURCE_DIR}/transform.cc"
45
"${CMAKE_CURRENT_SOURCE_DIR}/graph.cc"
56
)
67

caffe2/contrib/transform/graph.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,22 @@ NetDef Graph::GetNetDef() {
172172
return netdef;
173173
}
174174

175+
void Graph::DeactivateSubgraph(std::vector<int> subgraph) {
176+
for (int idx : subgraph) {
177+
// remove all edges connected to inactive node
178+
for (const auto& edge : node(idx).parents) {
179+
int parent = edge.first;
180+
node(parent).children.erase(idx);
181+
}
182+
for (const auto& edge : node(idx).children) {
183+
int child = edge.first;
184+
node(child).parents.erase(idx);
185+
}
186+
// actually mark flags as false
187+
node(idx).active = false;
188+
}
189+
}
190+
175191
} // namespace transform
176192

177193
} // namespace caffe2

caffe2/contrib/transform/graph.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,26 @@ namespace transform {
1717
*/
1818
struct Node {
1919
public:
20+
// Empty constructor for resize
21+
Node() {}
22+
23+
// Alternate constructor
24+
Node(
25+
const OperatorDef& op,
26+
bool active,
27+
std::map<int, string> parents,
28+
std::map<int, string> children)
29+
: op(op), active(active), parents(parents), children(children) {}
30+
2031
// The OperatorDef which this node represents.
2132
OperatorDef op;
2233

2334
// Keeps track of if an operator has been deleted through a transformation.
2435
bool active = true;
2536

2637
// Stores a pair (idx, blob), the index of the child, and the blob of edge.
27-
std::map<int, string> children;
2838
std::map<int, string> parents;
39+
std::map<int, string> children;
2940
};
3041

3142
/**
@@ -83,6 +94,11 @@ struct Graph {
8394
*/
8495
NetDef GetNetDef();
8596

97+
/**
98+
* Deactivate a subgraph, and get rid of all edges into this subgraph.
99+
*/
100+
void DeactivateSubgraph(std::vector<int> subgraph);
101+
86102
const size_t size() const {
87103
return nodes_.size();
88104
}

caffe2/contrib/transform/graph_test.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,20 @@ class DummyOp final : public OperatorBase {
2222
};
2323

2424
REGISTER_CPU_OPERATOR(DummyOp1, DummyOp);
25-
REGISTER_CUDA_OPERATOR(DummyOp1, DummyOp);
2625

2726
OPERATOR_SCHEMA(DummyOp1)
2827
.NumInputs(0, INT_MAX)
2928
.NumOutputs(0, INT_MAX)
3029
.AllowInplace({{0, 0}, {1, 1}});
3130

3231
REGISTER_CPU_OPERATOR(DummyOp2, DummyOp);
33-
REGISTER_CUDA_OPERATOR(DummyOp2, DummyOp);
3432

3533
OPERATOR_SCHEMA(DummyOp2)
3634
.NumInputs(0, INT_MAX)
3735
.NumOutputs(0, INT_MAX)
3836
.AllowInplace({{0, 0}, {1, 1}});
3937

4038
REGISTER_CPU_OPERATOR(DummyOp3, DummyOp);
41-
REGISTER_CUDA_OPERATOR(DummyOp3, DummyOp);
4239

4340
OPERATOR_SCHEMA(DummyOp3)
4441
.NumInputs(0, INT_MAX)

caffe2/contrib/transform/transform.cc

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include "caffe2/contrib/transform/transform.h"
2+
3+
#include "caffe2/core/common.h"
4+
#include "caffe2/core/logging.h"
5+
#include "caffe2/core/net.h"
6+
#include "caffe2/proto/caffe2.pb.h"
7+
8+
namespace caffe2 {
9+
10+
using transform::Graph;
11+
12+
CAFFE_DEFINE_REGISTRY(TransformRegistry, Transform);
13+
14+
std::vector<std::vector<int>> Transform::PatternMatch(const Graph& graph) {
15+
std::vector<std::vector<int>> matches;
16+
17+
// Consider every possible node as the starting point.
18+
for (int idx = 0; idx < graph.size(); ++idx) {
19+
// The current working subgraph. We will try to add new nodes to this,
20+
// when invoking the PatternRule.
21+
std::vector<int> subgraph;
22+
23+
// The largest "validated" subgraph found so far.
24+
// This will be mutated by PatternMatchHelper.
25+
std::vector<int> best_subgraph;
26+
27+
// Only begin to match if the start node is accepted.
28+
if (PatternRule(graph, subgraph, idx)) {
29+
subgraph.push_back(idx);
30+
PatternMatchHelper(graph, &subgraph, &best_subgraph);
31+
subgraph.pop_back();
32+
}
33+
if (best_subgraph.size() > 0) { // match found
34+
matches.push_back(best_subgraph);
35+
}
36+
}
37+
return matches;
38+
}
39+
40+
void Transform::TryNeighbors(
41+
const Graph& graph,
42+
const std::map<int, string>& neighbors,
43+
std::vector<int>* subgraph_ptr,
44+
std::vector<int>* best_subgraph_ptr) {
45+
auto& subgraph = *subgraph_ptr;
46+
for (const auto& edge : neighbors) {
47+
int j = edge.first;
48+
if (std::find(subgraph.begin(), subgraph.end(), j) == subgraph.end()) {
49+
if (PatternRule(graph, subgraph, j)) {
50+
subgraph.push_back(j);
51+
PatternMatchHelper(graph, subgraph_ptr, best_subgraph_ptr);
52+
subgraph.pop_back();
53+
}
54+
}
55+
}
56+
}
57+
58+
void Transform::PatternMatchHelper(
59+
const Graph& graph,
60+
std::vector<int>* subgraph_ptr,
61+
std::vector<int>* best_subgraph_ptr) {
62+
CHECK(subgraph_ptr);
63+
auto& subgraph = *subgraph_ptr;
64+
CHECK(best_subgraph_ptr);
65+
auto& best_subgraph = *best_subgraph_ptr;
66+
67+
// If the current subgraph is valid, and the largest we've seen so far,
68+
// make it the best_subgraph.
69+
if (ValidatorRule(graph, subgraph) &&
70+
subgraph.size() > best_subgraph.size()) {
71+
best_subgraph = subgraph;
72+
}
73+
74+
// Try adding each parent and child of every node in the subgraph,
75+
// and see if we can accept it.
76+
for (int i : subgraph) {
77+
TryNeighbors(
78+
graph, graph.node(i).children, subgraph_ptr, best_subgraph_ptr);
79+
TryNeighbors(graph, graph.node(i).parents, subgraph_ptr, best_subgraph_ptr);
80+
}
81+
}
82+
83+
void Transform::ReplacePattern(
84+
const std::vector<vector<int>>& matches,
85+
Graph* graph) {
86+
// Simply try to apply the replace rule upon every match.
87+
for (const auto& match : matches) {
88+
if (!ReplaceRule(match, graph)) {
89+
CAFFE_THROW("Replace failed!");
90+
}
91+
}
92+
}
93+
94+
// The simple interface - performs the transformation upon a NetDef, and returns
95+
// the result.
96+
NetDef Transform::ApplyTo(const NetDef& orig_net) {
97+
Graph g(orig_net);
98+
const auto matches = PatternMatch(g);
99+
ReplacePattern(matches, &g);
100+
return g.GetNetDef();
101+
}
102+
103+
} // namespace Caffe2

caffe2/contrib/transform/transform.h

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#pragma once
2+
3+
#include "caffe2/contrib/transform/graph.h"
4+
#include "caffe2/core/common.h"
5+
#include "caffe2/proto/caffe2.pb.h"
6+
#include "caffe2/utils/proto_utils.h"
7+
8+
namespace caffe2 {
9+
10+
/**
11+
* The Transform Base Object
12+
*
13+
* A Transform is an operation which manipulates a Caffe2 NetDef.
14+
* You can consider it as a function: Transform.ApplyTo(NetDef) -> NetDef
15+
*
16+
* A Transform Operation does 4 things:
17+
* 1) Creates a Graph object from a NetDef, which stores connections.
18+
* 2) Pattern Matches on the Graph, to find subgraphs it wants to change.
19+
* 3) Replaces the subgraphs that it's matched with new operators.
20+
* 4) Creates a NetDef from the changed Graph, and returns it.
21+
*
22+
* The effect of a Transform is defined by its 3 protected virtual functions.
23+
* 1) PatternRule determines for an ordered subgraph and a node, whether to
24+
* consider adding the node to the subgraph.
25+
* 2) ValidatorRule determines, for an ordered subgraph, whether it is a
26+
* match.
27+
* 3) ReplaceRule mutates the graph, based on a matched subgraph.
28+
*
29+
* This is the base class for all derived classes to base off. To create your
30+
* own transform, write your implementations for PatternRule, ValidatorRule, and
31+
* ReplaceRule.
32+
*/
33+
class Transform {
34+
public:
35+
Transform() {}
36+
37+
/**
38+
* Apply a Transform onto a NetDef.
39+
* Returns the transformed NetDef.
40+
*/
41+
NetDef ApplyTo(const NetDef& orig_net_def);
42+
43+
virtual ~Transform() {}
44+
45+
/**
46+
* Generates all matches (stored as ordered subgraphs) and returns them.
47+
*
48+
* A match is stored as vector<int>, which is a mapping to OperatorDefs
49+
* in Graph. The order matters.
50+
*/
51+
std::vector<std::vector<int>> PatternMatch(const transform::Graph& graph);
52+
53+
/**
54+
* Applies the replace rule onto each of the matches found.
55+
*/
56+
void ReplacePattern(
57+
const std::vector<std::vector<int>>& matches,
58+
transform::Graph* graph);
59+
60+
protected:
61+
/**
62+
* The PatternRule essentially answers:
63+
* Given the current subgraph (ordered), should we append the new node at idx?
64+
*/
65+
virtual bool PatternRule(
66+
const transform::Graph& g,
67+
const std::vector<int>& subgraph,
68+
int idx) {
69+
CAFFE_NOT_IMPLEMENTED;
70+
}
71+
72+
/**
73+
* The ValidatorRule essentially answers:
74+
* Given a subgraph, can we accept it?
75+
*/
76+
virtual bool ValidatorRule(
77+
const transform::Graph& g,
78+
const std::vector<int>& subgraph) {
79+
CAFFE_NOT_IMPLEMENTED;
80+
}
81+
82+
/**
83+
* The ReplaceRule actually mutates the graph, and applies the transformation
84+
* upon the subgraph.
85+
*/
86+
virtual bool ReplaceRule(
87+
const std::vector<int>& subgraph,
88+
transform::Graph* g_ptr) {
89+
CAFFE_NOT_IMPLEMENTED;
90+
}
91+
92+
private:
93+
/**
94+
* A helper function for PatternMatch, which keeps track of the best subgraph
95+
* so far.
96+
*/
97+
void PatternMatchHelper(
98+
const transform::Graph& graph,
99+
std::vector<int>* subgraph_ptr,
100+
std::vector<int>* best_subgraph_ptr);
101+
/**
102+
* Attempts to append each neighbor to the end of the subgraph.
103+
*/
104+
void TryNeighbors(
105+
const transform::Graph& graph,
106+
const std::map<int, string>& neighbors,
107+
std::vector<int>* subgraph_ptr,
108+
std::vector<int>* best_subgraph_ptr);
109+
};
110+
111+
CAFFE_DECLARE_REGISTRY(TransformRegistry, Transform);
112+
113+
} // namespace

0 commit comments

Comments
 (0)