diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h index 892675954493b..a4ee893ca5341 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h @@ -10,7 +10,9 @@ #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TRANSFORMS_H #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/SubsetOpInterface.h" namespace mlir { namespace bufferization { @@ -34,13 +36,35 @@ struct OneShotBufferizationOptions; /// "tensor.empty" op. LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op); +/// A function type that defines a callback to control the construction +/// of the subset extraction of the `SubsetInsertionOpInterface`. +/// The subset extraction value can be used as a replacement for the +/// `emptyTensorOp` value which is being consumed by `user`, failing +/// of building such a value should be indicated with an empty value. +/// This function should guarantee the legality of the replacement, +/// i.e. the replacement should dominate the user of the `emptyTensorOp` +/// being eliminated. +using ControlBuildSubsetExtractionFn = + std::function; + +/// This method builds and returns a subset extraction value for the +/// destination tensor that the given `op` inserts into. +/// It returns a value which should replace the `emptyTensorOp` use +/// that is being consumed by `user`. +/// If no such a value found it will return an empty Value. +Value buildSubsetExtraction(RewriterBase &rewriter, + SubsetInsertionOpInterface op, + tensor::EmptyOp emptyTensorOp, Operation *user); + /// Try to eliminate "tensor.empty" ops inside `op`. /// /// This function overload accepts an existing `OneShotAnalysisState`, which /// contains in-place bufferization decisions. This overload is useful if an /// existing analysis should be reused for empty tensor elimination. -LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op, - OneShotAnalysisState &state); +LogicalResult eliminateEmptyTensors( + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, + ControlBuildSubsetExtractionFn subsetsExtractionFn = buildSubsetExtraction); /// Within the given operation, hoist buffers from loops where possible. See /// "BufferLoopHoistingPass" for more information. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index abc0635a2cdff..98c3d8d0adc6d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -93,8 +93,31 @@ findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, return nullptr; } +Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter, + SubsetInsertionOpInterface op, + tensor::EmptyOp emptyTensorOp, + Operation *user) { + + mlir::OpBuilder::InsertionGuard guard(rewriter); + // All values that are needed to create the replacement op. + SmallVector neededValues = op.getValuesNeededToBuildSubsetExtraction(); + // Find a suitable insertion point. If no suitable insertion point + // for the replacement can be found, return an empty value to skip + // this replacement. + Operation *insertionPoint = + findValidInsertionPoint(emptyTensorOp, user, neededValues); + if (!insertionPoint) + return {}; + + rewriter.setInsertionPoint(insertionPoint); + Value replacement = + op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); + return replacement; +} + LogicalResult mlir::bufferization::eliminateEmptyTensors( - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, + ControlBuildSubsetExtractionFn subsetsExtractionFn) { OpBuilder::InsertionGuard g(rewriter); llvm::DenseSet visitedOpOperands; op->walk([&](SubsetInsertionOpInterface op) { @@ -105,10 +128,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( if (!state.isInPlace(source)) return WalkResult::skip(); - // All values that are needed to create the replacement op. - SmallVector neededValues = - op.getValuesNeededToBuildSubsetExtraction(); - // Find tensor.empty ops on the reverse SSA use-def chain. Only follow // equivalent tensors. I.e., stop when there are ops such as extract_slice // on the path. @@ -129,8 +148,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( &visitedOpOperands); for (Value v : emptyTensors) { - Operation *emptyTensorOp = v.getDefiningOp(); - + auto emptyTensorOp = v.getDefiningOp(); + assert(emptyTensorOp && "expected tensor.empty op"); // Find the use to be replaced from the use-def chain. auto iter = llvm::find_if( visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) { @@ -142,17 +161,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( continue; OpOperand *useToBeReplaced = *iter; Operation *user = useToBeReplaced->getOwner(); - - // Find a suitable insertion point. If no suitable insertion point for - // the replacement can be found, skip this replacement. - Operation *insertionPoint = - findValidInsertionPoint(emptyTensorOp, user, neededValues); - if (!insertionPoint) - continue; - - rewriter.setInsertionPoint(insertionPoint); - Value replacement = - op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); + auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user); if (!replacement) continue; if (emptyTensorOp == replacement.getDefiningOp())