Skip to content

[mlir][Transforms][NFC] Store per-pattern IR modifications in separate state #145319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 75 additions & 64 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,16 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;

/// A set of operations that were created by the current pattern.
SetVector<Operation *> patternNewOps;

/// A set of operations that were modified by the current pattern.
SetVector<Operation *> patternModifiedOps;

/// A set of blocks that were inserted (newly-created blocks or moved blocks)
/// by the current pattern.
SetVector<Block *> patternInsertedBlocks;

/// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
/// to the corresponding rewrite objects.
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
Expand Down Expand Up @@ -1571,6 +1581,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
if (!previous.isSet()) {
// This is a newly created op.
appendRewrite<CreateOperationRewrite>(op);
patternNewOps.insert(op);
return;
}
Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
Expand Down Expand Up @@ -1655,6 +1666,8 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
}
});

patternInsertedBlocks.insert(block);

if (!previous) {
// This is a newly created block.
appendRewrite<CreateBlockRewrite>(block);
Expand Down Expand Up @@ -1852,6 +1865,8 @@ void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
assert(!impl->wasOpReplaced(op) &&
"attempting to modify a replaced/erased op");
PatternRewriter::finalizeOpModification(op);
impl->patternModifiedOps.insert(op);

// There is nothing to do here, we only need to track the operation at the
// start of the update.
#ifndef NDEBUG
Expand Down Expand Up @@ -1964,21 +1979,25 @@ class OperationLegalizer {
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
ConversionPatternRewriter &rewriter,
RewriterState &curState);
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks);

/// Legalizes the actions registered during the execution of a pattern.
LogicalResult
legalizePatternBlockRewrites(Operation *op,
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState);
LogicalResult legalizePatternCreatedOperations(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState);
LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &impl,
RewriterState &state,
RewriterState &newState);
const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps);
LogicalResult
legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &impl,
const SetVector<Operation *> &newOps);
LogicalResult
legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &impl,
const SetVector<Operation *> &modifiedOps);

//===--------------------------------------------------------------------===//
// Cost Model
Expand Down Expand Up @@ -2131,6 +2150,15 @@ OperationLegalizer::legalize(Operation *op,
return failure();
}

/// Helper function that moves and returns the given object. Also resets the
/// original object, so that it is in a valid, empty state again.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am no C++ whiz but this unusual. Perhaps the comment could be expanded to explain why this is needed?

Copy link
Member Author

@matthias-springer matthias-springer Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving an object leaves it in an unspecified state. This helper function both moves the object and initializes the old object, so that it can be immediately reused again.

template <typename T>
static T moveAndReset(T &obj) {
T result = std::move(obj);
obj = T();
return result;
}

LogicalResult
OperationLegalizer::legalizeWithFold(Operation *op,
ConversionPatternRewriter &rewriter) {
Expand Down Expand Up @@ -2192,6 +2220,9 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
rewriterImpl.patternInsertedBlocks.clear();
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
if (rewriterImpl.config.notifyCallback) {
Expand All @@ -2212,7 +2243,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// successfully applied.
auto onSuccess = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
auto result = legalizePatternResult(op, pattern, rewriter, curState);
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
SetVector<Operation *> modifiedOps =
moveAndReset(rewriterImpl.patternModifiedOps);
SetVector<Block *> insertedBlocks =
moveAndReset(rewriterImpl.patternInsertedBlocks);
auto result = legalizePatternResult(op, pattern, rewriter, newOps,
modifiedOps, insertedBlocks);
appliedPatterns.erase(&pattern);
if (failed(result)) {
if (!rewriterImpl.config.allowPatternRollback)
Expand Down Expand Up @@ -2253,10 +2290,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
return true;
}

LogicalResult
OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
ConversionPatternRewriter &rewriter,
RewriterState &curState) {
LogicalResult OperationLegalizer::legalizePatternResult(
Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
auto &impl = rewriter.getImpl();
assert(impl.pendingRootUpdates.empty() && "dangling root updates");

Expand All @@ -2274,12 +2312,10 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

// Legalize each of the actions registered during application.
RewriterState newState = impl.getCurrentState();
if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
newState)) ||
failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
failed(legalizePatternCreatedOperations(rewriter, impl, curState,
newState))) {
if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
newOps)) ||
failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
return failure();
}

Expand All @@ -2289,20 +2325,14 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,

LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
Operation *op, ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &impl, RewriterState &state,
RewriterState &newState) {
SmallPtrSet<Operation *, 16> operationsToIgnore;
ConversionPatternRewriterImpl &impl,
const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps) {
SmallPtrSet<Operation *, 16> alreadyLegalized;

// If the pattern moved or created any blocks, make sure the types of block
// arguments get legalized.
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
if (!rewrite)
continue;
Block *block = rewrite->getBlock();
if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
ReplaceBlockArgRewrite, InlineBlockRewrite>(rewrite))
continue;
for (Block *block : insertedBlocks) {
// Only check blocks outside of the current operation.
Operation *parentOp = block->getParentOp();
if (!parentOp || parentOp == op || block->getNumArguments() == 0)
Expand All @@ -2322,41 +2352,26 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
continue;
}

// Otherwise, check that this operation isn't one generated by this pattern.
// This is because we will attempt to legalize the parent operation, and
// blocks in regions created by this pattern will already be legalized later
// on. If we haven't built the set yet, build it now.
if (operationsToIgnore.empty()) {
for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
++i) {
auto *createOp =
dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
if (!createOp)
continue;
operationsToIgnore.insert(createOp->getOperation());
// Otherwise, try to legalize the parent operation if it was not generated
// by this pattern. This is because we will attempt to legalize the parent
// operation, and blocks in regions created by this pattern will already be
// legalized later on.
if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
if (failed(legalize(parentOp, rewriter))) {
LLVM_DEBUG(logFailure(
impl.logger, "operation '{0}'({1}) became illegal after rewrite",
parentOp->getName(), parentOp));
return failure();
}
}

// If this operation should be considered for re-legalization, try it.
if (operationsToIgnore.insert(parentOp).second &&
failed(legalize(parentOp, rewriter))) {
LLVM_DEBUG(logFailure(impl.logger,
"operation '{0}'({1}) became illegal after rewrite",
parentOp->getName(), parentOp));
return failure();
}
}
return success();
}

LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) {
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
if (!createOp)
continue;
Operation *op = createOp->getOperation();
const SetVector<Operation *> &newOps) {
for (Operation *op : newOps) {
if (failed(legalize(op, rewriter))) {
LLVM_DEBUG(logFailure(impl.logger,
"failed to legalize generated operation '{0}'({1})",
Expand All @@ -2369,12 +2384,8 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(

LogicalResult OperationLegalizer::legalizePatternRootUpdates(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) {
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
if (!rewrite)
continue;
Operation *op = rewrite->getOperation();
const SetVector<Operation *> &modifiedOps) {
for (Operation *op : modifiedOps) {
if (failed(legalize(op, rewriter))) {
LLVM_DEBUG(logFailure(
impl.logger, "failed to legalize operation updated in-place '{0}'",
Expand Down