From be85978d35ff35bb1a2e21d1ededa01c9a90c0ec Mon Sep 17 00:00:00 2001 From: "Golubev, Andrey" Date: Wed, 4 Jun 2025 15:03:26 +0000 Subject: [PATCH 1/5] [mlir][bufferization] Support custom types (1/N) Following the introduction of TensorLike and BufferLike type interfaces (see 00eaff3e9c897c263a879416d0f151d7ca7eeaff), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation. To achieve this, a new conversion dialect interface is added that abstracts away the differences between existing (tensor -> memref) and custom conversions. The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise. --- .../IR/BufferizableOpInterface.h | 17 +++- .../IR/BufferizationConversionInterface.h | 72 +++++++++++++++++ .../Bufferization/IR/BufferizationOps.td | 48 +++++++----- .../IR/UnstructuredControlFlow.h | 5 +- .../BufferizableOpInterfaceImpl.cpp | 14 ++-- .../IR/BufferizableOpInterface.cpp | 77 ++++++++++++------- .../IR/BufferizationConversionInterface.cpp | 67 ++++++++++++++++ .../Bufferization/IR/BufferizationOps.cpp | 21 ++--- .../Dialect/Bufferization/IR/CMakeLists.txt | 1 + .../Bufferization/Transforms/Bufferize.cpp | 8 +- .../FuncBufferizableOpInterfaceImpl.cpp | 8 +- .../BufferizableOpInterfaceImpl.cpp | 51 ++++++------ .../Transforms/Utils/CodegenUtils.cpp | 4 +- .../BufferizableOpInterfaceImpl.cpp | 14 ++-- .../Transforms/one-shot-bufferize.mlir | 21 ++++- mlir/test/lib/Dialect/Test/TestDialect.cpp | 49 ++++++++++++ mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 23 ++++++ mlir/test/lib/Dialect/Test/TestOps.h | 1 + mlir/test/lib/Dialect/Test/TestOps.td | 58 +++++++++++++- 19 files changed, 451 insertions(+), 108 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h create mode 100644 mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 2fb795f16ae2c..768778df046a6 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -17,6 +17,7 @@ #include #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc" +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" namespace mlir { class OpBuilder; @@ -615,7 +616,7 @@ FailureOr getBuffer(RewriterBase &rewriter, Value value, /// IR, this function can be used. /// /// This function is a wrapper around BufferizableOpInterface::getBufferType. -FailureOr getBufferType(Value value, +FailureOr getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state); @@ -629,7 +630,7 @@ FailureOr getBufferType(Value value, /// IR, this function can be used. /// /// This function is a wrapper around `BufferizableOpInterface::getBufferType`. -FailureOr getBufferType(Value value, +FailureOr getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack); @@ -739,6 +740,18 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand); /// This is the default implementation of /// BufferizableOpInterface::hasTensorSemantics bool defaultHasTensorSemantics(Operation *op); + +/// This is a helper function used when buffer type is guaranteed to be memref. +FailureOr castToMemRef(FailureOr bufferType); + +/// This function is a free-standing helper that relies on +/// bufferization::ConversionInterface to verify the types in tensor and buffer +/// worlds match. +bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer); + +/// This function is a free-standing helper that relies on +/// bufferization::ConversionInterface to perform the conversion. +Type getTensorFromBuffer(Type buffer); } // namespace detail } // namespace bufferization diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h new file mode 100644 index 0000000000000..4164d1dcb9ea6 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h @@ -0,0 +1,72 @@ +//===- BufferizationConversionInterface.h - Dialect Interface ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_ +#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_ + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" +#include "mlir/IR/DialectInterface.h" + +namespace mlir { +namespace bufferization { + +/// This class defines a virtual interface for conversions between tensor-like +/// and buffer-like types. +struct ConversionDialectInterface + : DialectInterface::Base { + using Base::Base; + + /// Hook to customize tensor-like -> buffer-like conversion within a given + /// dialect. Returns a buffer-like type for the specific tensor-like type. + virtual FailureOr getBufferType( + Value value, const BufferizationOptions &options, + const BufferizationState &state, + function_ref emitError) const = 0; + + /// Hook to customize type checking between tensor-like and buffer-like types. + /// Given tensor `T` and buffer `B = getBufferType(T, ...)`, the call to + /// `typesMatch(T, B)` must return true. + virtual LogicalResult typesMatch( + TensorLikeType tensor, BufferLikeType buffer, + function_ref emitError) const = 0; + + /// Hook to customize buffer-like -> tensor-like conversion, which is the + /// opposite of bufferization. + virtual TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const = 0; +}; + +/// Interface collection for conversion between tensor-like and buffer-like +/// types, dispatches to a concrete interface implementation based on the +/// dialect to which the given type belongs. +struct ConversionInterface + : DialectInterfaceCollection { + using Base::Base; + + /// Dispatches to ConversionDialectInterface::getBufferType() of the dialect + /// associated with the value type. + FailureOr getBufferType( + Value value, const BufferizationOptions &options, + const BufferizationState &state, + function_ref emitError) const; + + /// Dispatches to ConversionDialectInterface::typesMatch() of the dialect + /// associated with the value type. + LogicalResult + typesMatch(TensorLikeType tensor, BufferLikeType buffer, + function_ref emitError) const; + + /// Dispatches to ConversionDialectInterface::getTensorFromBuffer() of the + /// dialect associated with the value type. + TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const; +}; + +} // namespace bufferization +} // namespace mlir + +#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index 6051aea849971..3d301a0657200 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -12,6 +12,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" +include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td" include "mlir/Dialect/Bufferization/IR/BufferizationBase.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -386,20 +387,28 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor", // ToTensorOp //===----------------------------------------------------------------------===// +class Bufferization_TensorAndBufferMatch : PredOpTrait< + "specified tensor and buffer types match", + CPred< + "::mlir::bufferization::detail::typesMatchAfterBufferization(" + "$_op, $" # tensor # ", $" # buffer #")" + > +>; + def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ BufferizableOpInterface, SameOperandsAndResultShape, SameOperandsAndResultElementType, - AllElementTypesMatch<["memref", "result"]> + Bufferization_TensorAndBufferMatch<"result", "buffer"> ]> { - let summary = "create a tensor from a `memref`"; + let summary = "create a buffer-like type from a tensor-like type"; let description = [{ - An operation that creates a tensor from a `memref`. The result value is a - tensor whose shape and element type match the memref operand. + An operation that creates a tensor from a buffer. The result value is a + tensor-like type whose shape and element type match the buffer-like operand. The opposite of this op is `to_buffer`. Together, these two ops are useful for source/target materializations when doing type conversions - involving tensors and memrefs. + involving tensors and buffers. Example: @@ -441,11 +450,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ away. However, such IR is no longer bufferizable with One-Shot Bufferize. }]; - let arguments = (ins Arg]>:$memref, + [MemReadAt<0, FullEffect>]>:$buffer, UnitAttr:$restrict, UnitAttr:$writable); - let results = (outs AnyTensor:$result); + let results = (outs Bufferization_TensorLikeTypeInterface:$result); let extraClassDeclaration = [{ /// The result of a to_tensor is always a tensor. @@ -472,19 +481,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ FailureOr getBufferType( Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) { - return ::llvm::cast(getMemref().getType()); + return ::llvm::cast(getBuffer().getType()); } }]; let assemblyFormat = [{ - $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict - `:` type($memref) `to` type($result) + $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict + `:` type($buffer) `to` type($result) }]; let builders = [ - OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{ - auto rtt = memref::getTensorTypeFromMemRefType(memref.getType()); - build($_builder, $_state, rtt, memref, restrict, writeable); + OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{ + auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType()); + build($_builder, $_state, rtt, buffer, restrict, writeable); }]> ]; @@ -502,10 +511,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [ SameOperandsAndResultShape, SameOperandsAndResultElementType, Pure, - AllShapesMatch<["memref", "tensor"]>, - AllElementTypesMatch<["memref", "tensor"]> + Bufferization_TensorAndBufferMatch<"tensor", "buffer"> ]> { - let summary = "cast a tensor to memref"; + let summary = "cast a tensor-like type to buffer-like type"; let description = [{ An operation that returns the future buffer of a `tensor`. @@ -523,8 +531,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [ the returned buffer) will not be written to. }]; - let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only); - let results = (outs AnyRankedOrUnrankedMemRef:$memref); + let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only); + let results = (outs Bufferization_BufferLikeTypeInterface:$buffer); let extraClassDeclaration = [{ //===------------------------------------------------------------------===// @@ -559,7 +567,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [ }]; let assemblyFormat = [{ - $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref) + $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer) }]; let hasFolder = 1; diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h index a441b8b66659e..f56c10555f02c 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h @@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel // The operand was already bufferized. Take its type directly. callerType = memrefType; } else { - FailureOr maybeCallerType = + FailureOr maybeCallerType = bufferization::getBufferType(opOperand->get(), options, state, invocationStack); if (failed(maybeCallerType)) return failure(); - callerType = *maybeCallerType; + assert(isa(*maybeCallerType) && "expected memref type"); + callerType = cast(*maybeCallerType); } if (!bufferType) { diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index a57d58ab28d28..021a557f68b4b 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -164,8 +164,8 @@ struct SelectOpInterface // buffers have different types, they differ only in their layout map. Cast // both of them to the most dynamic MemRef type. if (trueBuffer.getType() != falseBuffer.getType()) { - auto targetType = - bufferization::getBufferType(selectOp.getResult(), options, state); + auto targetType = bufferization::detail::castToMemRef( + bufferization::getBufferType(selectOp.getResult(), options, state)); if (failed(targetType)) return failure(); if (trueBuffer.getType() != *targetType) @@ -187,10 +187,12 @@ struct SelectOpInterface SmallVector &invocationStack) const { auto selectOp = cast(op); assert(value == selectOp.getResult() && "invalid value"); - auto trueType = bufferization::getBufferType( - selectOp.getTrueValue(), options, state, invocationStack); - auto falseType = bufferization::getBufferType( - selectOp.getFalseValue(), options, state, invocationStack); + auto trueType = + bufferization::detail::castToMemRef(bufferization::getBufferType( + selectOp.getTrueValue(), options, state, invocationStack)); + auto falseType = + bufferization::detail::castToMemRef(bufferization::getBufferType( + selectOp.getFalseValue(), options, state, invocationStack)); if (failed(trueType) || failed(falseType)) return failure(); if (*trueType == *falseType) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index dd43647682ea2..bd79cbc80dd2a 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -211,8 +212,8 @@ FailureOr bufferization::allocateTensorForShapedValue( // Add 'memory_space' attribute. Not needed if 'copy' operand is specified. if (copy) return allocTensorOp.getResult(); - FailureOr copyBufferType = - getBufferType(tensor, options, state); + auto copyBufferType = + detail::castToMemRef(getBufferType(tensor, options, state)); if (failed(copyBufferType)) return failure(); std::optional memorySpace = copyBufferType->getMemorySpace(); @@ -672,28 +673,28 @@ FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state) { #ifndef NDEBUG - auto tensorType = llvm::dyn_cast(value.getType()); + auto tensorType = llvm::dyn_cast(value.getType()); assert(tensorType && "unexpected non-tensor type"); #endif // NDEBUG // Replace "%t = to_tensor %m" with %m. if (auto toTensorOp = value.getDefiningOp()) - return toTensorOp.getMemref(); + return toTensorOp.getBuffer(); // Insert to_buffer op. OpBuilder::InsertionGuard g(rewriter); setInsertionPointAfter(rewriter, value); - FailureOr memrefType = getBufferType(value, options, state); - if (failed(memrefType)) + FailureOr bufferType = getBufferType(value, options, state); + if (failed(bufferType)) return failure(); - ensureToBufferOpIsValid(value, *memrefType); + ensureToBufferOpIsValid(value, *bufferType); return rewriter - .create(value.getLoc(), *memrefType, value) + .create(value.getLoc(), *bufferType, value) .getResult(); } /// Return the buffer type for a given Value (tensor) after bufferization. -FailureOr +FailureOr bufferization::getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state) { SmallVector invocationStack; @@ -701,11 +702,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, } /// Return the buffer type for a given Value (tensor) after bufferization. -FailureOr +FailureOr bufferization::getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) { - assert(llvm::isa(value.getType()) && + assert(llvm::isa(value.getType()) && "unexpected non-tensor type"); invocationStack.push_back(value); auto popFromStack = @@ -717,14 +718,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, if (bufferizableOp) return bufferizableOp.getBufferType(value, options, state, invocationStack); - // Op is not bufferizable. - auto memSpace = - options.defaultMemorySpaceFn(cast(value.getType())); - if (!memSpace.has_value()) - return op->emitError("could not infer memory space"); - - return getMemRefType(cast(value.getType()), options, - /*layout=*/{}, *memSpace); + // Op is not bufferizable, use conversion interface. + bufferization::ConversionInterface iface(value.getContext()); + return iface.getBufferType(value, options, state, [&](const Twine &message) { + return op->emitError(message); + }); } bool bufferization::hasTensorSemantics(Operation *op) { @@ -744,12 +742,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, SmallVector replacements; for (OpResult opResult : op->getOpResults()) { Value replacement = values[opResult.getResultNumber()]; - if (llvm::isa(opResult.getType())) { + if (llvm::isa(opResult.getType())) { // The OpResult is a tensor. Such values are replaced with memrefs during // bufferization. - assert((llvm::isa(replacement.getType()) || - llvm::isa(replacement.getType())) && - "tensor op result should be replaced with a memref value"); + assert(llvm::isa(replacement.getType()) && + "tensor op result should be replaced with a buffer value"); // The existing uses of the OpResult still expect a tensor. Insert a // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually // loose all of its users and eventually DCE away. @@ -969,8 +966,8 @@ FailureOr bufferization::detail::defaultGetBufferType( // If the OpResult has an equivalent OpOperand, both OpResult and // OpOperand bufferize to the exact same buffer type. Value equivalentOperand = aliases.getAliases().front().opOperand->get(); - return getBufferType(equivalentOperand, options, bufferizationState, - invocationStack); + return castToMemRef(getBufferType(equivalentOperand, options, + bufferizationState, invocationStack)); } // If we do not know the memory space and there is no default memory space, @@ -1030,7 +1027,7 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) { } bool bufferization::detail::defaultHasTensorSemantics(Operation *op) { - auto isaTensor = [](Type t) { return isa(t); }; + auto isaTensor = [](Type t) { return isa(t); }; bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) { return any_of(r.getBlocks(), [&](Block &b) { return any_of(b.getArguments(), [&](BlockArgument bbArg) { @@ -1045,3 +1042,31 @@ bool bufferization::detail::defaultHasTensorSemantics(Operation *op) { return true; return any_of(op->getOperandTypes(), isaTensor); } + +FailureOr +bufferization::detail::castToMemRef(FailureOr bufferType) { + if (failed(bufferType)) + return failure(); + assert(isa(*bufferType) && "expected memref type"); + return cast(*bufferType); +} + +bool bufferization::detail::typesMatchAfterBufferization(Operation &op, + Value tensor, + Value buffer) { + assert(isa(tensor.getType()) && "expected TensorLikeType"); + assert(isa(buffer.getType()) && "expected BufferLikeType"); + + // Op is not bufferizable, use conversion interface. + bufferization::ConversionInterface iface(op.getContext()); + return succeeded(iface.typesMatch( + cast(tensor.getType()), + cast(buffer.getType()), + [&](const Twine &message) { return op.emitError(message); })); +} + +Type bufferization::detail::getTensorFromBuffer(Type buffer) { + assert(isa(buffer) && "expected BufferLikeType"); + bufferization::ConversionInterface iface(buffer.getContext()); + return iface.getTensorFromBuffer(cast(buffer)); +} diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp new file mode 100644 index 0000000000000..287e9bf85002f --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp @@ -0,0 +1,67 @@ +//===- BufferizationConversionInterface.cpp - Dialect Interface ---=------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" // getTensorTypeFromMemRefType + +namespace mlir { +namespace bufferization { + +FailureOr ConversionInterface::getBufferType( + Value value, const BufferizationOptions &options, + const BufferizationState &state, + function_ref emitError) const { + Dialect *dialect = &value.getType().getDialect(); + if (const ConversionDialectInterface *iface = getInterfaceFor(dialect)) + return iface->getBufferType(value, options, state, emitError); + + // Fall back to tensor -> memref conversion. + auto memSpace = + options.defaultMemorySpaceFn(cast(value.getType())); + if (!memSpace.has_value()) + return emitError("could not infer memory space"); + + return cast( + getMemRefType(value, options, /*layout=*/{}, *memSpace)); +} + +LogicalResult ConversionInterface::typesMatch( + TensorLikeType tensor, BufferLikeType buffer, + function_ref emitError) const { + Dialect *dialect = &tensor.getDialect(); + if (const ConversionDialectInterface *iface = getInterfaceFor(dialect)) + return iface->typesMatch(tensor, buffer, emitError); + + // Fall back to tensor, memref checking. + assert(isa(tensor) && "expected tensor type"); + assert(isa(buffer) && "expected memref type"); + + if (cast(tensor).getShape() != + cast(buffer).getShape()) { + return emitError("shapes do not match"); + } + + if (cast(tensor).getElementType() != + cast(buffer).getElementType()) { + return emitError("element types do not match"); + } + + return success(); +} + +TensorLikeType +ConversionInterface::getTensorFromBuffer(BufferLikeType buffer) const { + Dialect *dialect = &buffer.getDialect(); + if (const ConversionDialectInterface *iface = getInterfaceFor(dialect)) + return iface->getTensorFromBuffer(buffer); + + return cast(memref::getTensorTypeFromMemRefType(buffer)); +} + +} // namespace bufferization +} // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index dc54ac94aed32..79af1e8fee79f 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -90,12 +90,12 @@ LogicalResult mlir::bufferization::foldToBufferToTensorPair( if (!bufferToTensor) return failure(); - Type srcType = bufferToTensor.getMemref().getType(); + Type srcType = bufferToTensor.getBuffer().getType(); Type destType = toBuffer.getType(); // Directly rewrite if the type did not change. if (srcType == destType) { - rewriter.replaceOp(toBuffer, bufferToTensor.getMemref()); + rewriter.replaceOp(toBuffer, bufferToTensor.getBuffer()); return success(); } @@ -106,7 +106,7 @@ LogicalResult mlir::bufferization::foldToBufferToTensorPair( // Ranked memref -> Ranked memref cast. if (rankedSrcType && rankedDestType) { FailureOr replacement = castOrReallocMemRefValue( - rewriter, bufferToTensor.getMemref(), rankedDestType, options); + rewriter, bufferToTensor.getBuffer(), rankedDestType, options); if (failed(replacement)) return failure(); @@ -124,7 +124,7 @@ LogicalResult mlir::bufferization::foldToBufferToTensorPair( assert(memref::CastOp::areCastCompatible(srcType, destType) && "expected that types are cast compatible"); rewriter.replaceOpWithNewOp(toBuffer, destType, - bufferToTensor.getMemref()); + bufferToTensor.getBuffer()); return success(); } @@ -233,8 +233,9 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, if (getMemorySpace().has_value()) { memorySpace = *getMemorySpace(); } else if (getCopy()) { - auto copyBufferType = bufferization::getBufferType(getCopy(), options, - state, invocationStack); + auto copyBufferType = + bufferization::detail::castToMemRef(bufferization::getBufferType( + getCopy(), options, state, invocationStack)); if (failed(copyBufferType)) return failure(); memorySpace = copyBufferType->getMemorySpace(); @@ -744,7 +745,7 @@ bool ToTensorOp::isWritable(Value value, const AnalysisState &state) { } OpFoldResult ToTensorOp::fold(FoldAdaptor) { - if (auto toBuffer = getMemref().getDefiningOp()) + if (auto toBuffer = getBuffer().getDefiningOp()) // Approximate alias analysis by conservatively folding only when no there // is no interleaved operation. if (toBuffer->getBlock() == this->getOperation()->getBlock() && @@ -764,7 +765,7 @@ struct DimOfToTensorFolder : public OpRewritePattern { return failure(); rewriter.replaceOpWithNewOp( - dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex()); + dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex()); return success(); } }; @@ -781,8 +782,8 @@ void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, OpFoldResult ToBufferOp::fold(FoldAdaptor) { if (auto memrefToTensor = getTensor().getDefiningOp()) - if (memrefToTensor.getMemref().getType() == getType()) - return memrefToTensor.getMemref(); + if (memrefToTensor.getBuffer().getType() == getType()) + return memrefToTensor.getBuffer(); return {}; } diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt index 63dcc1eb233e9..a47c1569e4c33 100644 --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect BufferizationDialect.cpp BufferViewFlowOpInterface.cpp UnstructuredControlFlow.cpp + BufferizationConversionInterface.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 7e9b9119ce949..6472ef3eff2ac 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -412,11 +412,11 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, continue; } - FailureOr memrefType = + FailureOr bufferType = bufferization::getBufferType(bbArg, options, state); - if (failed(memrefType)) + if (failed(bufferType)) return failure(); - newTypes.push_back(*memrefType); + newTypes.push_back(*bufferType); } // Change the type of all block arguments. @@ -463,7 +463,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, newOperands.push_back(operand); continue; } - FailureOr operandBufferType = + FailureOr operandBufferType = bufferization::getBufferType(operand, options, state); if (failed(operandBufferType)) return failure(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index a0168da44b7b3..453ed43bcadd2 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -255,7 +255,7 @@ struct CallOpInterface } // Returning a memref. - FailureOr resultType = + FailureOr resultType = bufferization::getBufferType(result, options, state); if (failed(resultType)) return failure(); @@ -290,13 +290,13 @@ struct CallOpInterface // The called function was not bufferized yet. This can happen when // there cycles in the function call graph. Compute the bufferized // result type. - FailureOr maybeMemRefType = + FailureOr maybeBufferType = bufferization::getBufferType( funcOp.getArgument(opOperand.getOperandNumber()), options, state); - if (failed(maybeMemRefType)) + if (failed(maybeBufferType)) return failure(); - memRefType = *maybeMemRefType; + memRefType = *maybeBufferType; } // Since we don't yet have a clear layout story, to_buffer may diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 46fa77a7dc4e6..efa9fc1a070aa 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -108,7 +108,7 @@ struct ConditionOpInterface getBuffer(rewriter, value, options, state); if (failed(maybeBuffer)) return failure(); - FailureOr resultType = bufferization::getBufferType( + FailureOr resultType = bufferization::getBufferType( whileOp.getAfterArguments()[it.index()], options, state); if (failed(resultType)) return failure(); @@ -292,8 +292,9 @@ struct IfOpInterface // True branch was already bufferized. thenBufferType = cast(thenValue.getType()); } else { - auto maybeBufferType = bufferization::getBufferType( - thenValue, options, state, invocationStack); + auto maybeBufferType = + bufferization::detail::castToMemRef(bufferization::getBufferType( + thenValue, options, state, invocationStack)); if (failed(maybeBufferType)) return failure(); thenBufferType = *maybeBufferType; @@ -302,8 +303,9 @@ struct IfOpInterface // False branch was already bufferized. elseBufferType = cast(elseValue.getType()); } else { - auto maybeBufferType = bufferization::getBufferType( - elseValue, options, state, invocationStack); + auto maybeBufferType = + bufferization::detail::castToMemRef(bufferization::getBufferType( + elseValue, options, state, invocationStack)); if (failed(maybeBufferType)) return failure(); elseBufferType = *maybeBufferType; @@ -406,9 +408,7 @@ struct IndexSwitchOpInterface return bufferType; auto maybeBufferType = bufferization::getBufferType( yieldedValue, options, state, invocationStack); - if (failed(maybeBufferType)) - return failure(); - return maybeBufferType; + return bufferization::detail::castToMemRef(maybeBufferType); }; // Compute buffer type of the default case. @@ -527,8 +527,8 @@ static FailureOr computeLoopRegionIterArgBufferType( const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) { // Determine the buffer type of the init_arg. - auto initArgBufferType = - bufferization::getBufferType(initArg, options, state, invocationStack); + auto initArgBufferType = bufferization::detail::castToMemRef( + bufferization::getBufferType(initArg, options, state, invocationStack)); if (failed(initArgBufferType)) return failure(); @@ -554,8 +554,9 @@ static FailureOr computeLoopRegionIterArgBufferType( } else { // Note: This typically triggers a recursive call for the buffer type of // the iter_arg. - auto maybeBufferType = bufferization::getBufferType(yieldedValue, options, - state, invocationStack); + auto maybeBufferType = + bufferization::detail::castToMemRef(bufferization::getBufferType( + yieldedValue, options, state, invocationStack)); if (failed(maybeBufferType)) return failure(); yieldedValueBufferType = *maybeBufferType; @@ -718,8 +719,12 @@ struct ForOpInterface if (auto opResult = dyn_cast(value)) { // The type of an OpResult must match the corresponding iter_arg type. BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); - return bufferization::getBufferType(bbArg, options, state, - invocationStack); + auto bufferType = + bufferization::getBufferType(bbArg, options, state, invocationStack); + if (failed(bufferType)) + return failure(); + assert(isa(*bufferType) && "expected memref type"); + return cast(*bufferType); } // Compute result/argument number. @@ -1078,8 +1083,8 @@ struct WhileOpInterface // scf.condition was already bufferized. return cast(conditionYieldedVal.getType()); } - return bufferization::getBufferType(conditionYieldedVal, options, state, - invocationStack); + return bufferization::detail::castToMemRef(bufferization::getBufferType( + conditionYieldedVal, options, state, invocationStack)); } /// Assert that yielded values of an scf.while op are equivalent to their @@ -1185,14 +1190,14 @@ struct YieldOpInterface // We may have to cast the value before yielding it. if (isa( yieldOp->getParentOp())) { - FailureOr resultType = bufferization::getBufferType( + FailureOr resultType = bufferization::getBufferType( yieldOp->getParentOp()->getResult(it.index()), options, state); if (failed(resultType)) return failure(); buffer = castBuffer(rewriter, buffer, *resultType); } else if (auto whileOp = dyn_cast(yieldOp->getParentOp())) { - FailureOr resultType = bufferization::getBufferType( + FailureOr resultType = bufferization::getBufferType( whileOp.getBeforeArguments()[it.index()], options, state); if (failed(resultType)) return failure(); @@ -1307,15 +1312,15 @@ struct ForallOpInterface if (auto bbArg = dyn_cast(value)) // A tensor block argument has the same bufferized type as the // corresponding output operand. - return bufferization::getBufferType( - forallOp.getTiedOpOperand(bbArg)->get(), options, state, - invocationStack); + return bufferization::detail::castToMemRef( + bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(), + options, state, invocationStack)); // The bufferized result type is the same as the bufferized type of the // corresponding output operand. - return bufferization::getBufferType( + return bufferization::detail::castToMemRef(bufferization::getBufferType( forallOp.getOutputs()[cast(value).getResultNumber()], options, - state, invocationStack); + state, invocationStack)); } bool isRepetitiveRegion(Operation *op, unsigned index) const { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp index 57291064eba22..1bd9563b3db07 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp @@ -549,8 +549,8 @@ TypedValue sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { auto tTp = llvm::cast(tensor.getType()); auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType()); - return builder.create(loc, mTp, tensor) - .getResult(); + return cast>( + builder.create(loc, mTp, tensor).getResult()); } Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 4b778b768d136..40b710f17fe44 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -54,8 +54,9 @@ struct CastOpInterface const BufferizationState &state, SmallVector &invocationStack) const { auto castOp = cast(op); - auto maybeSrcBufferType = bufferization::getBufferType( - castOp.getSource(), options, state, invocationStack); + auto maybeSrcBufferType = + bufferization::detail::castToMemRef(bufferization::getBufferType( + castOp.getSource(), options, state, invocationStack)); if (failed(maybeSrcBufferType)) return failure(); Attribute memorySpace = maybeSrcBufferType->getMemorySpace(); @@ -500,8 +501,8 @@ struct FromElementsOpInterface /*copy=*/false); if (failed(tensorAlloc)) return failure(); - FailureOr memrefType = - bufferization::getBufferType(*tensorAlloc, options, state); + FailureOr memrefType = bufferization::detail::castToMemRef( + bufferization::getBufferType(*tensorAlloc, options, state)); if (failed(memrefType)) return failure(); Value buffer = rewriter.create( @@ -758,8 +759,9 @@ struct PadOpInterface SmallVector &invocationStack) const { // Infer memory space from the source tensor. auto padOp = cast(op); - auto maybeSrcBufferType = bufferization::getBufferType( - padOp.getSource(), options, state, invocationStack); + auto maybeSrcBufferType = + bufferization::detail::castToMemRef(bufferization::getBufferType( + padOp.getSource(), options, state, invocationStack)); if (failed(maybeSrcBufferType)) return failure(); MemRefLayoutAttrInterface layout; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir index cd19e3a5e82aa..da3c26ce36ba5 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -268,4 +268,23 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x %r = tensor.extract %dest_filled[%idx] : tensor<5xf32> return %0, %r : tensor<5xf32>, f32 -} \ No newline at end of file +} + +// ----- + +// CHECK-LABEL: func.func @test_dialect_op( +// CHECK-SAME: %[[ARG:.*]]: !test.test_tensor<[32, 64], f64> +// CHECK-SAME: ) -> !test.test_tensor<[32, 128], f64> { +func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>) + -> !test.test_tensor<[32, 128], f64> { + // CHECK: %[[MEMREF:.*]] = bufferization.to_buffer %[[ARG]] + // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]]) + // CHECK-SAME: : (!test.test_memref<[32, 64], f64>) + // CHECK-SAME: -> !test.test_memref<[32, 128], f64> + // CHECK: %[[OUT:.*]] = bufferization.to_tensor %[[DUMMY]] + %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[32, 64], f64>) + -> !test.test_tensor<[32, 128], f64> + + // CHECK: return %[[OUT]] + return %out : !test.test_tensor<[32, 128], f64> +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 1bbf2cc7481d9..03985874f910d 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -11,6 +11,7 @@ #include "TestTypes.h" #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AsmState.h" @@ -284,6 +285,53 @@ getDynamicCustomParserPrinterOp(TestDialect *dialect) { verifier, regionVerifier, parser, printer); } +namespace { + +struct TestConverter : bufferization::ConversionDialectInterface { + TestConverter(Dialect *dialect) + : bufferization::ConversionDialectInterface(dialect) {} + + FailureOr + getBufferType(Value value, const bufferization::BufferizationOptions &options, + const bufferization::BufferizationState &state, + function_ref emitError) + const override { + auto testTensor = dyn_cast(value.getType()); + if (!testTensor) + return emitError("expected TestTensorType"); + + return cast( + TestMemrefType::get(value.getContext(), testTensor.getShape(), + testTensor.getElementType(), nullptr)); + } + + LogicalResult typesMatch(bufferization::TensorLikeType tensor, + bufferization::BufferLikeType buffer, + function_ref + emitError) const override { + auto testTensor = dyn_cast(tensor); + auto testMemref = dyn_cast(buffer); + if (!testTensor || !testMemref) + return emitError("expected TestTensorType and TestMemrefType"); + + const bool valid = + testTensor.getShape() == testMemref.getShape() && + testTensor.getElementType() == testMemref.getElementType(); + return success(valid); + } + + bufferization::TensorLikeType + getTensorFromBuffer(bufferization::BufferLikeType buffer) const override { + auto testMemref = dyn_cast(buffer); + assert(testMemref && "expected TestMemrefType"); + return cast( + TestTensorType::get(testMemref.getContext(), testMemref.getShape(), + testMemref.getElementType())); + } +}; + +} // namespace + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// @@ -333,6 +381,7 @@ void TestDialect::initialize() { registerDynamicOp(getDynamicCustomParserPrinterOp(this)); registerInterfaces(); allowUnknownOperations(); + addInterface(); // Instantiate our fallback op interface that we'll use on specific // unregistered op. diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index b5a8bd10d6b68..78e44c6ec7a9b 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -8,6 +8,7 @@ #include "TestDialect.h" #include "TestOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/FunctionImplementation.h" @@ -1387,3 +1388,25 @@ TestMultiSlotAlloca::handleDestructuringComplete( const DestructurableMemorySlot &slot, OpBuilder &builder) { return createNewMultiAllocaWithoutSlot(slot, builder, *this); } + +::mlir::LogicalResult test::TestDummyTensorOp::bufferize( + ::mlir::RewriterBase &rewriter, + const ::mlir::bufferization::BufferizationOptions &options, + ::mlir::bufferization::BufferizationState &state) { + auto buffer = + mlir::bufferization::getBuffer(rewriter, getInput(), options, state); + if (mlir::failed(buffer)) + return failure(); + + const auto outType = getOutput().getType(); + const auto bufferizedOutType = test::TestMemrefType::get( + getContext(), outType.getShape(), outType.getElementType(), nullptr); + // replace op with memref analogy + auto dummyMemrefOp = rewriter.create( + getLoc(), bufferizedOutType, *buffer); + + mlir::bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(), + dummyMemrefOp.getResult()); + + return mlir::success(); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h index c2ee5f9ab9a57..b414b47c87425 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.h +++ b/mlir/test/lib/Dialect/Test/TestOps.h @@ -13,6 +13,7 @@ #include "TestInterfaces.h" #include "TestTypes.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/Dialect/Func/IR/FuncOps.h" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 59330fdb1bb2c..79bcd9c2e0a9a 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -31,7 +31,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" - +include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" // Include the attribute definitions. include "TestAttrDefs.td" @@ -2825,7 +2825,7 @@ def TestNVVMRequiresSMArchCondOp : let assemblyFormat = "attr-dict"; } -def TestNVVMRequirestSMArchCondMultiOp : +def TestNVVMRequirestSMArchCondMultiOp : TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMa<[90, 100]>]> { let arguments = (ins ); let assemblyFormat = "attr-dict"; @@ -3552,4 +3552,58 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> { }]; } +//===----------------------------------------------------------------------===// +// Test Ops bufferization +//===----------------------------------------------------------------------===// + +def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", [BufferizableOpInterface]> { + let arguments = (ins + Arg:$input + ); + let results = (outs + Arg:$output + ); + let extraClassDeclaration = [{ + // BufferizableOpInterface + bool bufferizesToMemoryRead(mlir::OpOperand&, + const mlir::bufferization::AnalysisState&); + + bool bufferizesToMemoryWrite(mlir::OpOperand&, + const mlir::bufferization::AnalysisState&); + + mlir::bufferization::AliasingValueList getAliasingValues(mlir::OpOperand&, + const mlir::bufferization::AnalysisState&); + + mlir::LogicalResult bufferize( + mlir::RewriterBase& rewriter, + const mlir::bufferization::BufferizationOptions& options, + mlir::bufferization::BufferizationState &state); + }]; + + let extraClassDefinition = [{ + bool test::TestDummyTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&, + const ::mlir::bufferization::AnalysisState&) { + return true; + } + bool test::TestDummyTensorOp::bufferizesToMemoryWrite(::mlir::OpOperand&, + const ::mlir::bufferization::AnalysisState&) { + return true; + } + ::mlir::bufferization::AliasingValueList + test::TestDummyTensorOp::getAliasingValues(::mlir::OpOperand&, + const ::mlir::bufferization::AnalysisState&) { + return {}; + } + }]; +} + +def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> { + let arguments = (ins + Arg:$input + ); + let results = (outs + Arg:$output + ); +} + #endif // TEST_OPS From 7ef1183d5025a386acd16375242cbdd69aecc5eb Mon Sep 17 00:00:00 2001 From: "Golubev, Andrey" Date: Tue, 17 Jun 2025 08:55:13 +0000 Subject: [PATCH 2/5] [NFC] Remove type-inferring builder of ToTensorOp The builder is ambiguous given customizable tensor-like -> buffer-like conversion and is thus removed. The places where reverse bufferization has to happen rely on the pre-existing functionality. --- .../Bufferization/IR/BufferizableOpInterface.h | 4 ---- .../IR/BufferizationConversionInterface.h | 8 -------- .../Dialect/Bufferization/IR/BufferizationOps.td | 7 ------- .../Bufferization/IR/BufferizableOpInterface.cpp | 10 +++------- .../IR/BufferizationConversionInterface.cpp | 9 --------- .../Dialect/Bufferization/IR/BufferizationOps.cpp | 5 +++-- .../Transforms/ConvertToDestinationStyle.cpp | 9 ++++++--- .../Transforms/BufferizableOpInterfaceImpl.cpp | 2 +- .../SparseTensor/Transforms/SparseGPUCodegen.cpp | 15 +++++++++------ .../Transforms/SparseTensorCodegen.cpp | 3 ++- .../Transforms/SparseTensorConversion.cpp | 4 +++- mlir/test/lib/Dialect/Test/TestDialect.cpp | 9 --------- 12 files changed, 27 insertions(+), 58 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 768778df046a6..b2a6420667d6f 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -748,10 +748,6 @@ FailureOr castToMemRef(FailureOr bufferType); /// bufferization::ConversionInterface to verify the types in tensor and buffer /// worlds match. bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer); - -/// This function is a free-standing helper that relies on -/// bufferization::ConversionInterface to perform the conversion. -Type getTensorFromBuffer(Type buffer); } // namespace detail } // namespace bufferization diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h index 4164d1dcb9ea6..6afdd3d4cb74e 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h @@ -35,10 +35,6 @@ struct ConversionDialectInterface virtual LogicalResult typesMatch( TensorLikeType tensor, BufferLikeType buffer, function_ref emitError) const = 0; - - /// Hook to customize buffer-like -> tensor-like conversion, which is the - /// opposite of bufferization. - virtual TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const = 0; }; /// Interface collection for conversion between tensor-like and buffer-like @@ -60,10 +56,6 @@ struct ConversionInterface LogicalResult typesMatch(TensorLikeType tensor, BufferLikeType buffer, function_ref emitError) const; - - /// Dispatches to ConversionDialectInterface::getTensorFromBuffer() of the - /// dialect associated with the value type. - TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const; }; } // namespace bufferization diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index 3d301a0657200..6fc99d14b66d0 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -490,13 +490,6 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ `:` type($buffer) `to` type($result) }]; - let builders = [ - OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{ - auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType()); - build($_builder, $_state, rtt, buffer, restrict, writeable); - }]> - ]; - let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index bd79cbc80dd2a..9a3ab5855eef6 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -172,7 +172,9 @@ FailureOr bufferization::allocateTensorForShapedValue( if (llvm::isa(shapedValue.getType())) { tensor = shapedValue; } else if (llvm::isa(shapedValue.getType())) { - tensor = b.create(loc, shapedValue); + tensor = b.create( + loc, memref::getTensorTypeFromMemRefType(shapedValue.getType()), + shapedValue); } else if (llvm::isa(shapedValue.getType()) || llvm::isa(shapedValue.getType())) { return getOwnerOfValue(shapedValue) @@ -1064,9 +1066,3 @@ bool bufferization::detail::typesMatchAfterBufferization(Operation &op, cast(buffer.getType()), [&](const Twine &message) { return op.emitError(message); })); } - -Type bufferization::detail::getTensorFromBuffer(Type buffer) { - assert(isa(buffer) && "expected BufferLikeType"); - bufferization::ConversionInterface iface(buffer.getContext()); - return iface.getTensorFromBuffer(cast(buffer)); -} diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp index 287e9bf85002f..3084854f58cfe 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp @@ -54,14 +54,5 @@ LogicalResult ConversionInterface::typesMatch( return success(); } -TensorLikeType -ConversionInterface::getTensorFromBuffer(BufferLikeType buffer) const { - Dialect *dialect = &buffer.getDialect(); - if (const ConversionDialectInterface *iface = getInterfaceFor(dialect)) - return iface->getTensorFromBuffer(buffer); - - return cast(memref::getTensorTypeFromMemRefType(buffer)); -} - } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 79af1e8fee79f..451446f35b105 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -643,8 +643,9 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder, assert(getRestrict() && "expected that ops with memrefs dest have 'restrict'"); setRestrict(false); - return builder.create(loc, getDest(), /*restrict=*/true, - getWritable()); + return builder.create( + loc, memref::getTensorTypeFromMemRefType(getDest().getType()), getDest(), + /*restrict=*/true, getWritable()); } bool MaterializeInDestinationOp::isEquivalentSubset( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 94a4b9011c16b..573420f6a9aa9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -252,7 +252,8 @@ Value linalg::bufferizeToAllocation( // Create bufferization.to_tensor with "restrict" and "writable". The returned // tensor is a new buffer allocation, so it does not alias with any buffer. Value toTensorOp = rewriter.create( - loc, alloc, /*restrict=*/true, /*writable=*/true); + loc, padOp.getResult().getType(), alloc, /*restrict=*/true, + /*writable=*/true); rewriter.replaceOp(padOp, toTensorOp); return alloc; } @@ -340,7 +341,8 @@ Value linalg::bufferizeToAllocation( // Create bufferization.to_tensor with "restrict" and "writable". The returned // tensor is a new buffer allocation, so it does not alias with any buffer. Value toTensorOp = rewriter.create( - loc, alloc, /*restrict=*/true, /*writable=*/true); + loc, allocTensorOp.getResult().getType(), alloc, /*restrict=*/true, + /*writable=*/true); rewriter.replaceOp(allocTensorOp, toTensorOp); return alloc; } @@ -567,7 +569,8 @@ Value linalg::bufferizeToAllocation( createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options); } rewriter.modifyOpInPlace(op, [&]() { - auto toTensorOp = rewriter.create(op->getLoc(), alloc); + auto toTensorOp = rewriter.create( + op->getLoc(), operand->get().getType(), alloc); operand->set(toTensorOp); if (options.bufferizeDestinationOnly) { rewriter.modifyOpInPlace(toTensorOp, [&]() { diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp index dc91117a51936..8a471c12d21e4 100644 --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -67,7 +67,7 @@ struct AssumingOpInterface for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) { if (isa(it.value())) { newResults.push_back(rewriter.create( - assumingOp.getLoc(), newOp->getResult(it.index()))); + assumingOp.getLoc(), it.value(), newOp->getResult(it.index()))); } else { newResults.push_back(newOp->getResult(it.index())); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp index e5f2418367a58..e89b34d457ff8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -651,7 +651,7 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter, tokens.clear(); // Done. - rewriter.replaceOpWithNewOp(op, memY); + rewriter.replaceOpWithNewOp(op, y.getType(), memY); return success(); } @@ -752,7 +752,7 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter, tokens.clear(); // Done. - rewriter.replaceOpWithNewOp(op, bufC); + rewriter.replaceOpWithNewOp(op, c.getType(), bufC); return success(); } @@ -925,9 +925,12 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, tokens.clear(); // Done. - Value vt = rewriter.create(loc, valH); - Value rt = rewriter.create(loc, rowH); - Value ct = rewriter.create(loc, colH); + Value vt = rewriter.create( + loc, memref::getTensorTypeFromMemRefType(valH.getType()), valH); + Value rt = rewriter.create( + loc, memref::getTensorTypeFromMemRefType(rowH.getType()), rowH); + Value ct = rewriter.create( + loc, memref::getTensorTypeFromMemRefType(colH.getType()), colH); rewriter.replaceOpWithNewOp(op, c.getType(), ValueRange{rt, ct}, vt); return success(); @@ -1043,7 +1046,7 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, tokens.clear(); // Done. - rewriter.replaceOpWithNewOp(op, bufC); + rewriter.replaceOpWithNewOp(op, C.getType(), bufC); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index e5f9717c3fbaa..14ced56b8365f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1471,7 +1471,8 @@ struct SparseDisassembleOpConverter // Converts MemRefs back to Tensors. SmallVector retValues = llvm::to_vector( llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value { - return rewriter.create(loc, v); + return rewriter.create( + loc, memref::getTensorTypeFromMemRefType(v.getType()), v); })); // Appends the actual memory length used in each buffer returned. retValues.append(retLen.begin(), retLen.end()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 9ffa64dc821d8..7f0b657687442 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -867,7 +867,9 @@ class SparseTensorDisassembleConverter // Converts MemRefs back to Tensors. assert(retVal.size() + retLen.size() == op.getNumResults()); for (unsigned i = 0, sz = retVal.size(); i < sz; i++) { - auto tensor = rewriter.create(loc, retVal[i]); + auto tensor = rewriter.create( + loc, memref::getTensorTypeFromMemRefType(retVal[i].getType()), + retVal[i]); retVal[i] = rewriter.create(loc, op.getResultTypes()[i], tensor); } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 03985874f910d..26d07f65ed0f9 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -319,15 +319,6 @@ struct TestConverter : bufferization::ConversionDialectInterface { testTensor.getElementType() == testMemref.getElementType(); return success(valid); } - - bufferization::TensorLikeType - getTensorFromBuffer(bufferization::BufferLikeType buffer) const override { - auto testMemref = dyn_cast(buffer); - assert(testMemref && "expected TestMemrefType"); - return cast( - TestTensorType::get(testMemref.getContext(), testMemref.getShape(), - testMemref.getElementType())); - } }; } // namespace From b05a291d2fe092ca669bc2c66fc490faf72eedbc Mon Sep 17 00:00:00 2001 From: "Golubev, Andrey" Date: Tue, 17 Jun 2025 09:34:41 +0000 Subject: [PATCH 3/5] Switch from ConversionDialectInterface to TensorLike API Noteworthy changes: * bufferization::getMemRefType() accepts a TensorType instead of Value to achieve broader applicability * BufferizationOptions::UnknownTypeConverterFn accepts a TensorType instead of Value to allow it being used in the updated getMemRefType() --- .../IR/BufferizableOpInterface.h | 4 +- .../IR/BufferizationConversionInterface.h | 64 ------------------- .../IR/BufferizationTypeInterfaces.h | 7 ++ .../IR/BufferizationTypeInterfaces.td | 27 +++++++- .../IR/BufferizableOpInterface.cpp | 19 ++---- .../IR/BufferizationConversionInterface.cpp | 58 ----------------- .../Bufferization/IR/BufferizationDialect.cpp | 35 +++++++++- .../IR/BufferizationTypeInterfaces.cpp | 21 ++++++ .../Dialect/Bufferization/IR/CMakeLists.txt | 2 +- mlir/test/lib/Dialect/Test/TestDialect.cpp | 40 ------------ mlir/test/lib/Dialect/Test/TestTypeDefs.td | 10 +++ mlir/test/lib/Dialect/Test/TestTypes.cpp | 21 ++++++ 12 files changed, 127 insertions(+), 181 deletions(-) delete mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h delete mode 100644 mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp create mode 100644 mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index b2a6420667d6f..c97e90e0ee1f3 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -745,8 +745,8 @@ bool defaultHasTensorSemantics(Operation *op); FailureOr castToMemRef(FailureOr bufferType); /// This function is a free-standing helper that relies on -/// bufferization::ConversionInterface to verify the types in tensor and buffer -/// worlds match. +/// bufferization::TensorLikeTypeInterface to verify the types in tensor and +/// buffer worlds match. bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer); } // namespace detail diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h deleted file mode 100644 index 6afdd3d4cb74e..0000000000000 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h +++ /dev/null @@ -1,64 +0,0 @@ -//===- BufferizationConversionInterface.h - Dialect Interface ---*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_ -#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_ - -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" -#include "mlir/IR/DialectInterface.h" - -namespace mlir { -namespace bufferization { - -/// This class defines a virtual interface for conversions between tensor-like -/// and buffer-like types. -struct ConversionDialectInterface - : DialectInterface::Base { - using Base::Base; - - /// Hook to customize tensor-like -> buffer-like conversion within a given - /// dialect. Returns a buffer-like type for the specific tensor-like type. - virtual FailureOr getBufferType( - Value value, const BufferizationOptions &options, - const BufferizationState &state, - function_ref emitError) const = 0; - - /// Hook to customize type checking between tensor-like and buffer-like types. - /// Given tensor `T` and buffer `B = getBufferType(T, ...)`, the call to - /// `typesMatch(T, B)` must return true. - virtual LogicalResult typesMatch( - TensorLikeType tensor, BufferLikeType buffer, - function_ref emitError) const = 0; -}; - -/// Interface collection for conversion between tensor-like and buffer-like -/// types, dispatches to a concrete interface implementation based on the -/// dialect to which the given type belongs. -struct ConversionInterface - : DialectInterfaceCollection { - using Base::Base; - - /// Dispatches to ConversionDialectInterface::getBufferType() of the dialect - /// associated with the value type. - FailureOr getBufferType( - Value value, const BufferizationOptions &options, - const BufferizationState &state, - function_ref emitError) const; - - /// Dispatches to ConversionDialectInterface::typesMatch() of the dialect - /// associated with the value type. - LogicalResult - typesMatch(TensorLikeType tensor, BufferLikeType buffer, - function_ref emitError) const; -}; - -} // namespace bufferization -} // namespace mlir - -#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h index 5faa1479ee542..cbb6054fcf886 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h @@ -13,8 +13,15 @@ // Bufferization Type Interfaces //===----------------------------------------------------------------------===// +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Types.h" +namespace mlir::bufferization { +struct BufferizationOptions; +class BufferizationState; +class BufferLikeType; +} // namespace mlir::bufferization + #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc" #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td index f19224a295648..9bca41a0284fa 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td @@ -21,10 +21,31 @@ def Bufferization_TensorLikeTypeInterface let description = [{ Indicates that this type is a tensor type (similarly to a MLIR builtin tensor) for bufferization purposes. - - The interface currently has no methods as it is used by types to opt into - being supported by the bufferization procedures. }]; + + let methods = [ + InterfaceMethod<[{ + Returns a BufferLike type for this TensorLike type. + }], + /*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>", + /*methodName=*/"getBufferType", + /*args=*/(ins + "const ::mlir::bufferization::BufferizationOptions &":$options, + "const ::mlir::bufferization::BufferizationState &":$state, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError + ) + >, + InterfaceMethod<[{ + Returns whether a BufferLike type is compatible to this TensorLike type. + The BufferLike type is assumed to be created by getBufferType(). + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"verifyCompatibleBufferType", + /*args=*/(ins + "::mlir::bufferization::BufferLikeType":$bufferType, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + > + ]; } def Bufferization_BufferLikeTypeInterface diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 9a3ab5855eef6..d33b853b8c203 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -8,7 +8,6 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -720,11 +719,9 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, if (bufferizableOp) return bufferizableOp.getBufferType(value, options, state, invocationStack); - // Op is not bufferizable, use conversion interface. - bufferization::ConversionInterface iface(value.getContext()); - return iface.getBufferType(value, options, state, [&](const Twine &message) { - return op->emitError(message); - }); + // Op is not bufferizable. + return cast(value.getType()) + .getBufferType(options, state, [&]() { return op->emitError(); }); } bool bufferization::hasTensorSemantics(Operation *op) { @@ -1059,10 +1056,8 @@ bool bufferization::detail::typesMatchAfterBufferization(Operation &op, assert(isa(tensor.getType()) && "expected TensorLikeType"); assert(isa(buffer.getType()) && "expected BufferLikeType"); - // Op is not bufferizable, use conversion interface. - bufferization::ConversionInterface iface(op.getContext()); - return succeeded(iface.typesMatch( - cast(tensor.getType()), - cast(buffer.getType()), - [&](const Twine &message) { return op.emitError(message); })); + return mlir::succeeded( + cast(tensor.getType()) + .verifyCompatibleBufferType(cast(buffer.getType()), + [&]() { return op.emitError(); })); } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp deleted file mode 100644 index 3084854f58cfe..0000000000000 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp +++ /dev/null @@ -1,58 +0,0 @@ -//===- BufferizationConversionInterface.cpp - Dialect Interface ---=------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" // getTensorTypeFromMemRefType - -namespace mlir { -namespace bufferization { - -FailureOr ConversionInterface::getBufferType( - Value value, const BufferizationOptions &options, - const BufferizationState &state, - function_ref emitError) const { - Dialect *dialect = &value.getType().getDialect(); - if (const ConversionDialectInterface *iface = getInterfaceFor(dialect)) - return iface->getBufferType(value, options, state, emitError); - - // Fall back to tensor -> memref conversion. - auto memSpace = - options.defaultMemorySpaceFn(cast(value.getType())); - if (!memSpace.has_value()) - return emitError("could not infer memory space"); - - return cast( - getMemRefType(value, options, /*layout=*/{}, *memSpace)); -} - -LogicalResult ConversionInterface::typesMatch( - TensorLikeType tensor, BufferLikeType buffer, - function_ref emitError) const { - Dialect *dialect = &tensor.getDialect(); - if (const ConversionDialectInterface *iface = getInterfaceFor(dialect)) - return iface->typesMatch(tensor, buffer, emitError); - - // Fall back to tensor, memref checking. - assert(isa(tensor) && "expected tensor type"); - assert(isa(buffer) && "expected memref type"); - - if (cast(tensor).getShape() != - cast(buffer).getShape()) { - return emitError("shapes do not match"); - } - - if (cast(tensor).getElementType() != - cast(buffer).getElementType()) { - return emitError("element types do not match"); - } - - return success(); -} - -} // namespace bufferization -} // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index d8eac01c2dea0..f92c1b4b18062 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -57,7 +57,40 @@ struct BufferizationInlinerInterface : public DialectInlinerInterface { template struct BuiltinTensorExternalModel : TensorLikeType::ExternalModel, - Tensor> {}; + Tensor> { + llvm::FailureOr getBufferType( + mlir::Type tensor, const BufferizationOptions &options, + const BufferizationState &state, + llvm::function_ref emitError) const { + auto tensorType = cast(tensor); + // Fall back to tensor -> memref conversion. + auto memSpace = options.defaultMemorySpaceFn(tensorType); + if (!memSpace.has_value()) + return emitError() << "could not infer memory space"; + + return cast( + getMemRefType(tensorType, options, /*layout=*/{}, *memSpace)); + } + + mlir::LogicalResult verifyCompatibleBufferType( + mlir::Type tensor, BufferLikeType bufferType, + llvm::function_ref emitError) const { + // Fall back to tensor, memref checking. + assert(isa(tensor) && "expected tensor type"); + assert(isa(bufferType) && "expected memref type"); + + auto tensorType = cast(tensor); + auto memrefType = cast(bufferType); + + if (tensorType.getShape() != memrefType.getShape()) + return emitError() << "shapes do not match"; + + if (tensorType.getElementType() != memrefType.getElementType()) + return emitError() << "element types do not match"; + + return mlir::success(); + } +}; template struct BuiltinMemRefExternalModel diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp new file mode 100644 index 0000000000000..0e973915c6fc9 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp @@ -0,0 +1,21 @@ +//===- BufferizationTypeInterfaces.cpp - Type Interfaces --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" + +//===----------------------------------------------------------------------===// +// Bufferization Type Interfaces +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace bufferization { + +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp.inc" + +} // namespace bufferization +} // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt index a47c1569e4c33..5d8f0060f2c3f 100644 --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -6,7 +6,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect BufferizationDialect.cpp BufferViewFlowOpInterface.cpp UnstructuredControlFlow.cpp - BufferizationConversionInterface.cpp + BufferizationTypeInterfaces.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 26d07f65ed0f9..1bbf2cc7481d9 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -11,7 +11,6 @@ #include "TestTypes.h" #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AsmState.h" @@ -285,44 +284,6 @@ getDynamicCustomParserPrinterOp(TestDialect *dialect) { verifier, regionVerifier, parser, printer); } -namespace { - -struct TestConverter : bufferization::ConversionDialectInterface { - TestConverter(Dialect *dialect) - : bufferization::ConversionDialectInterface(dialect) {} - - FailureOr - getBufferType(Value value, const bufferization::BufferizationOptions &options, - const bufferization::BufferizationState &state, - function_ref emitError) - const override { - auto testTensor = dyn_cast(value.getType()); - if (!testTensor) - return emitError("expected TestTensorType"); - - return cast( - TestMemrefType::get(value.getContext(), testTensor.getShape(), - testTensor.getElementType(), nullptr)); - } - - LogicalResult typesMatch(bufferization::TensorLikeType tensor, - bufferization::BufferLikeType buffer, - function_ref - emitError) const override { - auto testTensor = dyn_cast(tensor); - auto testMemref = dyn_cast(buffer); - if (!testTensor || !testMemref) - return emitError("expected TestTensorType and TestMemrefType"); - - const bool valid = - testTensor.getShape() == testMemref.getShape() && - testTensor.getElementType() == testMemref.getElementType(); - return success(valid); - } -}; - -} // namespace - //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// @@ -372,7 +333,6 @@ void TestDialect::initialize() { registerDynamicOp(getDynamicCustomParserPrinterOp(this)); registerInterfaces(); allowUnknownOperations(); - addInterface(); // Instantiate our fallback op interface that we'll use on specific // unregistered op. diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 09294e84960f2..5594e71bf7ca0 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -428,6 +428,16 @@ def TestTensorType : Test_Type<"TestTensor", return test::TestTensorType::get( getContext(), shape.value_or(getShape()), elementType); } + + // TensorLikeTypeInterface: + ::mlir::FailureOr<::mlir::bufferization::BufferLikeType> + getBufferType(const ::mlir::bufferization::BufferizationOptions& options, + const ::mlir::bufferization::BufferizationState& state, + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + + ::mlir::LogicalResult verifyCompatibleBufferType( + ::mlir::bufferization::BufferLikeType bufferType, + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); }]; } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 5c784dcee6e15..8ee8ddb39f202 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -545,3 +545,24 @@ TestTypeOpAsmTypeInterfaceType::getAlias(::llvm::raw_ostream &os) const { os << "op_asm_type_interface_type"; return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias; } + +::mlir::FailureOr<::mlir::bufferization::BufferLikeType> +TestTensorType::getBufferType( + const ::mlir::bufferization::BufferizationOptions &, + const ::mlir::bufferization::BufferizationState &, + ::llvm::function_ref<::mlir::InFlightDiagnostic()>) { + return cast( + TestMemrefType::get(getContext(), getShape(), getElementType(), nullptr)); +} + +::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType( + ::mlir::bufferization::BufferLikeType bufferType, + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + auto testMemref = dyn_cast(bufferType); + if (!testMemref) + return emitError() << "expected TestMemrefType"; + + const bool valid = getShape() == testMemref.getShape() && + getElementType() == testMemref.getElementType(); + return mlir::success(valid); +} From 2de5728af40426cfa809a70a6fc5a3dcbb7c968b Mon Sep 17 00:00:00 2001 From: "Golubev, Andrey" Date: Wed, 18 Jun 2025 13:15:28 +0000 Subject: [PATCH 4/5] Address code review feedback * Remove BufferizationState from TensorLikeType::getBufferType() * Rename castToMemRef to asMemRefType (+ add extra docs) * Improve ToTensorOp's docs * Apply minor suggestions --- .../Bufferization/IR/BufferizableOpInterface.h | 7 ++++++- .../Dialect/Bufferization/IR/BufferizationOps.td | 5 ++++- .../IR/BufferizationTypeInterfaces.td | 1 - .../Transforms/BufferizableOpInterfaceImpl.cpp | 6 +++--- .../Bufferization/IR/BufferizableOpInterface.cpp | 15 ++++++--------- .../Bufferization/IR/BufferizationDialect.cpp | 3 --- .../Bufferization/IR/BufferizationOps.cpp | 2 +- .../Transforms/BufferizableOpInterfaceImpl.cpp | 16 ++++++++-------- .../Transforms/BufferizableOpInterfaceImpl.cpp | 6 +++--- mlir/test/lib/Dialect/Test/TestTypeDefs.td | 1 - mlir/test/lib/Dialect/Test/TestTypes.cpp | 1 - 11 files changed, 31 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index c97e90e0ee1f3..c1529a36465ac 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -742,7 +742,12 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand); bool defaultHasTensorSemantics(Operation *op); /// This is a helper function used when buffer type is guaranteed to be memref. -FailureOr castToMemRef(FailureOr bufferType); +/// It performs two actions: failure state checking and an explicit llvm::cast<> +/// from the buffer-like type interface to a BaseMemRefType. This allows easier +/// management of differences in C++ types at the API boundaries. Valid buffer +/// type is casted to the memref type. Otherwise, the failure state is +/// propagated i.e. asMemRefType(mlir::failure()) returns mlir::failure(). +FailureOr asMemRefType(FailureOr bufferType); /// This function is a free-standing helper that relies on /// bufferization::TensorLikeTypeInterface to verify the types in tensor and diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index 6fc99d14b66d0..ec126a965ccbb 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -404,7 +404,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ let summary = "create a buffer-like type from a tensor-like type"; let description = [{ An operation that creates a tensor from a buffer. The result value is a - tensor-like type whose shape and element type match the buffer-like operand. + tensor-like type that must match the corresponding buffer-like operand as + per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType + and BaseMemRefType), this means that shapes and element types match between + the tensor and the buffer. The opposite of this op is `to_buffer`. Together, these two ops are useful for source/target materializations when doing type conversions diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td index 9bca41a0284fa..fb6fc4f5ad964 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td @@ -31,7 +31,6 @@ def Bufferization_TensorLikeTypeInterface /*methodName=*/"getBufferType", /*args=*/(ins "const ::mlir::bufferization::BufferizationOptions &":$options, - "const ::mlir::bufferization::BufferizationState &":$state, "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError ) >, diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index 021a557f68b4b..85d1b5ac73bf4 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -164,7 +164,7 @@ struct SelectOpInterface // buffers have different types, they differ only in their layout map. Cast // both of them to the most dynamic MemRef type. if (trueBuffer.getType() != falseBuffer.getType()) { - auto targetType = bufferization::detail::castToMemRef( + auto targetType = bufferization::detail::asMemRefType( bufferization::getBufferType(selectOp.getResult(), options, state)); if (failed(targetType)) return failure(); @@ -188,10 +188,10 @@ struct SelectOpInterface auto selectOp = cast(op); assert(value == selectOp.getResult() && "invalid value"); auto trueType = - bufferization::detail::castToMemRef(bufferization::getBufferType( + bufferization::detail::asMemRefType(bufferization::getBufferType( selectOp.getTrueValue(), options, state, invocationStack)); auto falseType = - bufferization::detail::castToMemRef(bufferization::getBufferType( + bufferization::detail::asMemRefType(bufferization::getBufferType( selectOp.getFalseValue(), options, state, invocationStack)); if (failed(trueType) || failed(falseType)) return failure(); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index d33b853b8c203..2ab182c9b7b2e 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -214,7 +214,7 @@ FailureOr bufferization::allocateTensorForShapedValue( if (copy) return allocTensorOp.getResult(); auto copyBufferType = - detail::castToMemRef(getBufferType(tensor, options, state)); + detail::asMemRefType(getBufferType(tensor, options, state)); if (failed(copyBufferType)) return failure(); std::optional memorySpace = copyBufferType->getMemorySpace(); @@ -720,8 +720,9 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, return bufferizableOp.getBufferType(value, options, state, invocationStack); // Op is not bufferizable. - return cast(value.getType()) - .getBufferType(options, state, [&]() { return op->emitError(); }); + return cast(value.getType()).getBufferType(options, [&]() { + return op->emitError(); + }); } bool bufferization::hasTensorSemantics(Operation *op) { @@ -965,7 +966,7 @@ FailureOr bufferization::detail::defaultGetBufferType( // If the OpResult has an equivalent OpOperand, both OpResult and // OpOperand bufferize to the exact same buffer type. Value equivalentOperand = aliases.getAliases().front().opOperand->get(); - return castToMemRef(getBufferType(equivalentOperand, options, + return asMemRefType(getBufferType(equivalentOperand, options, bufferizationState, invocationStack)); } @@ -1043,19 +1044,15 @@ bool bufferization::detail::defaultHasTensorSemantics(Operation *op) { } FailureOr -bufferization::detail::castToMemRef(FailureOr bufferType) { +bufferization::detail::asMemRefType(FailureOr bufferType) { if (failed(bufferType)) return failure(); - assert(isa(*bufferType) && "expected memref type"); return cast(*bufferType); } bool bufferization::detail::typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer) { - assert(isa(tensor.getType()) && "expected TensorLikeType"); - assert(isa(buffer.getType()) && "expected BufferLikeType"); - return mlir::succeeded( cast(tensor.getType()) .verifyCompatibleBufferType(cast(buffer.getType()), diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index f92c1b4b18062..6c08cdfb669f3 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -60,10 +60,8 @@ struct BuiltinTensorExternalModel Tensor> { llvm::FailureOr getBufferType( mlir::Type tensor, const BufferizationOptions &options, - const BufferizationState &state, llvm::function_ref emitError) const { auto tensorType = cast(tensor); - // Fall back to tensor -> memref conversion. auto memSpace = options.defaultMemorySpaceFn(tensorType); if (!memSpace.has_value()) return emitError() << "could not infer memory space"; @@ -75,7 +73,6 @@ struct BuiltinTensorExternalModel mlir::LogicalResult verifyCompatibleBufferType( mlir::Type tensor, BufferLikeType bufferType, llvm::function_ref emitError) const { - // Fall back to tensor, memref checking. assert(isa(tensor) && "expected tensor type"); assert(isa(bufferType) && "expected memref type"); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 451446f35b105..9bd87d66c7d36 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -234,7 +234,7 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, memorySpace = *getMemorySpace(); } else if (getCopy()) { auto copyBufferType = - bufferization::detail::castToMemRef(bufferization::getBufferType( + bufferization::detail::asMemRefType(bufferization::getBufferType( getCopy(), options, state, invocationStack)); if (failed(copyBufferType)) return failure(); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index efa9fc1a070aa..58562536be61f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -293,7 +293,7 @@ struct IfOpInterface thenBufferType = cast(thenValue.getType()); } else { auto maybeBufferType = - bufferization::detail::castToMemRef(bufferization::getBufferType( + bufferization::detail::asMemRefType(bufferization::getBufferType( thenValue, options, state, invocationStack)); if (failed(maybeBufferType)) return failure(); @@ -304,7 +304,7 @@ struct IfOpInterface elseBufferType = cast(elseValue.getType()); } else { auto maybeBufferType = - bufferization::detail::castToMemRef(bufferization::getBufferType( + bufferization::detail::asMemRefType(bufferization::getBufferType( elseValue, options, state, invocationStack)); if (failed(maybeBufferType)) return failure(); @@ -408,7 +408,7 @@ struct IndexSwitchOpInterface return bufferType; auto maybeBufferType = bufferization::getBufferType( yieldedValue, options, state, invocationStack); - return bufferization::detail::castToMemRef(maybeBufferType); + return bufferization::detail::asMemRefType(maybeBufferType); }; // Compute buffer type of the default case. @@ -527,7 +527,7 @@ static FailureOr computeLoopRegionIterArgBufferType( const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) { // Determine the buffer type of the init_arg. - auto initArgBufferType = bufferization::detail::castToMemRef( + auto initArgBufferType = bufferization::detail::asMemRefType( bufferization::getBufferType(initArg, options, state, invocationStack)); if (failed(initArgBufferType)) return failure(); @@ -555,7 +555,7 @@ static FailureOr computeLoopRegionIterArgBufferType( // Note: This typically triggers a recursive call for the buffer type of // the iter_arg. auto maybeBufferType = - bufferization::detail::castToMemRef(bufferization::getBufferType( + bufferization::detail::asMemRefType(bufferization::getBufferType( yieldedValue, options, state, invocationStack)); if (failed(maybeBufferType)) return failure(); @@ -1083,7 +1083,7 @@ struct WhileOpInterface // scf.condition was already bufferized. return cast(conditionYieldedVal.getType()); } - return bufferization::detail::castToMemRef(bufferization::getBufferType( + return bufferization::detail::asMemRefType(bufferization::getBufferType( conditionYieldedVal, options, state, invocationStack)); } @@ -1312,13 +1312,13 @@ struct ForallOpInterface if (auto bbArg = dyn_cast(value)) // A tensor block argument has the same bufferized type as the // corresponding output operand. - return bufferization::detail::castToMemRef( + return bufferization::detail::asMemRefType( bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(), options, state, invocationStack)); // The bufferized result type is the same as the bufferized type of the // corresponding output operand. - return bufferization::detail::castToMemRef(bufferization::getBufferType( + return bufferization::detail::asMemRefType(bufferization::getBufferType( forallOp.getOutputs()[cast(value).getResultNumber()], options, state, invocationStack)); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 40b710f17fe44..729c048db4560 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -55,7 +55,7 @@ struct CastOpInterface SmallVector &invocationStack) const { auto castOp = cast(op); auto maybeSrcBufferType = - bufferization::detail::castToMemRef(bufferization::getBufferType( + bufferization::detail::asMemRefType(bufferization::getBufferType( castOp.getSource(), options, state, invocationStack)); if (failed(maybeSrcBufferType)) return failure(); @@ -501,7 +501,7 @@ struct FromElementsOpInterface /*copy=*/false); if (failed(tensorAlloc)) return failure(); - FailureOr memrefType = bufferization::detail::castToMemRef( + FailureOr memrefType = bufferization::detail::asMemRefType( bufferization::getBufferType(*tensorAlloc, options, state)); if (failed(memrefType)) return failure(); @@ -760,7 +760,7 @@ struct PadOpInterface // Infer memory space from the source tensor. auto padOp = cast(op); auto maybeSrcBufferType = - bufferization::detail::castToMemRef(bufferization::getBufferType( + bufferization::detail::asMemRefType(bufferization::getBufferType( padOp.getSource(), options, state, invocationStack)); if (failed(maybeSrcBufferType)) return failure(); diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 5594e71bf7ca0..03261f37c815d 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -432,7 +432,6 @@ def TestTensorType : Test_Type<"TestTensor", // TensorLikeTypeInterface: ::mlir::FailureOr<::mlir::bufferization::BufferLikeType> getBufferType(const ::mlir::bufferization::BufferizationOptions& options, - const ::mlir::bufferization::BufferizationState& state, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); ::mlir::LogicalResult verifyCompatibleBufferType( diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 8ee8ddb39f202..2fc2f90ef6bc0 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -549,7 +549,6 @@ TestTypeOpAsmTypeInterfaceType::getAlias(::llvm::raw_ostream &os) const { ::mlir::FailureOr<::mlir::bufferization::BufferLikeType> TestTensorType::getBufferType( const ::mlir::bufferization::BufferizationOptions &, - const ::mlir::bufferization::BufferizationState &, ::llvm::function_ref<::mlir::InFlightDiagnostic()>) { return cast( TestMemrefType::get(getContext(), getShape(), getElementType(), nullptr)); From 4d052ff415bedae65dc2c298f780d2ce2d540631 Mon Sep 17 00:00:00 2001 From: "Golubev, Andrey" Date: Wed, 18 Jun 2025 13:34:30 +0000 Subject: [PATCH 5/5] Update ToTensorOp::getType() to return TensorLikeType --- .../mlir/Dialect/Bufferization/IR/BufferizationOps.td | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index ec126a965ccbb..32c53ea9c494a 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -461,11 +461,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ let extraClassDeclaration = [{ /// The result of a to_tensor is always a tensor. - TensorType getType() { - Type resultType = getResult().getType(); - if (::llvm::isa(resultType)) - return ::llvm::cast(resultType); - return {}; + ::mlir::bufferization::TensorLikeType getType() { + return getResult().getType(); } //===------------------------------------------------------------------===//