Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.

Commit 518b458

Browse files
committed
[graph-scheduler] If a node has a user which does not have any users and which does not require any additional memory, schedule this user right after the current node
This handles QuantizationProfileNode scheduling in a more general way, which is not dependent on the kind of the node. Also, take the opportunity to generalize the "SaveNode hack" in the scheduler and make it independent of the Node's kind.
1 parent cdbe0db commit 518b458

File tree

2 files changed

+87
-11
lines changed

2 files changed

+87
-11
lines changed

lib/IR/ChildMemSizeBasedScheduler.cpp

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,24 +98,27 @@ void ChildMemSizeBasedScheduler::orderChildNodesAndSchedule(Node *N) {
9898
orderedChildren.push_back(N->getPredicate());
9999
}
100100

101-
// SaveNode hack:
102101
// We don't model memory dependencies, but we still need to honor them.
103-
// Make sure the SaveNode happens after the last use of the output
104-
// placeholder.
105-
if (auto *save = dyn_cast<SaveNode>(N)) {
106-
auto *destination = save->getOutput().getNode();
102+
// Make sure the a node mutating any of its inputs happens after the last
103+
// non-mutating use of the operand being mutated. Some examples of such nodes
104+
// would be SaveNode and QuantizationProfileNode.
105+
for (unsigned idx = 0, e = N->getNumInputs(); idx < e; ++idx) {
106+
// We don't care about inputs that are not mutated by the node.
107+
if (!N->isOverwrittenNthInput(idx)) {
108+
continue;
109+
}
110+
auto mutatedInput = N->getNthInput(idx);
111+
auto *destination = mutatedInput.getNode();
107112
for (NodeUse &use : destination->getUsers()) {
108113
Node *user = use.getUser();
109-
if (user == save) {
114+
if (user == N) {
110115
continue;
111116
}
112-
// Storage nodes may have users scattered across different functions.
117+
// Nodes may have users scattered across different functions.
113118
// Only accounts for the ones in that function.
114119
if (&G_ != user->getParent()) {
115120
continue;
116121
}
117-
assert(!isa<SaveNode>(user) &&
118-
"Placeholder must be saved at most once in each function");
119122
orderedChildren.push_back(user);
120123
}
121124
}
@@ -148,9 +151,34 @@ void ChildMemSizeBasedScheduler::orderChildNodesAndSchedule(Node *N) {
148151
orderChildNodesAndSchedule(child);
149152
}
150153

151-
// Schedule the node after all its children are scheduled.
152-
DEBUG_GLOW(llvm::dbgs() << "Scheduled node: " << N->getName() << "\n");
154+
// Schedule the node after all its children are scheduled. We need to perform
155+
// an extra isScheduled check here, because the code below may have scheduled
156+
// the current node while scheduling its children.
157+
if (isScheduled(N)) {
158+
return;
159+
}
153160
scheduled_.push_back(N);
161+
// If this node has a user which does not have any users and which does not
162+
// require any additional memory, schedule it here, because we don't want to
163+
// extend the lifetime of this value for no reason. We want to execute and get
164+
// rid of this node as soon as possible to reduce the memory pressure.
165+
for (NodeUse &use : N->getUsers()) {
166+
Node *user = use.getUser();
167+
// Users may be scattered across different functions.
168+
// Only accounts for the ones in that function.
169+
if (&G_ != user->getParent()) {
170+
continue;
171+
}
172+
// Bail if a nodes has users, because nodes that have users can't be
173+
// scheduled safely without violating dependencies.
174+
if (user->getNumUsers()) {
175+
continue;
176+
}
177+
// Schedule a node if it does not require any additional memory.
178+
if (resultMemSize_[user] == 0) {
179+
orderChildNodesAndSchedule(user);
180+
}
181+
}
154182
}
155183

156184
void ChildMemSizeBasedScheduler::scheduleNodes() {

tests/unittests/GraphSchedulerTest.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,51 @@ TEST(GraphScheduler, testMaxSizeLessThanResultSize) {
113113
std::distance(schedule.begin(), concatSmallIt));
114114
}
115115
}
116+
117+
TEST(GraphScheduler, ScheduleQuantizationProfileRightAfterNodeBeingProfiled) {
118+
Module MD;
119+
PlaceholderBindings bindings;
120+
auto *input1 =
121+
MD.createPlaceholder(ElemKind::FloatTy, {1, 4, 4}, "input1", false);
122+
bindings.allocate(input1);
123+
auto *input2 =
124+
MD.createPlaceholder(ElemKind::FloatTy, {1, 4, 4}, "input2", false);
125+
bindings.allocate(input2);
126+
Function *F = MD.createFunction("F");
127+
Node *add = F->createAdd("add", input1, input2);
128+
Node *sub = F->createSub("sub", input1, input2);
129+
Node *mul = F->createMul("mul", add, sub);
130+
Node *save = F->createSave("save", mul);
131+
Node *quantizationProfileAdd =
132+
F->createQuantizationProfile(bindings, "qpAdd", add);
133+
Node *quantizationProfileSub =
134+
F->createQuantizationProfile(bindings, "qpSub", sub);
135+
136+
// Since all of the tensors are Variables, they don't need
137+
// memory for storing their outputs. Consequently, sliceBig
138+
// should be scheduled before concatSmall in this example
139+
// because the former frees up some memory while the latter
140+
// uses up more memory after execution.
141+
NodesPtrList schedule;
142+
ChildMemSizeBasedScheduler scheduler(*F, schedule);
143+
scheduler.schedule();
144+
145+
// Find the positions of add and quantizationProfileAdd in the schedule.
146+
auto addIt = std::find(schedule.begin(), schedule.end(), add);
147+
auto qpAddIt =
148+
std::find(schedule.begin(), schedule.end(), quantizationProfileAdd);
149+
// Expect the quantization profiling node to be scheduled right after the node
150+
// being profiled.
151+
EXPECT_EQ(++addIt, qpAddIt);
152+
153+
// Find the positions of sub and quantizationProfileSub in the schedule.
154+
auto subIt = std::find(schedule.begin(), schedule.end(), sub);
155+
auto qpSubIt =
156+
std::find(schedule.begin(), schedule.end(), quantizationProfileSub);
157+
// Expect the quantization profiling node to be scheduled right after the node
158+
// being profiled.
159+
EXPECT_EQ(++subIt, qpSubIt);
160+
161+
// Expect the save node to be the last in the schedule.
162+
EXPECT_EQ(save, schedule.back());
163+
}

0 commit comments

Comments
 (0)