diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 2fb795f16ae2c..c1529a36465ac 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -17,6 +17,7 @@ #include #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc" +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" namespace mlir { class OpBuilder; @@ -615,7 +616,7 @@ FailureOr getBuffer(RewriterBase &rewriter, Value value, /// IR, this function can be used. /// /// This function is a wrapper around BufferizableOpInterface::getBufferType. -FailureOr getBufferType(Value value, +FailureOr getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state); @@ -629,7 +630,7 @@ FailureOr getBufferType(Value value, /// IR, this function can be used. /// /// This function is a wrapper around `BufferizableOpInterface::getBufferType`. -FailureOr getBufferType(Value value, +FailureOr getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack); @@ -739,6 +740,19 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand); /// This is the default implementation of /// BufferizableOpInterface::hasTensorSemantics bool defaultHasTensorSemantics(Operation *op); + +/// This is a helper function used when buffer type is guaranteed to be memref. +/// It performs two actions: failure state checking and an explicit llvm::cast<> +/// from the buffer-like type interface to a BaseMemRefType. This allows easier +/// management of differences in C++ types at the API boundaries. Valid buffer +/// type is casted to the memref type. Otherwise, the failure state is +/// propagated i.e. asMemRefType(mlir::failure()) returns mlir::failure(). +FailureOr asMemRefType(FailureOr bufferType); + +/// This function is a free-standing helper that relies on +/// bufferization::TensorLikeTypeInterface to verify the types in tensor and +/// buffer worlds match. +bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer); } // namespace detail } // namespace bufferization diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index 6051aea849971..32c53ea9c494a 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -12,6 +12,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" +include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td" include "mlir/Dialect/Bufferization/IR/BufferizationBase.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -386,20 +387,31 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor", // ToTensorOp //===----------------------------------------------------------------------===// +class Bufferization_TensorAndBufferMatch : PredOpTrait< + "specified tensor and buffer types match", + CPred< + "::mlir::bufferization::detail::typesMatchAfterBufferization(" + "$_op, $" # tensor # ", $" # buffer #")" + > +>; + def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ BufferizableOpInterface, SameOperandsAndResultShape, SameOperandsAndResultElementType, - AllElementTypesMatch<["memref", "result"]> + Bufferization_TensorAndBufferMatch<"result", "buffer"> ]> { - let summary = "create a tensor from a `memref`"; + let summary = "create a buffer-like type from a tensor-like type"; let description = [{ - An operation that creates a tensor from a `memref`. The result value is a - tensor whose shape and element type match the memref operand. + An operation that creates a tensor from a buffer. The result value is a + tensor-like type that must match the corresponding buffer-like operand as + per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType + and BaseMemRefType), this means that shapes and element types match between + the tensor and the buffer. The opposite of this op is `to_buffer`. Together, these two ops are useful for source/target materializations when doing type conversions - involving tensors and memrefs. + involving tensors and buffers. Example: @@ -441,19 +453,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ away. However, such IR is no longer bufferizable with One-Shot Bufferize. }]; - let arguments = (ins Arg]>:$memref, + [MemReadAt<0, FullEffect>]>:$buffer, UnitAttr:$restrict, UnitAttr:$writable); - let results = (outs AnyTensor:$result); + let results = (outs Bufferization_TensorLikeTypeInterface:$result); let extraClassDeclaration = [{ /// The result of a to_tensor is always a tensor. - TensorType getType() { - Type resultType = getResult().getType(); - if (::llvm::isa(resultType)) - return ::llvm::cast(resultType); - return {}; + ::mlir::bufferization::TensorLikeType getType() { + return getResult().getType(); } //===------------------------------------------------------------------===// @@ -472,22 +481,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ FailureOr getBufferType( Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) { - return ::llvm::cast(getMemref().getType()); + return ::llvm::cast(getBuffer().getType()); } }]; let assemblyFormat = [{ - $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict - `:` type($memref) `to` type($result) + $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict + `:` type($buffer) `to` type($result) }]; - let builders = [ - OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{ - auto rtt = memref::getTensorTypeFromMemRefType(memref.getType()); - build($_builder, $_state, rtt, memref, restrict, writeable); - }]> - ]; - let hasCanonicalizer = 1; let hasFolder = 1; } @@ -502,10 +504,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [ SameOperandsAndResultShape, SameOperandsAndResultElementType, Pure, - AllShapesMatch<["memref", "tensor"]>, - AllElementTypesMatch<["memref", "tensor"]> + Bufferization_TensorAndBufferMatch<"tensor", "buffer"> ]> { - let summary = "cast a tensor to memref"; + let summary = "cast a tensor-like type to buffer-like type"; let description = [{ An operation that returns the future buffer of a `tensor`. @@ -523,8 +524,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [ the returned buffer) will not be written to. }]; - let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only); - let results = (outs AnyRankedOrUnrankedMemRef:$memref); + let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only); + let results = (outs Bufferization_BufferLikeTypeInterface:$buffer); let extraClassDeclaration = [{ //===------------------------------------------------------------------===// @@ -559,7 +560,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [ }]; let assemblyFormat = [{ - $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref) + $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer) }]; let hasFolder = 1; diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h index 5faa1479ee542..cbb6054fcf886 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h @@ -13,8 +13,15 @@ // Bufferization Type Interfaces //===----------------------------------------------------------------------===// +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Types.h" +namespace mlir::bufferization { +struct BufferizationOptions; +class BufferizationState; +class BufferLikeType; +} // namespace mlir::bufferization + #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc" #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td index f19224a295648..fb6fc4f5ad964 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td @@ -21,10 +21,30 @@ def Bufferization_TensorLikeTypeInterface let description = [{ Indicates that this type is a tensor type (similarly to a MLIR builtin tensor) for bufferization purposes. - - The interface currently has no methods as it is used by types to opt into - being supported by the bufferization procedures. }]; + + let methods = [ + InterfaceMethod<[{ + Returns a BufferLike type for this TensorLike type. + }], + /*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>", + /*methodName=*/"getBufferType", + /*args=*/(ins + "const ::mlir::bufferization::BufferizationOptions &":$options, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError + ) + >, + InterfaceMethod<[{ + Returns whether a BufferLike type is compatible to this TensorLike type. + The BufferLike type is assumed to be created by getBufferType(). + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"verifyCompatibleBufferType", + /*args=*/(ins + "::mlir::bufferization::BufferLikeType":$bufferType, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + > + ]; } def Bufferization_BufferLikeTypeInterface diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h index a441b8b66659e..f56c10555f02c 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h @@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel // The operand was already bufferized. Take its type directly. callerType = memrefType; } else { - FailureOr maybeCallerType = + FailureOr maybeCallerType = bufferization::getBufferType(opOperand->get(), options, state, invocationStack); if (failed(maybeCallerType)) return failure(); - callerType = *maybeCallerType; + assert(isa(*maybeCallerType) && "expected memref type"); + callerType = cast(*maybeCallerType); } if (!bufferType) { diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index a57d58ab28d28..85d1b5ac73bf4 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -164,8 +164,8 @@ struct SelectOpInterface // buffers have different types, they differ only in their layout map. Cast // both of them to the most dynamic MemRef type. if (trueBuffer.getType() != falseBuffer.getType()) { - auto targetType = - bufferization::getBufferType(selectOp.getResult(), options, state); + auto targetType = bufferization::detail::asMemRefType( + bufferization::getBufferType(selectOp.getResult(), options, state)); if (failed(targetType)) return failure(); if (trueBuffer.getType() != *targetType) @@ -187,10 +187,12 @@ struct SelectOpInterface SmallVector &invocationStack) const { auto selectOp = cast(op); assert(value == selectOp.getResult() && "invalid value"); - auto trueType = bufferization::getBufferType( - selectOp.getTrueValue(), options, state, invocationStack); - auto falseType = bufferization::getBufferType( - selectOp.getFalseValue(), options, state, invocationStack); + auto trueType = + bufferization::detail::asMemRefType(bufferization::getBufferType( + selectOp.getTrueValue(), options, state, invocationStack)); + auto falseType = + bufferization::detail::asMemRefType(bufferization::getBufferType( + selectOp.getFalseValue(), options, state, invocationStack)); if (failed(trueType) || failed(falseType)) return failure(); if (*trueType == *falseType) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index dd43647682ea2..2ab182c9b7b2e 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -171,7 +171,9 @@ FailureOr bufferization::allocateTensorForShapedValue( if (llvm::isa(shapedValue.getType())) { tensor = shapedValue; } else if (llvm::isa(shapedValue.getType())) { - tensor = b.create(loc, shapedValue); + tensor = b.create( + loc, memref::getTensorTypeFromMemRefType(shapedValue.getType()), + shapedValue); } else if (llvm::isa(shapedValue.getType()) || llvm::isa(shapedValue.getType())) { return getOwnerOfValue(shapedValue) @@ -211,8 +213,8 @@ FailureOr bufferization::allocateTensorForShapedValue( // Add 'memory_space' attribute. Not needed if 'copy' operand is specified. if (copy) return allocTensorOp.getResult(); - FailureOr copyBufferType = - getBufferType(tensor, options, state); + auto copyBufferType = + detail::asMemRefType(getBufferType(tensor, options, state)); if (failed(copyBufferType)) return failure(); std::optional memorySpace = copyBufferType->getMemorySpace(); @@ -672,28 +674,28 @@ FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state) { #ifndef NDEBUG - auto tensorType = llvm::dyn_cast(value.getType()); + auto tensorType = llvm::dyn_cast(value.getType()); assert(tensorType && "unexpected non-tensor type"); #endif // NDEBUG // Replace "%t = to_tensor %m" with %m. if (auto toTensorOp = value.getDefiningOp()) - return toTensorOp.getMemref(); + return toTensorOp.getBuffer(); // Insert to_buffer op. OpBuilder::InsertionGuard g(rewriter); setInsertionPointAfter(rewriter, value); - FailureOr memrefType = getBufferType(value, options, state); - if (failed(memrefType)) + FailureOr bufferType = getBufferType(value, options, state); + if (failed(bufferType)) return failure(); - ensureToBufferOpIsValid(value, *memrefType); + ensureToBufferOpIsValid(value, *bufferType); return rewriter - .create(value.getLoc(), *memrefType, value) + .create(value.getLoc(), *bufferType, value) .getResult(); } /// Return the buffer type for a given Value (tensor) after bufferization. -FailureOr +FailureOr bufferization::getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state) { SmallVector invocationStack; @@ -701,11 +703,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, } /// Return the buffer type for a given Value (tensor) after bufferization. -FailureOr +FailureOr bufferization::getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) { - assert(llvm::isa(value.getType()) && + assert(llvm::isa(value.getType()) && "unexpected non-tensor type"); invocationStack.push_back(value); auto popFromStack = @@ -718,13 +720,9 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, return bufferizableOp.getBufferType(value, options, state, invocationStack); // Op is not bufferizable. - auto memSpace = - options.defaultMemorySpaceFn(cast(value.getType())); - if (!memSpace.has_value()) - return op->emitError("could not infer memory space"); - - return getMemRefType(cast(value.getType()), options, - /*layout=*/{}, *memSpace); + return cast(value.getType()).getBufferType(options, [&]() { + return op->emitError(); + }); } bool bufferization::hasTensorSemantics(Operation *op) { @@ -744,12 +742,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, SmallVector replacements; for (OpResult opResult : op->getOpResults()) { Value replacement = values[opResult.getResultNumber()]; - if (llvm::isa(opResult.getType())) { + if (llvm::isa(opResult.getType())) { // The OpResult is a tensor. Such values are replaced with memrefs during // bufferization. - assert((llvm::isa(replacement.getType()) || - llvm::isa(replacement.getType())) && - "tensor op result should be replaced with a memref value"); + assert(llvm::isa(replacement.getType()) && + "tensor op result should be replaced with a buffer value"); // The existing uses of the OpResult still expect a tensor. Insert a // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually // loose all of its users and eventually DCE away. @@ -969,8 +966,8 @@ FailureOr bufferization::detail::defaultGetBufferType( // If the OpResult has an equivalent OpOperand, both OpResult and // OpOperand bufferize to the exact same buffer type. Value equivalentOperand = aliases.getAliases().front().opOperand->get(); - return getBufferType(equivalentOperand, options, bufferizationState, - invocationStack); + return asMemRefType(getBufferType(equivalentOperand, options, + bufferizationState, invocationStack)); } // If we do not know the memory space and there is no default memory space, @@ -1030,7 +1027,7 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) { } bool bufferization::detail::defaultHasTensorSemantics(Operation *op) { - auto isaTensor = [](Type t) { return isa(t); }; + auto isaTensor = [](Type t) { return isa(t); }; bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) { return any_of(r.getBlocks(), [&](Block &b) { return any_of(b.getArguments(), [&](BlockArgument bbArg) { @@ -1045,3 +1042,19 @@ bool bufferization::detail::defaultHasTensorSemantics(Operation *op) { return true; return any_of(op->getOperandTypes(), isaTensor); } + +FailureOr +bufferization::detail::asMemRefType(FailureOr bufferType) { + if (failed(bufferType)) + return failure(); + return cast(*bufferType); +} + +bool bufferization::detail::typesMatchAfterBufferization(Operation &op, + Value tensor, + Value buffer) { + return mlir::succeeded( + cast(tensor.getType()) + .verifyCompatibleBufferType(cast(buffer.getType()), + [&]() { return op.emitError(); })); +} diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index d8eac01c2dea0..6c08cdfb669f3 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -57,7 +57,37 @@ struct BufferizationInlinerInterface : public DialectInlinerInterface { template struct BuiltinTensorExternalModel : TensorLikeType::ExternalModel, - Tensor> {}; + Tensor> { + llvm::FailureOr getBufferType( + mlir::Type tensor, const BufferizationOptions &options, + llvm::function_ref emitError) const { + auto tensorType = cast(tensor); + auto memSpace = options.defaultMemorySpaceFn(tensorType); + if (!memSpace.has_value()) + return emitError() << "could not infer memory space"; + + return cast( + getMemRefType(tensorType, options, /*layout=*/{}, *memSpace)); + } + + mlir::LogicalResult verifyCompatibleBufferType( + mlir::Type tensor, BufferLikeType bufferType, + llvm::function_ref emitError) const { + assert(isa(tensor) && "expected tensor type"); + assert(isa(bufferType) && "expected memref type"); + + auto tensorType = cast(tensor); + auto memrefType = cast(bufferType); + + if (tensorType.getShape() != memrefType.getShape()) + return emitError() << "shapes do not match"; + + if (tensorType.getElementType() != memrefType.getElementType()) + return emitError() << "element types do not match"; + + return mlir::success(); + } +}; template struct BuiltinMemRefExternalModel diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index dc54ac94aed32..9bd87d66c7d36 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -90,12 +90,12 @@ LogicalResult mlir::bufferization::foldToBufferToTensorPair( if (!bufferToTensor) return failure(); - Type srcType = bufferToTensor.getMemref().getType(); + Type srcType = bufferToTensor.getBuffer().getType(); Type destType = toBuffer.getType(); // Directly rewrite if the type did not change. if (srcType == destType) { - rewriter.replaceOp(toBuffer, bufferToTensor.getMemref()); + rewriter.replaceOp(toBuffer, bufferToTensor.getBuffer()); return success(); } @@ -106,7 +106,7 @@ LogicalResult mlir::bufferization::foldToBufferToTensorPair( // Ranked memref -> Ranked memref cast. if (rankedSrcType && rankedDestType) { FailureOr replacement = castOrReallocMemRefValue( - rewriter, bufferToTensor.getMemref(), rankedDestType, options); + rewriter, bufferToTensor.getBuffer(), rankedDestType, options); if (failed(replacement)) return failure(); @@ -124,7 +124,7 @@ LogicalResult mlir::bufferization::foldToBufferToTensorPair( assert(memref::CastOp::areCastCompatible(srcType, destType) && "expected that types are cast compatible"); rewriter.replaceOpWithNewOp(toBuffer, destType, - bufferToTensor.getMemref()); + bufferToTensor.getBuffer()); return success(); } @@ -233,8 +233,9 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, if (getMemorySpace().has_value()) { memorySpace = *getMemorySpace(); } else if (getCopy()) { - auto copyBufferType = bufferization::getBufferType(getCopy(), options, - state, invocationStack); + auto copyBufferType = + bufferization::detail::asMemRefType(bufferization::getBufferType( + getCopy(), options, state, invocationStack)); if (failed(copyBufferType)) return failure(); memorySpace = copyBufferType->getMemorySpace(); @@ -642,8 +643,9 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder, assert(getRestrict() && "expected that ops with memrefs dest have 'restrict'"); setRestrict(false); - return builder.create(loc, getDest(), /*restrict=*/true, - getWritable()); + return builder.create( + loc, memref::getTensorTypeFromMemRefType(getDest().getType()), getDest(), + /*restrict=*/true, getWritable()); } bool MaterializeInDestinationOp::isEquivalentSubset( @@ -744,7 +746,7 @@ bool ToTensorOp::isWritable(Value value, const AnalysisState &state) { } OpFoldResult ToTensorOp::fold(FoldAdaptor) { - if (auto toBuffer = getMemref().getDefiningOp()) + if (auto toBuffer = getBuffer().getDefiningOp()) // Approximate alias analysis by conservatively folding only when no there // is no interleaved operation. if (toBuffer->getBlock() == this->getOperation()->getBlock() && @@ -764,7 +766,7 @@ struct DimOfToTensorFolder : public OpRewritePattern { return failure(); rewriter.replaceOpWithNewOp( - dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex()); + dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex()); return success(); } }; @@ -781,8 +783,8 @@ void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, OpFoldResult ToBufferOp::fold(FoldAdaptor) { if (auto memrefToTensor = getTensor().getDefiningOp()) - if (memrefToTensor.getMemref().getType() == getType()) - return memrefToTensor.getMemref(); + if (memrefToTensor.getBuffer().getType() == getType()) + return memrefToTensor.getBuffer(); return {}; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp new file mode 100644 index 0000000000000..0e973915c6fc9 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp @@ -0,0 +1,21 @@ +//===- BufferizationTypeInterfaces.cpp - Type Interfaces --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" + +//===----------------------------------------------------------------------===// +// Bufferization Type Interfaces +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace bufferization { + +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp.inc" + +} // namespace bufferization +} // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt index 63dcc1eb233e9..5d8f0060f2c3f 100644 --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect BufferizationDialect.cpp BufferViewFlowOpInterface.cpp UnstructuredControlFlow.cpp + BufferizationTypeInterfaces.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 7e9b9119ce949..6472ef3eff2ac 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -412,11 +412,11 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, continue; } - FailureOr memrefType = + FailureOr bufferType = bufferization::getBufferType(bbArg, options, state); - if (failed(memrefType)) + if (failed(bufferType)) return failure(); - newTypes.push_back(*memrefType); + newTypes.push_back(*bufferType); } // Change the type of all block arguments. @@ -463,7 +463,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, newOperands.push_back(operand); continue; } - FailureOr operandBufferType = + FailureOr operandBufferType = bufferization::getBufferType(operand, options, state); if (failed(operandBufferType)) return failure(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index a0168da44b7b3..453ed43bcadd2 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -255,7 +255,7 @@ struct CallOpInterface } // Returning a memref. - FailureOr resultType = + FailureOr resultType = bufferization::getBufferType(result, options, state); if (failed(resultType)) return failure(); @@ -290,13 +290,13 @@ struct CallOpInterface // The called function was not bufferized yet. This can happen when // there cycles in the function call graph. Compute the bufferized // result type. - FailureOr maybeMemRefType = + FailureOr maybeBufferType = bufferization::getBufferType( funcOp.getArgument(opOperand.getOperandNumber()), options, state); - if (failed(maybeMemRefType)) + if (failed(maybeBufferType)) return failure(); - memRefType = *maybeMemRefType; + memRefType = *maybeBufferType; } // Since we don't yet have a clear layout story, to_buffer may diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 94a4b9011c16b..573420f6a9aa9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -252,7 +252,8 @@ Value linalg::bufferizeToAllocation( // Create bufferization.to_tensor with "restrict" and "writable". The returned // tensor is a new buffer allocation, so it does not alias with any buffer. Value toTensorOp = rewriter.create( - loc, alloc, /*restrict=*/true, /*writable=*/true); + loc, padOp.getResult().getType(), alloc, /*restrict=*/true, + /*writable=*/true); rewriter.replaceOp(padOp, toTensorOp); return alloc; } @@ -340,7 +341,8 @@ Value linalg::bufferizeToAllocation( // Create bufferization.to_tensor with "restrict" and "writable". The returned // tensor is a new buffer allocation, so it does not alias with any buffer. Value toTensorOp = rewriter.create( - loc, alloc, /*restrict=*/true, /*writable=*/true); + loc, allocTensorOp.getResult().getType(), alloc, /*restrict=*/true, + /*writable=*/true); rewriter.replaceOp(allocTensorOp, toTensorOp); return alloc; } @@ -567,7 +569,8 @@ Value linalg::bufferizeToAllocation( createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options); } rewriter.modifyOpInPlace(op, [&]() { - auto toTensorOp = rewriter.create(op->getLoc(), alloc); + auto toTensorOp = rewriter.create( + op->getLoc(), operand->get().getType(), alloc); operand->set(toTensorOp); if (options.bufferizeDestinationOnly) { rewriter.modifyOpInPlace(toTensorOp, [&]() { diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 46fa77a7dc4e6..58562536be61f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -108,7 +108,7 @@ struct ConditionOpInterface getBuffer(rewriter, value, options, state); if (failed(maybeBuffer)) return failure(); - FailureOr resultType = bufferization::getBufferType( + FailureOr resultType = bufferization::getBufferType( whileOp.getAfterArguments()[it.index()], options, state); if (failed(resultType)) return failure(); @@ -292,8 +292,9 @@ struct IfOpInterface // True branch was already bufferized. thenBufferType = cast(thenValue.getType()); } else { - auto maybeBufferType = bufferization::getBufferType( - thenValue, options, state, invocationStack); + auto maybeBufferType = + bufferization::detail::asMemRefType(bufferization::getBufferType( + thenValue, options, state, invocationStack)); if (failed(maybeBufferType)) return failure(); thenBufferType = *maybeBufferType; @@ -302,8 +303,9 @@ struct IfOpInterface // False branch was already bufferized. elseBufferType = cast(elseValue.getType()); } else { - auto maybeBufferType = bufferization::getBufferType( - elseValue, options, state, invocationStack); + auto maybeBufferType = + bufferization::detail::asMemRefType(bufferization::getBufferType( + elseValue, options, state, invocationStack)); if (failed(maybeBufferType)) return failure(); elseBufferType = *maybeBufferType; @@ -406,9 +408,7 @@ struct IndexSwitchOpInterface return bufferType; auto maybeBufferType = bufferization::getBufferType( yieldedValue, options, state, invocationStack); - if (failed(maybeBufferType)) - return failure(); - return maybeBufferType; + return bufferization::detail::asMemRefType(maybeBufferType); }; // Compute buffer type of the default case. @@ -527,8 +527,8 @@ static FailureOr computeLoopRegionIterArgBufferType( const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) { // Determine the buffer type of the init_arg. - auto initArgBufferType = - bufferization::getBufferType(initArg, options, state, invocationStack); + auto initArgBufferType = bufferization::detail::asMemRefType( + bufferization::getBufferType(initArg, options, state, invocationStack)); if (failed(initArgBufferType)) return failure(); @@ -554,8 +554,9 @@ static FailureOr computeLoopRegionIterArgBufferType( } else { // Note: This typically triggers a recursive call for the buffer type of // the iter_arg. - auto maybeBufferType = bufferization::getBufferType(yieldedValue, options, - state, invocationStack); + auto maybeBufferType = + bufferization::detail::asMemRefType(bufferization::getBufferType( + yieldedValue, options, state, invocationStack)); if (failed(maybeBufferType)) return failure(); yieldedValueBufferType = *maybeBufferType; @@ -718,8 +719,12 @@ struct ForOpInterface if (auto opResult = dyn_cast(value)) { // The type of an OpResult must match the corresponding iter_arg type. BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); - return bufferization::getBufferType(bbArg, options, state, - invocationStack); + auto bufferType = + bufferization::getBufferType(bbArg, options, state, invocationStack); + if (failed(bufferType)) + return failure(); + assert(isa(*bufferType) && "expected memref type"); + return cast(*bufferType); } // Compute result/argument number. @@ -1078,8 +1083,8 @@ struct WhileOpInterface // scf.condition was already bufferized. return cast(conditionYieldedVal.getType()); } - return bufferization::getBufferType(conditionYieldedVal, options, state, - invocationStack); + return bufferization::detail::asMemRefType(bufferization::getBufferType( + conditionYieldedVal, options, state, invocationStack)); } /// Assert that yielded values of an scf.while op are equivalent to their @@ -1185,14 +1190,14 @@ struct YieldOpInterface // We may have to cast the value before yielding it. if (isa( yieldOp->getParentOp())) { - FailureOr resultType = bufferization::getBufferType( + FailureOr resultType = bufferization::getBufferType( yieldOp->getParentOp()->getResult(it.index()), options, state); if (failed(resultType)) return failure(); buffer = castBuffer(rewriter, buffer, *resultType); } else if (auto whileOp = dyn_cast(yieldOp->getParentOp())) { - FailureOr resultType = bufferization::getBufferType( + FailureOr resultType = bufferization::getBufferType( whileOp.getBeforeArguments()[it.index()], options, state); if (failed(resultType)) return failure(); @@ -1307,15 +1312,15 @@ struct ForallOpInterface if (auto bbArg = dyn_cast(value)) // A tensor block argument has the same bufferized type as the // corresponding output operand. - return bufferization::getBufferType( - forallOp.getTiedOpOperand(bbArg)->get(), options, state, - invocationStack); + return bufferization::detail::asMemRefType( + bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(), + options, state, invocationStack)); // The bufferized result type is the same as the bufferized type of the // corresponding output operand. - return bufferization::getBufferType( + return bufferization::detail::asMemRefType(bufferization::getBufferType( forallOp.getOutputs()[cast(value).getResultNumber()], options, - state, invocationStack); + state, invocationStack)); } bool isRepetitiveRegion(Operation *op, unsigned index) const { diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp index dc91117a51936..8a471c12d21e4 100644 --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -67,7 +67,7 @@ struct AssumingOpInterface for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) { if (isa(it.value())) { newResults.push_back(rewriter.create( - assumingOp.getLoc(), newOp->getResult(it.index()))); + assumingOp.getLoc(), it.value(), newOp->getResult(it.index()))); } else { newResults.push_back(newOp->getResult(it.index())); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp index e5f2418367a58..e89b34d457ff8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -651,7 +651,7 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter, tokens.clear(); // Done. - rewriter.replaceOpWithNewOp(op, memY); + rewriter.replaceOpWithNewOp(op, y.getType(), memY); return success(); } @@ -752,7 +752,7 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter, tokens.clear(); // Done. - rewriter.replaceOpWithNewOp(op, bufC); + rewriter.replaceOpWithNewOp(op, c.getType(), bufC); return success(); } @@ -925,9 +925,12 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, tokens.clear(); // Done. - Value vt = rewriter.create(loc, valH); - Value rt = rewriter.create(loc, rowH); - Value ct = rewriter.create(loc, colH); + Value vt = rewriter.create( + loc, memref::getTensorTypeFromMemRefType(valH.getType()), valH); + Value rt = rewriter.create( + loc, memref::getTensorTypeFromMemRefType(rowH.getType()), rowH); + Value ct = rewriter.create( + loc, memref::getTensorTypeFromMemRefType(colH.getType()), colH); rewriter.replaceOpWithNewOp(op, c.getType(), ValueRange{rt, ct}, vt); return success(); @@ -1043,7 +1046,7 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, tokens.clear(); // Done. - rewriter.replaceOpWithNewOp(op, bufC); + rewriter.replaceOpWithNewOp(op, C.getType(), bufC); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index e5f9717c3fbaa..14ced56b8365f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1471,7 +1471,8 @@ struct SparseDisassembleOpConverter // Converts MemRefs back to Tensors. SmallVector retValues = llvm::to_vector( llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value { - return rewriter.create(loc, v); + return rewriter.create( + loc, memref::getTensorTypeFromMemRefType(v.getType()), v); })); // Appends the actual memory length used in each buffer returned. retValues.append(retLen.begin(), retLen.end()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 9ffa64dc821d8..7f0b657687442 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -867,7 +867,9 @@ class SparseTensorDisassembleConverter // Converts MemRefs back to Tensors. assert(retVal.size() + retLen.size() == op.getNumResults()); for (unsigned i = 0, sz = retVal.size(); i < sz; i++) { - auto tensor = rewriter.create(loc, retVal[i]); + auto tensor = rewriter.create( + loc, memref::getTensorTypeFromMemRefType(retVal[i].getType()), + retVal[i]); retVal[i] = rewriter.create(loc, op.getResultTypes()[i], tensor); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp index 57291064eba22..1bd9563b3db07 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp @@ -549,8 +549,8 @@ TypedValue sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { auto tTp = llvm::cast(tensor.getType()); auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType()); - return builder.create(loc, mTp, tensor) - .getResult(); + return cast>( + builder.create(loc, mTp, tensor).getResult()); } Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 4b778b768d136..729c048db4560 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -54,8 +54,9 @@ struct CastOpInterface const BufferizationState &state, SmallVector &invocationStack) const { auto castOp = cast(op); - auto maybeSrcBufferType = bufferization::getBufferType( - castOp.getSource(), options, state, invocationStack); + auto maybeSrcBufferType = + bufferization::detail::asMemRefType(bufferization::getBufferType( + castOp.getSource(), options, state, invocationStack)); if (failed(maybeSrcBufferType)) return failure(); Attribute memorySpace = maybeSrcBufferType->getMemorySpace(); @@ -500,8 +501,8 @@ struct FromElementsOpInterface /*copy=*/false); if (failed(tensorAlloc)) return failure(); - FailureOr memrefType = - bufferization::getBufferType(*tensorAlloc, options, state); + FailureOr memrefType = bufferization::detail::asMemRefType( + bufferization::getBufferType(*tensorAlloc, options, state)); if (failed(memrefType)) return failure(); Value buffer = rewriter.create( @@ -758,8 +759,9 @@ struct PadOpInterface SmallVector &invocationStack) const { // Infer memory space from the source tensor. auto padOp = cast(op); - auto maybeSrcBufferType = bufferization::getBufferType( - padOp.getSource(), options, state, invocationStack); + auto maybeSrcBufferType = + bufferization::detail::asMemRefType(bufferization::getBufferType( + padOp.getSource(), options, state, invocationStack)); if (failed(maybeSrcBufferType)) return failure(); MemRefLayoutAttrInterface layout; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir index cd19e3a5e82aa..da3c26ce36ba5 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -268,4 +268,23 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x %r = tensor.extract %dest_filled[%idx] : tensor<5xf32> return %0, %r : tensor<5xf32>, f32 -} \ No newline at end of file +} + +// ----- + +// CHECK-LABEL: func.func @test_dialect_op( +// CHECK-SAME: %[[ARG:.*]]: !test.test_tensor<[32, 64], f64> +// CHECK-SAME: ) -> !test.test_tensor<[32, 128], f64> { +func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>) + -> !test.test_tensor<[32, 128], f64> { + // CHECK: %[[MEMREF:.*]] = bufferization.to_buffer %[[ARG]] + // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]]) + // CHECK-SAME: : (!test.test_memref<[32, 64], f64>) + // CHECK-SAME: -> !test.test_memref<[32, 128], f64> + // CHECK: %[[OUT:.*]] = bufferization.to_tensor %[[DUMMY]] + %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[32, 64], f64>) + -> !test.test_tensor<[32, 128], f64> + + // CHECK: return %[[OUT]] + return %out : !test.test_tensor<[32, 128], f64> +} diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index b5a8bd10d6b68..78e44c6ec7a9b 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -8,6 +8,7 @@ #include "TestDialect.h" #include "TestOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/FunctionImplementation.h" @@ -1387,3 +1388,25 @@ TestMultiSlotAlloca::handleDestructuringComplete( const DestructurableMemorySlot &slot, OpBuilder &builder) { return createNewMultiAllocaWithoutSlot(slot, builder, *this); } + +::mlir::LogicalResult test::TestDummyTensorOp::bufferize( + ::mlir::RewriterBase &rewriter, + const ::mlir::bufferization::BufferizationOptions &options, + ::mlir::bufferization::BufferizationState &state) { + auto buffer = + mlir::bufferization::getBuffer(rewriter, getInput(), options, state); + if (mlir::failed(buffer)) + return failure(); + + const auto outType = getOutput().getType(); + const auto bufferizedOutType = test::TestMemrefType::get( + getContext(), outType.getShape(), outType.getElementType(), nullptr); + // replace op with memref analogy + auto dummyMemrefOp = rewriter.create( + getLoc(), bufferizedOutType, *buffer); + + mlir::bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(), + dummyMemrefOp.getResult()); + + return mlir::success(); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h index c2ee5f9ab9a57..b414b47c87425 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.h +++ b/mlir/test/lib/Dialect/Test/TestOps.h @@ -13,6 +13,7 @@ #include "TestInterfaces.h" #include "TestTypes.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/Dialect/Func/IR/FuncOps.h" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 59330fdb1bb2c..79bcd9c2e0a9a 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -31,7 +31,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" - +include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" // Include the attribute definitions. include "TestAttrDefs.td" @@ -2825,7 +2825,7 @@ def TestNVVMRequiresSMArchCondOp : let assemblyFormat = "attr-dict"; } -def TestNVVMRequirestSMArchCondMultiOp : +def TestNVVMRequirestSMArchCondMultiOp : TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMa<[90, 100]>]> { let arguments = (ins ); let assemblyFormat = "attr-dict"; @@ -3552,4 +3552,58 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> { }]; } +//===----------------------------------------------------------------------===// +// Test Ops bufferization +//===----------------------------------------------------------------------===// + +def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", [BufferizableOpInterface]> { + let arguments = (ins + Arg:$input + ); + let results = (outs + Arg:$output + ); + let extraClassDeclaration = [{ + // BufferizableOpInterface + bool bufferizesToMemoryRead(mlir::OpOperand&, + const mlir::bufferization::AnalysisState&); + + bool bufferizesToMemoryWrite(mlir::OpOperand&, + const mlir::bufferization::AnalysisState&); + + mlir::bufferization::AliasingValueList getAliasingValues(mlir::OpOperand&, + const mlir::bufferization::AnalysisState&); + + mlir::LogicalResult bufferize( + mlir::RewriterBase& rewriter, + const mlir::bufferization::BufferizationOptions& options, + mlir::bufferization::BufferizationState &state); + }]; + + let extraClassDefinition = [{ + bool test::TestDummyTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&, + const ::mlir::bufferization::AnalysisState&) { + return true; + } + bool test::TestDummyTensorOp::bufferizesToMemoryWrite(::mlir::OpOperand&, + const ::mlir::bufferization::AnalysisState&) { + return true; + } + ::mlir::bufferization::AliasingValueList + test::TestDummyTensorOp::getAliasingValues(::mlir::OpOperand&, + const ::mlir::bufferization::AnalysisState&) { + return {}; + } + }]; +} + +def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> { + let arguments = (ins + Arg:$input + ); + let results = (outs + Arg:$output + ); +} + #endif // TEST_OPS diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 09294e84960f2..03261f37c815d 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -428,6 +428,15 @@ def TestTensorType : Test_Type<"TestTensor", return test::TestTensorType::get( getContext(), shape.value_or(getShape()), elementType); } + + // TensorLikeTypeInterface: + ::mlir::FailureOr<::mlir::bufferization::BufferLikeType> + getBufferType(const ::mlir::bufferization::BufferizationOptions& options, + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + + ::mlir::LogicalResult verifyCompatibleBufferType( + ::mlir::bufferization::BufferLikeType bufferType, + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); }]; } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 5c784dcee6e15..2fc2f90ef6bc0 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -545,3 +545,23 @@ TestTypeOpAsmTypeInterfaceType::getAlias(::llvm::raw_ostream &os) const { os << "op_asm_type_interface_type"; return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias; } + +::mlir::FailureOr<::mlir::bufferization::BufferLikeType> +TestTensorType::getBufferType( + const ::mlir::bufferization::BufferizationOptions &, + ::llvm::function_ref<::mlir::InFlightDiagnostic()>) { + return cast( + TestMemrefType::get(getContext(), getShape(), getElementType(), nullptr)); +} + +::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType( + ::mlir::bufferization::BufferLikeType bufferType, + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + auto testMemref = dyn_cast(bufferType); + if (!testMemref) + return emitError() << "expected TestMemrefType"; + + const bool valid = getShape() == testMemref.getShape() && + getElementType() == testMemref.getElementType(); + return mlir::success(valid); +}