Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/glow/Graph/PlaceholderBindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class PlaceholderBindings final {
/// tensors.
void clear();

/// Removes the Tensor backing Placeholder \p P;
/// \p P must be a valid Placeholder registered in the bindings.
void erase(Placeholder *P);

/// \returns a copy of the PlaceholderBindings, with each placeholder mapped
/// to a new Tensor, with their own memory.
PlaceholderBindings clone() const;
Expand Down
8 changes: 8 additions & 0 deletions lib/Graph/PlaceholderBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ void PlaceholderBindings::clear() {
nameMap_.clear();
}

void PlaceholderBindings::erase(Placeholder *P) {
assert(nameMap_.count(P->getName()) &&
"Placeholder must already be registered");
nameMap_.erase(P->getName());
delete map_[P];
map_.erase(P);
}

PlaceholderBindings PlaceholderBindings::clone() const {
PlaceholderBindings cloned;
for (auto PH : map_) {
Expand Down
139 changes: 35 additions & 104 deletions lib/Runtime/Executor/ThreadPoolExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,26 @@ ExecutionState::ExecutionState(RunIdentifierTy id, const DAGNode *root,
// Create Placeholders for the symbols of all intermediate nodes. These are
// not in the ExecutionContext passed to Executor::run, so they must be
// created by the Executor.
auto *resultBindings = resultCtx_->getPlaceholderBindings();
for (const auto &symbolPair : symbolTable) {
const auto &symbolName = symbolPair.first;
const auto &symbolInfo = symbolPair.second;

if (symbolInfo.symbolCategory == SymbolCategory::Placeholder) {
auto PH = module_->getPlaceholderByName(symbolName);
// If a PH name is provided it had to come from the Module originally.
DCHECK(PH) << "Placeholder: " << symbolName << " is not in the module";
nodeInputPhBindings->allocate(PH);
auto *PH = resultBindings->getPlaceholderByName(symbolName);
if (!PH) {
PH = module_->getPlaceholderByName(symbolName);
DCHECK(PH) << "Placeholder: " << symbolName
<< " is not in the module";

// allocate into the resultBindings because they have the longest
// lifetime.
resultBindings->allocate(PH);
intermediatePlaceholders_.push_back(PH);
}

nodeInputPhBindings->insert(
PH, resultBindings->get(PH)->getUnowned(PH->dims()));
}
}

Expand All @@ -123,26 +134,6 @@ ExecutionState::ExecutionState(RunIdentifierTy id, const DAGNode *root,
}
}

void ExecutionState::insertIntoNodeCtx(const DAGNode *node,
llvm::StringRef name, Tensor &&T) {
// Get a raw pointer to the input ExecutionContext for the node. It should
// have been created in the constructor.
auto ctxIt = inputCtxs_.find(node);

if (ctxIt == inputCtxs_.end()) {
assert(!"Input bindings not found but should exist!");
}

PlaceholderBindings *bindings = (ctxIt->second)->getPlaceholderBindings();
assert(bindings && "Input bindings for node is null");

// Insert the placeholder-tensor pair.
std::lock_guard<std::mutex> lock(bindingsMtx_);
auto *tensor = bindings->get(bindings->getPlaceholderByName(name));
assert(tensor && "Placeholder should have already been created");
*tensor = std::move(T);
}

std::unique_ptr<ExecutionContext>
ExecutionState::getUniqueNodeContextPtr(const DAGNode *node) {
// The input PlaceholderBindings for the node should have been created in the
Expand Down Expand Up @@ -201,22 +192,6 @@ bool ExecutionState::incrementNodeParentsDone(const DAGNode *node,
return (newValue == numParents);
}

void ExecutionState::insertIntoResultCtx(llvm::StringRef name, Tensor &&T) {
// The result PlaceholderBindings should have been been created in the
// constructor and should not yet have been moved out if this function is
// being called.
assert(resultCtx_ && resultCtx_->getPlaceholderBindings() &&
"Execution result bindings should exist!");
std::lock_guard<std::mutex> lock(bindingsMtx_);
auto *resultBindings = resultCtx_->getPlaceholderBindings();
Tensor *tensor =
resultBindings->get(resultBindings->getPlaceholderByName(name));

if (tensor) {
*tensor = std::move(T);
}
}

void ExecutionState::insertIntoTraceContext(std::vector<TraceEvent> &events) {
if (!resultCtx_->getTraceContext()) {
events.clear();
Expand All @@ -229,6 +204,13 @@ void ExecutionState::insertIntoTraceContext(std::vector<TraceEvent> &events) {
std::back_inserter(resultCtx_->getTraceContext()->getTraceEvents()));
}

void ExecutionState::removeIntermediatePlaceholders() {
for (auto &p : intermediatePlaceholders_) {
resultCtx_->getPlaceholderBindings()->erase(p);
}
intermediatePlaceholders_.clear();
}

std::unique_ptr<ExecutionContext> ExecutionState::getUniqueResultContextPtr() {
// The result PlaceholderBindings should have been been created in the
// constructor.
Expand Down Expand Up @@ -308,39 +290,11 @@ void ThreadPoolExecutor::run(const DAGNode *root,
inflightBarrier_.increment(numChildren);

for (auto const &node : root->children) {
// Propagate placeholders from the given starter PlaceholderBindings into
// the input PlaceholderBindings for the current node being processed.
propagatePlaceholdersForNode(executionState, node,
executionState->getRawResultContextPtr());

// Execute the node.
executeDAGNode(executionState, node);
}
}

void ThreadPoolExecutor::propagatePlaceholdersForNode(
std::shared_ptr<ExecutionState> executionState, const DAGNode *node,
const ExecutionContext *ctx) {
ScopedTraceBlock(executionState->getRawResultContextPtr()->getTraceContext(),
"EX_propagateInputs");
// Get the symbol table for the node.
const SymbolTableTy &symbolTable = node->runtimeBundle->getSymbolTable();

for (const auto &symbolPair : symbolTable) {
const auto &symbolName = symbolPair.first;

auto *placeholder =
ctx->getPlaceholderBindings()->getPlaceholderByName(symbolName);

// If ctx provides a mapping for the symbol, copy it into the context for
// the node.
if (placeholder) {
const auto *tensor = ctx->getPlaceholderBindings()->get(placeholder);
executionState->insertIntoNodeCtx(node, symbolName, tensor->clone());
}
}
}

void ThreadPoolExecutor::executeDAGNode(
std::shared_ptr<ExecutionState> executionState, DAGNode *node) {
// If execution has already failed due to another node, don't bother running
Expand Down Expand Up @@ -396,25 +350,10 @@ void ThreadPoolExecutor::executeDAGNode(
});
}

void ThreadPoolExecutor::propagateOutputPlaceholders(
std::shared_ptr<ExecutionState> executionState,
PlaceholderBindings *bindings) {
ScopedTraceBlock(executionState->getRawResultContextPtr()->getTraceContext(),
"EX_propagateOutputs");
// Copy all of the Placeholders in bindings into the result
// PlaceholderBindings for the run.
for (const auto &phTensorPair : bindings->pairs()) {
auto *placeholder = phTensorPair.first;
auto *tensor = phTensorPair.second;

executionState->insertIntoResultCtx(placeholder->getName(),
std::move(*tensor));
}
}

void ThreadPoolExecutor::handleDeviceManagerResult(
std::shared_ptr<ExecutionState> executionState, llvm::Error err,
std::unique_ptr<ExecutionContext> ctx, const DAGNode *node) {

// If executionState is null, that means that the object was deleted
// while a node was executing. That should never happen.
assert(executionState && "Execution state should not be null");
Expand All @@ -430,26 +369,15 @@ void ThreadPoolExecutor::handleDeviceManagerResult(
// If the DeviceManager executed the node, propagate its output Placeholders
// to its children or the result PlaceholderBindings as appropriate.
if (runWasSuccess) {
if ((node->children).empty()) {
// If the node has no children, propagate its outputs to the result
// PlaceholderBindings for the run.
propagateOutputPlaceholders(executionState,
ctx->getPlaceholderBindings());
} else {
// If the node has children, propagate its outputs to the input
// PlaceholderBindings for any of its children that need them as inputs.
for (auto &child : node->children) {
propagatePlaceholdersForNode(executionState, child, ctx.get());

// Execute any child that has no parent nodes left to execute.
bool childReadyToExecute =
executionState->incrementNodeParentsDone(child);
if (childReadyToExecute) {
// Mark the node as "inflight" (i.e. currently executing).
executionState->incrementInflightNodes();
inflightBarrier_.increment();
executeDAGNode(executionState, child);
}
for (auto &child : node->children) {
// Execute any child that has no parent nodes left to execute.
bool childReadyToExecute =
executionState->incrementNodeParentsDone(child);
if (childReadyToExecute) {
// Mark the node as "inflight" (i.e. currently executing).
executionState->incrementInflightNodes();
inflightBarrier_.increment();
executeDAGNode(executionState, child);
}
}
}
Expand All @@ -464,6 +392,9 @@ void ThreadPoolExecutor::handleDeviceManagerResult(
}

if (noNodesInflight) {
// Remove the intermediate placeholders so we don't leak them to the caller.
executionState->removeIntermediatePlaceholders();

// If there are no nodes inflight, that means all nodes are done. Call
// the callback and erase the state information.
ResultCBTy cb = executionState->getCallback();
Expand Down
33 changes: 6 additions & 27 deletions lib/Runtime/Executor/ThreadPoolExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ class ExecutionState final {
std::unique_ptr<ExecutionContext> resultContext,
ResultCBTy doneCb);

/// Insert \p T as the mapped Tensor for the Placeholder named \p name in the
/// bindings of the context for \p node. This should not be called at the same
/// time as getUniqueNodeContextPtr().
void insertIntoNodeCtx(const DAGNode *node, llvm::StringRef name, Tensor &&T);

/// \returns a unique pointer to an input bindings for \p node. This should
/// not be called at the same time as insertIntoNodeCtx().
std::unique_ptr<ExecutionContext>
Expand All @@ -84,14 +79,12 @@ class ExecutionState final {
/// otherwise.
bool incrementNodeParentsDone(const DAGNode *node, unsigned increment = 1);

/// Insert \p T as the mapped Tensor for the Placeholder in the bindings of
/// the result context named \p name. This should not be called at the same
/// time as getUniqueResultPlaceholderBindingsPtr().
void insertIntoResultCtx(llvm::StringRef name, Tensor &&T);

/// Move all events from the provided vector into the top level resultContxt.
void insertIntoTraceContext(std::vector<TraceEvent> &events);

/// Remove intermediate placeholders not required in the final output.
void removeIntermediatePlaceholders();

/// \returns a unique pointer to the result bindings. This should not be
/// called at the same time as getRawResultPlaceholderBindingsPtr() or
/// insertIntoResultCtx().
Expand Down Expand Up @@ -126,9 +119,9 @@ class ExecutionState final {
std::unordered_map<const DAGNode *, std::unique_ptr<ExecutionContext>>
inputCtxs_;
/// Placeholders for tensors generated by DAG nodes that aren't the final
/// output (i.e. they have children). The set of currently executing nodes.
std::unordered_map<std::string, std::unique_ptr<Placeholder>>
intermediatePlaceholders_;
/// output (i.e. they have children). The owning pointer for these tensors
/// exists in the resultCtx and are removed before the ResultCB is called.
std::vector<Placeholder *> intermediatePlaceholders_;
/// This is populated with the roots when a run starts, and does not become
/// empty until execution finishes.
std::atomic<unsigned> inflightNodes_;
Expand Down Expand Up @@ -162,20 +155,6 @@ class ThreadPoolExecutor final : public Executor {
void shutdown() override;

private:
/// Propagate Placeholders from \p ctx into the final output
/// ExecutionContext for the run corresponding to \p executionState.
void
propagateOutputPlaceholders(std::shared_ptr<ExecutionState> executionState,
PlaceholderBindings *bindings);

/// Propagate Placeholders needed by \p node from \p ctx into
/// the ExecutionContext for \p node within the run corresponding to \p
/// executionState.
void
propagatePlaceholdersForNode(std::shared_ptr<ExecutionState> executionState,
const DAGNode *node,
const ExecutionContext *ctx);

/// Execute the DAG node specified by \p node within the run corresponding to
/// \p executionState.
void executeDAGNode(std::shared_ptr<ExecutionState> executionState,
Expand Down
22 changes: 13 additions & 9 deletions tests/unittests/ExecutorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ class TestDeviceManager final : public runtime::DeviceManager {
/// call \p resultCB with it after checking that \p context contains the
/// expected Placeholder-Tensor mappings.
void doRunFunction(std::string functionName,
std::shared_ptr<ExecutionContext> context,
std::unique_ptr<ExecutionContext> context,
ResultCBTy resultCB) {

RunIdentifierTy runId = 0;
bool successResult = false;
std::unique_ptr<ExecutionContext> resultContext = nullptr;

// Retrieve the registered response for the function if there is one.
if (context && resultCB && resultMap_.count(functionName)) {
Expand All @@ -95,14 +95,18 @@ class TestDeviceManager final : public runtime::DeviceManager {
// ones.
runId = registeredResult->runId;
successResult = registeredResult->success;
resultContext = std::move(registeredResult->resultContext);

for (auto p : registeredResult->resultContext->getPlaceholderBindings()
->pairs()) {
context->getPlaceholderBindings()->get(p.first)->assign(p.second);
}
}
}

if (successResult) {
resultCB(runId, llvm::Error::success(), std::move(resultContext));
resultCB(runId, llvm::Error::success(), std::move(context));
} else {
resultCB(runId, MAKE_ERR("An error occurred"), std::move(resultContext));
resultCB(runId, MAKE_ERR("An error occurred"), std::move(context));
}
}

Expand All @@ -113,10 +117,10 @@ class TestDeviceManager final : public runtime::DeviceManager {
ResultCBTy resultCB) override {
// Give the call to the thread pool to process to make the tests
// multithreaded if needed.
std::shared_ptr<ExecutionContext> sharedContext = std::move(context);
this->threadPool_.submit([this, functionName, sharedContext, resultCB]() {
this->doRunFunction(functionName, sharedContext, resultCB);
});
this->threadPool_.submit(
[this, functionName, context = std::move(context), resultCB]() mutable {
this->doRunFunction(functionName, std::move(context), resultCB);
});
return 0;
}

Expand Down