@@ -97,15 +97,26 @@ ExecutionState::ExecutionState(RunIdentifierTy id, const DAGNode *root,
9797 // Create Placeholders for the symbols of all intermediate nodes. These are
9898 // not in the ExecutionContext passed to Executor::run, so they must be
9999 // created by the Executor.
100+ auto *resultBindings = resultCtx_->getPlaceholderBindings ();
100101 for (const auto &symbolPair : symbolTable) {
101102 const auto &symbolName = symbolPair.first ;
102103 const auto &symbolInfo = symbolPair.second ;
103104
104105 if (symbolInfo.symbolCategory == SymbolCategory::Placeholder) {
105- auto PH = module_->getPlaceholderByName (symbolName);
106- // If a PH name is provided it had to come from the Module originally.
107- DCHECK (PH) << " Placeholder: " << symbolName << " is not in the module" ;
108- nodeInputPhBindings->allocate (PH);
106+ auto *PH = resultBindings->getPlaceholderByName (symbolName);
107+ if (!PH) {
108+ PH = module_->getPlaceholderByName (symbolName);
109+ DCHECK (PH) << " Placeholder: " << symbolName
110+ << " is not in the module" ;
111+
112+ // allocate into the resultBindings because they have the longest
113+ // lifetime.
114+ resultBindings->allocate (PH);
115+ intermediatePlaceholders_.push_back (PH);
116+ }
117+
118+ nodeInputPhBindings->insert (
119+ PH, resultBindings->get (PH)->getUnowned (PH->dims ()));
109120 }
110121 }
111122
@@ -123,26 +134,6 @@ ExecutionState::ExecutionState(RunIdentifierTy id, const DAGNode *root,
123134 }
124135}
125136
126- void ExecutionState::insertIntoNodeCtx (const DAGNode *node,
127- llvm::StringRef name, Tensor &&T) {
128- // Get a raw pointer to the input ExecutionContext for the node. It should
129- // have been created in the constructor.
130- auto ctxIt = inputCtxs_.find (node);
131-
132- if (ctxIt == inputCtxs_.end ()) {
133- assert (!" Input bindings not found but should exist!" );
134- }
135-
136- PlaceholderBindings *bindings = (ctxIt->second )->getPlaceholderBindings ();
137- assert (bindings && " Input bindings for node is null" );
138-
139- // Insert the placeholder-tensor pair.
140- std::lock_guard<std::mutex> lock (bindingsMtx_);
141- auto *tensor = bindings->get (bindings->getPlaceholderByName (name));
142- assert (tensor && " Placeholder should have already been created" );
143- *tensor = std::move (T);
144- }
145-
146137std::unique_ptr<ExecutionContext>
147138ExecutionState::getUniqueNodeContextPtr (const DAGNode *node) {
148139 // The input PlaceholderBindings for the node should have been created in the
@@ -201,22 +192,6 @@ bool ExecutionState::incrementNodeParentsDone(const DAGNode *node,
201192 return (newValue == numParents);
202193}
203194
204- void ExecutionState::insertIntoResultCtx (llvm::StringRef name, Tensor &&T) {
205- // The result PlaceholderBindings should have been been created in the
206- // constructor and should not yet have been moved out if this function is
207- // being called.
208- assert (resultCtx_ && resultCtx_->getPlaceholderBindings () &&
209- " Execution result bindings should exist!" );
210- std::lock_guard<std::mutex> lock (bindingsMtx_);
211- auto *resultBindings = resultCtx_->getPlaceholderBindings ();
212- Tensor *tensor =
213- resultBindings->get (resultBindings->getPlaceholderByName (name));
214-
215- if (tensor) {
216- *tensor = std::move (T);
217- }
218- }
219-
220195void ExecutionState::insertIntoTraceContext (std::vector<TraceEvent> &events) {
221196 if (!resultCtx_->getTraceContext ()) {
222197 events.clear ();
@@ -229,6 +204,13 @@ void ExecutionState::insertIntoTraceContext(std::vector<TraceEvent> &events) {
229204 std::back_inserter (resultCtx_->getTraceContext ()->getTraceEvents ()));
230205}
231206
207+ void ExecutionState::removeIntermediatePlaceholders () {
208+ for (auto &p : intermediatePlaceholders_) {
209+ resultCtx_->getPlaceholderBindings ()->erase (p);
210+ }
211+ intermediatePlaceholders_.clear ();
212+ }
213+
232214std::unique_ptr<ExecutionContext> ExecutionState::getUniqueResultContextPtr () {
233215 // The result PlaceholderBindings should have been been created in the
234216 // constructor.
@@ -308,39 +290,11 @@ void ThreadPoolExecutor::run(const DAGNode *root,
308290 inflightBarrier_.increment (numChildren);
309291
310292 for (auto const &node : root->children ) {
311- // Propagate placeholders from the given starter PlaceholderBindings into
312- // the input PlaceholderBindings for the current node being processed.
313- propagatePlaceholdersForNode (executionState, node,
314- executionState->getRawResultContextPtr ());
315-
316293 // Execute the node.
317294 executeDAGNode (executionState, node);
318295 }
319296}
320297
321- void ThreadPoolExecutor::propagatePlaceholdersForNode (
322- std::shared_ptr<ExecutionState> executionState, const DAGNode *node,
323- const ExecutionContext *ctx) {
324- ScopedTraceBlock (executionState->getRawResultContextPtr ()->getTraceContext (),
325- " EX_propagateInputs" );
326- // Get the symbol table for the node.
327- const SymbolTableTy &symbolTable = node->runtimeBundle ->getSymbolTable ();
328-
329- for (const auto &symbolPair : symbolTable) {
330- const auto &symbolName = symbolPair.first ;
331-
332- auto *placeholder =
333- ctx->getPlaceholderBindings ()->getPlaceholderByName (symbolName);
334-
335- // If ctx provides a mapping for the symbol, copy it into the context for
336- // the node.
337- if (placeholder) {
338- const auto *tensor = ctx->getPlaceholderBindings ()->get (placeholder);
339- executionState->insertIntoNodeCtx (node, symbolName, tensor->clone ());
340- }
341- }
342- }
343-
344298void ThreadPoolExecutor::executeDAGNode (
345299 std::shared_ptr<ExecutionState> executionState, DAGNode *node) {
346300 // If execution has already failed due to another node, don't bother running
@@ -396,25 +350,10 @@ void ThreadPoolExecutor::executeDAGNode(
396350 });
397351}
398352
399- void ThreadPoolExecutor::propagateOutputPlaceholders (
400- std::shared_ptr<ExecutionState> executionState,
401- PlaceholderBindings *bindings) {
402- ScopedTraceBlock (executionState->getRawResultContextPtr ()->getTraceContext (),
403- " EX_propagateOutputs" );
404- // Copy all of the Placeholders in bindings into the result
405- // PlaceholderBindings for the run.
406- for (const auto &phTensorPair : bindings->pairs ()) {
407- auto *placeholder = phTensorPair.first ;
408- auto *tensor = phTensorPair.second ;
409-
410- executionState->insertIntoResultCtx (placeholder->getName (),
411- std::move (*tensor));
412- }
413- }
414-
415353void ThreadPoolExecutor::handleDeviceManagerResult (
416354 std::shared_ptr<ExecutionState> executionState, llvm::Error err,
417355 std::unique_ptr<ExecutionContext> ctx, const DAGNode *node) {
356+
418357 // If executionState is null, that means that the object was deleted
419358 // while a node was executing. That should never happen.
420359 assert (executionState && " Execution state should not be null" );
@@ -430,26 +369,15 @@ void ThreadPoolExecutor::handleDeviceManagerResult(
430369 // If the DeviceManager executed the node, propagate its output Placeholders
431370 // to its children or the result PlaceholderBindings as appropriate.
432371 if (runWasSuccess) {
433- if ((node->children ).empty ()) {
434- // If the node has no children, propagate its outputs to the result
435- // PlaceholderBindings for the run.
436- propagateOutputPlaceholders (executionState,
437- ctx->getPlaceholderBindings ());
438- } else {
439- // If the node has children, propagate its outputs to the input
440- // PlaceholderBindings for any of its children that need them as inputs.
441- for (auto &child : node->children ) {
442- propagatePlaceholdersForNode (executionState, child, ctx.get ());
443-
444- // Execute any child that has no parent nodes left to execute.
445- bool childReadyToExecute =
446- executionState->incrementNodeParentsDone (child);
447- if (childReadyToExecute) {
448- // Mark the node as "inflight" (i.e. currently executing).
449- executionState->incrementInflightNodes ();
450- inflightBarrier_.increment ();
451- executeDAGNode (executionState, child);
452- }
372+ for (auto &child : node->children ) {
373+ // Execute any child that has no parent nodes left to execute.
374+ bool childReadyToExecute =
375+ executionState->incrementNodeParentsDone (child);
376+ if (childReadyToExecute) {
377+ // Mark the node as "inflight" (i.e. currently executing).
378+ executionState->incrementInflightNodes ();
379+ inflightBarrier_.increment ();
380+ executeDAGNode (executionState, child);
453381 }
454382 }
455383 }
@@ -464,6 +392,9 @@ void ThreadPoolExecutor::handleDeviceManagerResult(
464392 }
465393
466394 if (noNodesInflight) {
395+ // Remove the intermediate placeholders so we don't leak them to the caller.
396+ executionState->removeIntermediatePlaceholders ();
397+
467398 // If there are no nodes inflight, that means all nodes are done. Call
468399 // the callback and erase the state information.
469400 ResultCBTy cb = executionState->getCallback ();
0 commit comments