diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 7cfe7250d02c3..955a106c21941 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1080,6 +1080,16 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// to modify/access them is invalid rewriter API usage. SetVector replacedOps; + /// A set of operations that were created by the current pattern. + SetVector patternNewOps; + + /// A set of operations that were modified by the current pattern. + SetVector patternModifiedOps; + + /// A set of blocks that were inserted (newly-created blocks or moved blocks) + /// by the current pattern. + SetVector patternInsertedBlocks; + /// A mapping of all unresolved materializations (UnrealizedConversionCastOp) /// to the corresponding rewrite objects. DenseMap @@ -1571,6 +1581,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( if (!previous.isSet()) { // This is a newly created op. appendRewrite(op); + patternNewOps.insert(op); return; } Operation *prevOp = previous.getPoint() == previous.getBlock()->end() @@ -1655,6 +1666,8 @@ void ConversionPatternRewriterImpl::notifyBlockInserted( } }); + patternInsertedBlocks.insert(block); + if (!previous) { // This is a newly created block. appendRewrite(block); @@ -1852,6 +1865,8 @@ void ConversionPatternRewriter::finalizeOpModification(Operation *op) { assert(!impl->wasOpReplaced(op) && "attempting to modify a replaced/erased op"); PatternRewriter::finalizeOpModification(op); + impl->patternModifiedOps.insert(op); + // There is nothing to do here, we only need to track the operation at the // start of the update. #ifndef NDEBUG @@ -1964,21 +1979,25 @@ class OperationLegalizer { /// Legalize the resultant IR after successfully applying the given pattern. LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, - RewriterState &curState); + const SetVector &newOps, + const SetVector &modifiedOps, + const SetVector &insertedBlocks); /// Legalizes the actions registered during the execution of a pattern. LogicalResult legalizePatternBlockRewrites(Operation *op, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, - RewriterState &state, RewriterState &newState); - LogicalResult legalizePatternCreatedOperations( - ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, - RewriterState &state, RewriterState &newState); - LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &impl, - RewriterState &state, - RewriterState &newState); + const SetVector &insertedBlocks, + const SetVector &newOps); + LogicalResult + legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &impl, + const SetVector &newOps); + LogicalResult + legalizePatternRootUpdates(ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &impl, + const SetVector &modifiedOps); //===--------------------------------------------------------------------===// // Cost Model @@ -2131,6 +2150,15 @@ OperationLegalizer::legalize(Operation *op, return failure(); } +/// Helper function that moves and returns the given object. Also resets the +/// original object, so that it is in a valid, empty state again. +template +static T moveAndReset(T &obj) { + T result = std::move(obj); + obj = T(); + return result; +} + LogicalResult OperationLegalizer::legalizeWithFold(Operation *op, ConversionPatternRewriter &rewriter) { @@ -2192,6 +2220,9 @@ OperationLegalizer::legalizeWithPattern(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); + rewriterImpl.patternNewOps.clear(); + rewriterImpl.patternModifiedOps.clear(); + rewriterImpl.patternInsertedBlocks.clear(); LLVM_DEBUG({ logFailure(rewriterImpl.logger, "pattern failed to match"); if (rewriterImpl.config.notifyCallback) { @@ -2212,7 +2243,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // successfully applied. auto onSuccess = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); - auto result = legalizePatternResult(op, pattern, rewriter, curState); + SetVector newOps = moveAndReset(rewriterImpl.patternNewOps); + SetVector modifiedOps = + moveAndReset(rewriterImpl.patternModifiedOps); + SetVector insertedBlocks = + moveAndReset(rewriterImpl.patternInsertedBlocks); + auto result = legalizePatternResult(op, pattern, rewriter, newOps, + modifiedOps, insertedBlocks); appliedPatterns.erase(&pattern); if (failed(result)) { if (!rewriterImpl.config.allowPatternRollback) @@ -2253,10 +2290,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, return true; } -LogicalResult -OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, - ConversionPatternRewriter &rewriter, - RewriterState &curState) { +LogicalResult OperationLegalizer::legalizePatternResult( + Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, + const SetVector &newOps, + const SetVector &modifiedOps, + const SetVector &insertedBlocks) { auto &impl = rewriter.getImpl(); assert(impl.pendingRootUpdates.empty() && "dangling root updates"); @@ -2274,12 +2312,10 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Legalize each of the actions registered during application. - RewriterState newState = impl.getCurrentState(); - if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState, - newState)) || - failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) || - failed(legalizePatternCreatedOperations(rewriter, impl, curState, - newState))) { + if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks, + newOps)) || + failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) || + failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) { return failure(); } @@ -2289,20 +2325,14 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, LogicalResult OperationLegalizer::legalizePatternBlockRewrites( Operation *op, ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &impl, RewriterState &state, - RewriterState &newState) { - SmallPtrSet operationsToIgnore; + ConversionPatternRewriterImpl &impl, + const SetVector &insertedBlocks, + const SetVector &newOps) { + SmallPtrSet alreadyLegalized; // If the pattern moved or created any blocks, make sure the types of block // arguments get legalized. - for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { - BlockRewrite *rewrite = dyn_cast(impl.rewrites[i].get()); - if (!rewrite) - continue; - Block *block = rewrite->getBlock(); - if (isa(rewrite)) - continue; + for (Block *block : insertedBlocks) { // Only check blocks outside of the current operation. Operation *parentOp = block->getParentOp(); if (!parentOp || parentOp == op || block->getNumArguments() == 0) @@ -2322,41 +2352,26 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( continue; } - // Otherwise, check that this operation isn't one generated by this pattern. - // This is because we will attempt to legalize the parent operation, and - // blocks in regions created by this pattern will already be legalized later - // on. If we haven't built the set yet, build it now. - if (operationsToIgnore.empty()) { - for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e; - ++i) { - auto *createOp = - dyn_cast(impl.rewrites[i].get()); - if (!createOp) - continue; - operationsToIgnore.insert(createOp->getOperation()); + // Otherwise, try to legalize the parent operation if it was not generated + // by this pattern. This is because we will attempt to legalize the parent + // operation, and blocks in regions created by this pattern will already be + // legalized later on. + if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) { + if (failed(legalize(parentOp, rewriter))) { + LLVM_DEBUG(logFailure( + impl.logger, "operation '{0}'({1}) became illegal after rewrite", + parentOp->getName(), parentOp)); + return failure(); } } - - // If this operation should be considered for re-legalization, try it. - if (operationsToIgnore.insert(parentOp).second && - failed(legalize(parentOp, rewriter))) { - LLVM_DEBUG(logFailure(impl.logger, - "operation '{0}'({1}) became illegal after rewrite", - parentOp->getName(), parentOp)); - return failure(); - } } return success(); } LogicalResult OperationLegalizer::legalizePatternCreatedOperations( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, - RewriterState &state, RewriterState &newState) { - for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { - auto *createOp = dyn_cast(impl.rewrites[i].get()); - if (!createOp) - continue; - Operation *op = createOp->getOperation(); + const SetVector &newOps) { + for (Operation *op : newOps) { if (failed(legalize(op, rewriter))) { LLVM_DEBUG(logFailure(impl.logger, "failed to legalize generated operation '{0}'({1})", @@ -2369,12 +2384,8 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations( LogicalResult OperationLegalizer::legalizePatternRootUpdates( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, - RewriterState &state, RewriterState &newState) { - for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { - auto *rewrite = dyn_cast(impl.rewrites[i].get()); - if (!rewrite) - continue; - Operation *op = rewrite->getOperation(); + const SetVector &modifiedOps) { + for (Operation *op : modifiedOps) { if (failed(legalize(op, rewriter))) { LLVM_DEBUG(logFailure( impl.logger, "failed to legalize operation updated in-place '{0}'",