Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.

Commit d299b98

Browse files
authored
Remove tensor copies from the executor (#2821)
1 parent 446a5b5 commit d299b98

File tree

5 files changed

+66
-140
lines changed

5 files changed

+66
-140
lines changed

include/glow/Graph/PlaceholderBindings.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ class PlaceholderBindings final {
8888
/// tensors.
8989
void clear();
9090

91+
/// Removes the Tensor backing Placeholder \p P;
92+
/// \p P must be a valid Placeholder registered in the bindings.
93+
void erase(Placeholder *P);
94+
9195
/// \returns a copy of the PlaceholderBindings, with each placeholder mapped
9296
/// to a new Tensor, with their own memory.
9397
PlaceholderBindings clone() const;

lib/Graph/PlaceholderBindings.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ void PlaceholderBindings::clear() {
9797
nameMap_.clear();
9898
}
9999

100+
void PlaceholderBindings::erase(Placeholder *P) {
101+
assert(nameMap_.count(P->getName()) &&
102+
"Placeholder must already be registered");
103+
nameMap_.erase(P->getName());
104+
delete map_[P];
105+
map_.erase(P);
106+
}
107+
100108
PlaceholderBindings PlaceholderBindings::clone() const {
101109
PlaceholderBindings cloned;
102110
for (auto PH : map_) {

lib/Runtime/Executor/ThreadPoolExecutor.cpp

Lines changed: 35 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,26 @@ ExecutionState::ExecutionState(RunIdentifierTy id, const DAGNode *root,
9797
// Create Placeholders for the symbols of all intermediate nodes. These are
9898
// not in the ExecutionContext passed to Executor::run, so they must be
9999
// created by the Executor.
100+
auto *resultBindings = resultCtx_->getPlaceholderBindings();
100101
for (const auto &symbolPair : symbolTable) {
101102
const auto &symbolName = symbolPair.first;
102103
const auto &symbolInfo = symbolPair.second;
103104

104105
if (symbolInfo.symbolCategory == SymbolCategory::Placeholder) {
105-
auto PH = module_->getPlaceholderByName(symbolName);
106-
// If a PH name is provided it had to come from the Module originally.
107-
DCHECK(PH) << "Placeholder: " << symbolName << " is not in the module";
108-
nodeInputPhBindings->allocate(PH);
106+
auto *PH = resultBindings->getPlaceholderByName(symbolName);
107+
if (!PH) {
108+
PH = module_->getPlaceholderByName(symbolName);
109+
DCHECK(PH) << "Placeholder: " << symbolName
110+
<< " is not in the module";
111+
112+
// allocate into the resultBindings because they have the longest
113+
// lifetime.
114+
resultBindings->allocate(PH);
115+
intermediatePlaceholders_.push_back(PH);
116+
}
117+
118+
nodeInputPhBindings->insert(
119+
PH, resultBindings->get(PH)->getUnowned(PH->dims()));
109120
}
110121
}
111122

@@ -123,26 +134,6 @@ ExecutionState::ExecutionState(RunIdentifierTy id, const DAGNode *root,
123134
}
124135
}
125136

126-
void ExecutionState::insertIntoNodeCtx(const DAGNode *node,
127-
llvm::StringRef name, Tensor &&T) {
128-
// Get a raw pointer to the input ExecutionContext for the node. It should
129-
// have been created in the constructor.
130-
auto ctxIt = inputCtxs_.find(node);
131-
132-
if (ctxIt == inputCtxs_.end()) {
133-
assert(!"Input bindings not found but should exist!");
134-
}
135-
136-
PlaceholderBindings *bindings = (ctxIt->second)->getPlaceholderBindings();
137-
assert(bindings && "Input bindings for node is null");
138-
139-
// Insert the placeholder-tensor pair.
140-
std::lock_guard<std::mutex> lock(bindingsMtx_);
141-
auto *tensor = bindings->get(bindings->getPlaceholderByName(name));
142-
assert(tensor && "Placeholder should have already been created");
143-
*tensor = std::move(T);
144-
}
145-
146137
std::unique_ptr<ExecutionContext>
147138
ExecutionState::getUniqueNodeContextPtr(const DAGNode *node) {
148139
// The input PlaceholderBindings for the node should have been created in the
@@ -201,22 +192,6 @@ bool ExecutionState::incrementNodeParentsDone(const DAGNode *node,
201192
return (newValue == numParents);
202193
}
203194

204-
void ExecutionState::insertIntoResultCtx(llvm::StringRef name, Tensor &&T) {
205-
// The result PlaceholderBindings should have been been created in the
206-
// constructor and should not yet have been moved out if this function is
207-
// being called.
208-
assert(resultCtx_ && resultCtx_->getPlaceholderBindings() &&
209-
"Execution result bindings should exist!");
210-
std::lock_guard<std::mutex> lock(bindingsMtx_);
211-
auto *resultBindings = resultCtx_->getPlaceholderBindings();
212-
Tensor *tensor =
213-
resultBindings->get(resultBindings->getPlaceholderByName(name));
214-
215-
if (tensor) {
216-
*tensor = std::move(T);
217-
}
218-
}
219-
220195
void ExecutionState::insertIntoTraceContext(std::vector<TraceEvent> &events) {
221196
if (!resultCtx_->getTraceContext()) {
222197
events.clear();
@@ -229,6 +204,13 @@ void ExecutionState::insertIntoTraceContext(std::vector<TraceEvent> &events) {
229204
std::back_inserter(resultCtx_->getTraceContext()->getTraceEvents()));
230205
}
231206

207+
void ExecutionState::removeIntermediatePlaceholders() {
208+
for (auto &p : intermediatePlaceholders_) {
209+
resultCtx_->getPlaceholderBindings()->erase(p);
210+
}
211+
intermediatePlaceholders_.clear();
212+
}
213+
232214
std::unique_ptr<ExecutionContext> ExecutionState::getUniqueResultContextPtr() {
233215
// The result PlaceholderBindings should have been been created in the
234216
// constructor.
@@ -308,39 +290,11 @@ void ThreadPoolExecutor::run(const DAGNode *root,
308290
inflightBarrier_.increment(numChildren);
309291

310292
for (auto const &node : root->children) {
311-
// Propagate placeholders from the given starter PlaceholderBindings into
312-
// the input PlaceholderBindings for the current node being processed.
313-
propagatePlaceholdersForNode(executionState, node,
314-
executionState->getRawResultContextPtr());
315-
316293
// Execute the node.
317294
executeDAGNode(executionState, node);
318295
}
319296
}
320297

321-
void ThreadPoolExecutor::propagatePlaceholdersForNode(
322-
std::shared_ptr<ExecutionState> executionState, const DAGNode *node,
323-
const ExecutionContext *ctx) {
324-
ScopedTraceBlock(executionState->getRawResultContextPtr()->getTraceContext(),
325-
"EX_propagateInputs");
326-
// Get the symbol table for the node.
327-
const SymbolTableTy &symbolTable = node->runtimeBundle->getSymbolTable();
328-
329-
for (const auto &symbolPair : symbolTable) {
330-
const auto &symbolName = symbolPair.first;
331-
332-
auto *placeholder =
333-
ctx->getPlaceholderBindings()->getPlaceholderByName(symbolName);
334-
335-
// If ctx provides a mapping for the symbol, copy it into the context for
336-
// the node.
337-
if (placeholder) {
338-
const auto *tensor = ctx->getPlaceholderBindings()->get(placeholder);
339-
executionState->insertIntoNodeCtx(node, symbolName, tensor->clone());
340-
}
341-
}
342-
}
343-
344298
void ThreadPoolExecutor::executeDAGNode(
345299
std::shared_ptr<ExecutionState> executionState, DAGNode *node) {
346300
// If execution has already failed due to another node, don't bother running
@@ -396,25 +350,10 @@ void ThreadPoolExecutor::executeDAGNode(
396350
});
397351
}
398352

399-
void ThreadPoolExecutor::propagateOutputPlaceholders(
400-
std::shared_ptr<ExecutionState> executionState,
401-
PlaceholderBindings *bindings) {
402-
ScopedTraceBlock(executionState->getRawResultContextPtr()->getTraceContext(),
403-
"EX_propagateOutputs");
404-
// Copy all of the Placeholders in bindings into the result
405-
// PlaceholderBindings for the run.
406-
for (const auto &phTensorPair : bindings->pairs()) {
407-
auto *placeholder = phTensorPair.first;
408-
auto *tensor = phTensorPair.second;
409-
410-
executionState->insertIntoResultCtx(placeholder->getName(),
411-
std::move(*tensor));
412-
}
413-
}
414-
415353
void ThreadPoolExecutor::handleDeviceManagerResult(
416354
std::shared_ptr<ExecutionState> executionState, llvm::Error err,
417355
std::unique_ptr<ExecutionContext> ctx, const DAGNode *node) {
356+
418357
// If executionState is null, that means that the object was deleted
419358
// while a node was executing. That should never happen.
420359
assert(executionState && "Execution state should not be null");
@@ -430,26 +369,15 @@ void ThreadPoolExecutor::handleDeviceManagerResult(
430369
// If the DeviceManager executed the node, propagate its output Placeholders
431370
// to its children or the result PlaceholderBindings as appropriate.
432371
if (runWasSuccess) {
433-
if ((node->children).empty()) {
434-
// If the node has no children, propagate its outputs to the result
435-
// PlaceholderBindings for the run.
436-
propagateOutputPlaceholders(executionState,
437-
ctx->getPlaceholderBindings());
438-
} else {
439-
// If the node has children, propagate its outputs to the input
440-
// PlaceholderBindings for any of its children that need them as inputs.
441-
for (auto &child : node->children) {
442-
propagatePlaceholdersForNode(executionState, child, ctx.get());
443-
444-
// Execute any child that has no parent nodes left to execute.
445-
bool childReadyToExecute =
446-
executionState->incrementNodeParentsDone(child);
447-
if (childReadyToExecute) {
448-
// Mark the node as "inflight" (i.e. currently executing).
449-
executionState->incrementInflightNodes();
450-
inflightBarrier_.increment();
451-
executeDAGNode(executionState, child);
452-
}
372+
for (auto &child : node->children) {
373+
// Execute any child that has no parent nodes left to execute.
374+
bool childReadyToExecute =
375+
executionState->incrementNodeParentsDone(child);
376+
if (childReadyToExecute) {
377+
// Mark the node as "inflight" (i.e. currently executing).
378+
executionState->incrementInflightNodes();
379+
inflightBarrier_.increment();
380+
executeDAGNode(executionState, child);
453381
}
454382
}
455383
}
@@ -464,6 +392,9 @@ void ThreadPoolExecutor::handleDeviceManagerResult(
464392
}
465393

466394
if (noNodesInflight) {
395+
// Remove the intermediate placeholders so we don't leak them to the caller.
396+
executionState->removeIntermediatePlaceholders();
397+
467398
// If there are no nodes inflight, that means all nodes are done. Call
468399
// the callback and erase the state information.
469400
ResultCBTy cb = executionState->getCallback();

lib/Runtime/Executor/ThreadPoolExecutor.h

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,6 @@ class ExecutionState final {
6161
std::unique_ptr<ExecutionContext> resultContext,
6262
ResultCBTy doneCb);
6363

64-
/// Insert \p T as the mapped Tensor for the Placeholder named \p name in the
65-
/// bindings of the context for \p node. This should not be called at the same
66-
/// time as getUniqueNodeContextPtr().
67-
void insertIntoNodeCtx(const DAGNode *node, llvm::StringRef name, Tensor &&T);
68-
6964
/// \returns a unique pointer to an input bindings for \p node. This should
7065
/// not be called at the same time as insertIntoNodeCtx().
7166
std::unique_ptr<ExecutionContext>
@@ -84,14 +79,12 @@ class ExecutionState final {
8479
/// otherwise.
8580
bool incrementNodeParentsDone(const DAGNode *node, unsigned increment = 1);
8681

87-
/// Insert \p T as the mapped Tensor for the Placeholder in the bindings of
88-
/// the result context named \p name. This should not be called at the same
89-
/// time as getUniqueResultPlaceholderBindingsPtr().
90-
void insertIntoResultCtx(llvm::StringRef name, Tensor &&T);
91-
9282
/// Move all events from the provided vector into the top level resultContxt.
9383
void insertIntoTraceContext(std::vector<TraceEvent> &events);
9484

85+
/// Remove intermediate placeholders not required in the final output.
86+
void removeIntermediatePlaceholders();
87+
9588
/// \returns a unique pointer to the result bindings. This should not be
9689
/// called at the same time as getRawResultPlaceholderBindingsPtr() or
9790
/// insertIntoResultCtx().
@@ -126,9 +119,9 @@ class ExecutionState final {
126119
std::unordered_map<const DAGNode *, std::unique_ptr<ExecutionContext>>
127120
inputCtxs_;
128121
/// Placeholders for tensors generated by DAG nodes that aren't the final
129-
/// output (i.e. they have children). The set of currently executing nodes.
130-
std::unordered_map<std::string, std::unique_ptr<Placeholder>>
131-
intermediatePlaceholders_;
122+
/// output (i.e. they have children). The owning pointer for these tensors
123+
/// exists in the resultCtx and are removed before the ResultCB is called.
124+
std::vector<Placeholder *> intermediatePlaceholders_;
132125
/// This is populated with the roots when a run starts, and does not become
133126
/// empty until execution finishes.
134127
std::atomic<unsigned> inflightNodes_;
@@ -162,20 +155,6 @@ class ThreadPoolExecutor final : public Executor {
162155
void shutdown() override;
163156

164157
private:
165-
/// Propagate Placeholders from \p ctx into the final output
166-
/// ExecutionContext for the run corresponding to \p executionState.
167-
void
168-
propagateOutputPlaceholders(std::shared_ptr<ExecutionState> executionState,
169-
PlaceholderBindings *bindings);
170-
171-
/// Propagate Placeholders needed by \p node from \p ctx into
172-
/// the ExecutionContext for \p node within the run corresponding to \p
173-
/// executionState.
174-
void
175-
propagatePlaceholdersForNode(std::shared_ptr<ExecutionState> executionState,
176-
const DAGNode *node,
177-
const ExecutionContext *ctx);
178-
179158
/// Execute the DAG node specified by \p node within the run corresponding to
180159
/// \p executionState.
181160
void executeDAGNode(std::shared_ptr<ExecutionState> executionState,

tests/unittests/ExecutorTest.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ class TestDeviceManager final : public runtime::DeviceManager {
7272
/// call \p resultCB with it after checking that \p context contains the
7373
/// expected Placeholder-Tensor mappings.
7474
void doRunFunction(std::string functionName,
75-
std::shared_ptr<ExecutionContext> context,
75+
std::unique_ptr<ExecutionContext> context,
7676
ResultCBTy resultCB) {
77+
7778
RunIdentifierTy runId = 0;
7879
bool successResult = false;
79-
std::unique_ptr<ExecutionContext> resultContext = nullptr;
8080

8181
// Retrieve the registered response for the function if there is one.
8282
if (context && resultCB && resultMap_.count(functionName)) {
@@ -95,14 +95,18 @@ class TestDeviceManager final : public runtime::DeviceManager {
9595
// ones.
9696
runId = registeredResult->runId;
9797
successResult = registeredResult->success;
98-
resultContext = std::move(registeredResult->resultContext);
98+
99+
for (auto p : registeredResult->resultContext->getPlaceholderBindings()
100+
->pairs()) {
101+
context->getPlaceholderBindings()->get(p.first)->assign(p.second);
102+
}
99103
}
100104
}
101105

102106
if (successResult) {
103-
resultCB(runId, llvm::Error::success(), std::move(resultContext));
107+
resultCB(runId, llvm::Error::success(), std::move(context));
104108
} else {
105-
resultCB(runId, MAKE_ERR("An error occurred"), std::move(resultContext));
109+
resultCB(runId, MAKE_ERR("An error occurred"), std::move(context));
106110
}
107111
}
108112

@@ -113,10 +117,10 @@ class TestDeviceManager final : public runtime::DeviceManager {
113117
ResultCBTy resultCB) override {
114118
// Give the call to the thread pool to process to make the tests
115119
// multithreaded if needed.
116-
std::shared_ptr<ExecutionContext> sharedContext = std::move(context);
117-
this->threadPool_.submit([this, functionName, sharedContext, resultCB]() {
118-
this->doRunFunction(functionName, sharedContext, resultCB);
119-
});
120+
this->threadPool_.submit(
121+
[this, functionName, context = std::move(context), resultCB]() mutable {
122+
this->doRunFunction(functionName, std::move(context), resultCB);
123+
});
120124
return 0;
121125
}
122126

0 commit comments

Comments
 (0)