From a625c02b9c672a6c08ccd32dfbb9a40e4092ffa7 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 5 Jun 2025 15:03:42 -0700 Subject: [PATCH] [mlir][scf] Return `replacements` explicitly in `SCFTilingResult`. In #120115 the replacements for the tiled operations were wrapped within the `MergeResult` object. That is a bit of an obfuscation and not immediately obvious where to get the replacements post tiling. This changes the `SCFTilingResult` to have `replacements` explicit (as it was before that change). It also makes the `mergeOps` a separate field of `SCFTilingResult`, which is empty when the reduction type is `FullReduction`. Signed-off-by: MaheshRavishankar --- .../SCF/Transforms/TileUsingInterface.h | 16 ++--- .../mlir/Interfaces/TilingInterface.td | 3 +- .../TransformOps/LinalgTransformOps.cpp | 10 ++-- .../SCF/Transforms/TileUsingInterface.cpp | 58 ++++++++++--------- .../TestTilingInterfaceTransformOps.cpp | 3 +- 5 files changed, 47 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 33a43ce2ee7bb..f686ae07b9a99 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -136,15 +136,17 @@ struct SCFTilingResult { SmallVector initialValues; /// The `scf.for` operations that iterate over the tiles. SmallVector loops; - /// The result generated by the loop nest in tiling, may hold partial results, - /// which need to be merged to match the computation of the untiled operation. - /// `mergeResult` contains the operations used to perform this merge from - /// partial results and the values that can be used as replacements of - /// the untiled operation. - MergeResult mergeResult; + /// Values to use as replacements for the untiled op. Is the same size as the + /// number of results of the untiled op. + SmallVector replacements; /// Slices generated after tiling that can be used for fusing with the tiled /// producer. SmallVector generatedSlices; + /// In cases where there as an additional merge step after tiling + /// return the merged ops after tiling. This list is empty when reduction + /// tiling strategy is + /// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction. + SmallVector mergeOps; }; /// Method to tile an op that implements the `TilingInterface` using @@ -362,7 +364,7 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op); /// ``` FailureOr tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, - ArrayRef tileSize); + ArrayRef tileSizes); } // namespace scf } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index 50b69b8f8d833..cdf3d01ce8a84 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -363,7 +363,8 @@ def TilingInterface : OpInterface<"TilingInterface"> { ]; } -def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> { +def PartialReductionOpInterface : + OpInterface<"PartialReductionOpInterface", [TilingInterface]> { let description = [{ Interface for allowing operations to expose information needed to tile reductions using partial reduction followed by merge. This is diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 1c3b621828315..b2c28f5eed33c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2381,7 +2381,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, return emitDefaultDefiniteFailure(target); if (target->getNumResults()) - rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements); + rewriter.replaceOp(target, maybeTilingResult->replacements); else rewriter.eraseOp(target); @@ -2800,12 +2800,12 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne( if (failed(result)) return emitDefaultSilenceableFailure(target); - rewriter.replaceOp(target, result->mergeResult.replacements); + rewriter.replaceOp(target, result->replacements); for (Value initValue : result->initialValues) results.push_back(initValue.getDefiningOp()); for (auto parallelTiledOp : result->tiledOps) results.push_back(parallelTiledOp); - for (auto mergeOp : result->mergeResult.mergeOps) + for (auto mergeOp : result->mergeOps) results.push_back(mergeOp); results.push_back(result->loops.front()); return DiagnosedSilenceableFailure::success(); @@ -3229,7 +3229,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter, if (failed(maybeTilingResult)) return DiagnosedSilenceableFailure::definiteFailure(); - rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements); + rewriter.replaceOp(op, maybeTilingResult->replacements); tiled.append(maybeTilingResult->tiledOps); for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops)) @@ -3465,7 +3465,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl( if (failed(maybeTilingResult)) return transformOp.emitDefaultSilenceableFailure(tileableOp); - rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements); + rewriter.replaceOp(tileableOp, maybeTilingResult->replacements); tilingResult = *maybeTilingResult; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 57ee0f52e7491..a0f9b599d1351 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1058,48 +1058,50 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, assert(succeeded(tilingResult) && "expected tiling result to be computed after loop generation"); - SmallVector partialResults; if (loops.empty()) { // If loops are empty, the tiled op is used as the replacement for the // untiled op. - partialResults = tilingResult->tiledValues; - } else { - partialResults = llvm::map_to_vector(loops.front()->getResults(), + return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops, + tilingResult->tiledValues, + tilingResult->generatedSlices}; + } + + auto loopResults = llvm::map_to_vector(loops.front()->getResults(), [](OpResult r) -> Value { return r; }); + + // For the full reduction case, there is nothing more to do. + if (options.reductionStrategy == + scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction) { + return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops, + loopResults, tilingResult->generatedSlices}; } + // The results of the loop needs to be merged. FailureOr mergeResult = - mergeTilingResults(rewriter, op, partialResults, options); + mergeTilingResults(rewriter, op, loopResults, options); if (failed(mergeResult)) { return rewriter.notifyMatchFailure( op, "Failed to merge partial results from tiling"); } - - return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops, - mergeResult.value(), - tilingResult->generatedSlices}; + return scf::SCFTilingResult{tilingResult->tiledOps, + initTensors, + loops, + mergeResult->replacements, + tilingResult->generatedSlices, + mergeResult->mergeOps}; } FailureOr mlir::scf::tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, - ArrayRef tileSizes) { - SCFTilingOptions options; - options.setLoopType(SCFTilingOptions::LoopType::ForOp); - options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy:: - PartialReductionOuterReduction); - options.setTileSizes(tileSizes); - - TilingInterface tilingInterfaceOp = - dyn_cast(op.getOperation()); - if (!tilingInterfaceOp) { - return b.notifyMatchFailure( - op, - "Operation implementing PartialReductionOpInterface should implement " - "TilingInterface"); - } - - return tileUsingSCF(b, tilingInterfaceOp, options); + ArrayRef tileSize) { + scf::SCFTilingOptions options; + options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); + options.setReductionTilingStrategy( + scf::SCFTilingOptions::ReductionTilingStrategy:: + PartialReductionOuterReduction); + options.setTileSizes(tileSize); + return tileUsingSCF(b, op, options); } //===----------------------------------------------------------------------===// @@ -1539,8 +1541,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( tiledAndFusedOps.insert_range(tilingResult->tiledOps); DenseMap replacements; - for (auto [origVal, replacement] : llvm::zip_equal( - consumer->getResults(), tilingResult->mergeResult.replacements)) { + for (auto [origVal, replacement] : + llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) { replacements[origVal] = replacement; } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 45d6ae3820159..9971f0cde4ed2 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -260,8 +260,7 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp, return failure(); // Perform the replacement of tiled and fused values. - rewriter.replaceOp(tilingInterfaceOp, - tiledResults->mergeResult.replacements); + rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements); // Report back the relevant handles to the transform op. tiledOps.push_back(tiledResults->tiledOps.front());