-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][bufferization] Support custom types (1/N) #142986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
be85978
7ef1183
b05a291
2de5728
4d052ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" | ||
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td" | ||
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" | ||
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td" | ||
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td" | ||
include "mlir/Interfaces/DestinationStyleOpInterface.td" | ||
include "mlir/Interfaces/InferTypeOpInterface.td" | ||
|
@@ -386,20 +387,31 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor", | |
// ToTensorOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait< | ||
"specified tensor and buffer types match", | ||
CPred< | ||
"::mlir::bufferization::detail::typesMatchAfterBufferization(" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @matthias-springer this would be the other problematic place. With the current approach, we want to validate that bufferization is "valid" on a tensor <-> buffer level. The current logic checks Instead, I think we should either restore the old comparison logic (which was changed in ced2fc7) or - more likely - have this put into an interface so that it's a customization point. But then, which interface? TensorLike? BufferLike? Since it's a type matching function, it's kind of valid to be in both. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about some kind of double dispatch? E.g., for
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wouldn't go with double dispatch to be honest because there's majorly no difference between "tensor equivalent to buffer" vs "buffer equivalent to tensor" (we have both things which do not change between the two calls). For the time being, I guess we just put it somewhere? (either to buffer-like or to tensor-like). Perhaps with more changes it would be clearer what to do here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sry, double dispatch is the wrong name. It's more like querying both interfaces. The reason why I'm suggesting this is to support custom conversions for builtin types. The type interface implementation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see what you mean now. It's actually more (or less?) straightforward:
there's never really a situation where we'd bufferize builtin into non-builtin or non-builtin into builtin, but the case is interesting (I'd perhaps add support for this separately if that has any use). edit: I guess for 2. we can keep the "default" logic - which is what's on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Implemented the half of the querying for TensorLikeType. |
||
"$_op, $" # tensor # ", $" # buffer #")" | ||
> | ||
>; | ||
|
||
def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ | ||
BufferizableOpInterface, | ||
SameOperandsAndResultShape, | ||
SameOperandsAndResultElementType, | ||
AllElementTypesMatch<["memref", "result"]> | ||
Bufferization_TensorAndBufferMatch<"result", "buffer"> | ||
]> { | ||
let summary = "create a tensor from a `memref`"; | ||
let summary = "create a buffer-like type from a tensor-like type"; | ||
let description = [{ | ||
An operation that creates a tensor from a `memref`. The result value is a | ||
tensor whose shape and element type match the memref operand. | ||
An operation that creates a tensor from a buffer. The result value is a | ||
tensor-like type that must match the corresponding buffer-like operand as | ||
per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType | ||
and BaseMemRefType), this means that shapes and element types match between | ||
the tensor and the buffer. | ||
|
||
The opposite of this op is `to_buffer`. Together, these two ops are | ||
useful for source/target materializations when doing type conversions | ||
involving tensors and memrefs. | ||
involving tensors and buffers. | ||
|
||
Example: | ||
|
||
|
@@ -441,19 +453,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ | |
away. However, such IR is no longer bufferizable with One-Shot Bufferize. | ||
}]; | ||
|
||
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, | ||
let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface, | ||
"the reference to load from", | ||
[MemReadAt<0, FullEffect>]>:$memref, | ||
[MemReadAt<0, FullEffect>]>:$buffer, | ||
UnitAttr:$restrict, UnitAttr:$writable); | ||
let results = (outs AnyTensor:$result); | ||
let results = (outs Bufferization_TensorLikeTypeInterface:$result); | ||
|
||
let extraClassDeclaration = [{ | ||
/// The result of a to_tensor is always a tensor. | ||
andrey-golubev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
TensorType getType() { | ||
Type resultType = getResult().getType(); | ||
if (::llvm::isa<TensorType>(resultType)) | ||
return ::llvm::cast<TensorType>(resultType); | ||
return {}; | ||
::mlir::bufferization::TensorLikeType getType() { | ||
return getResult().getType(); | ||
} | ||
|
||
//===------------------------------------------------------------------===// | ||
|
@@ -472,22 +481,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ | |
FailureOr<BaseMemRefType> getBufferType( | ||
Value value, const BufferizationOptions &options, | ||
const BufferizationState &state, SmallVector<Value> &invocationStack) { | ||
return ::llvm::cast<BaseMemRefType>(getMemref().getType()); | ||
return ::llvm::cast<BaseMemRefType>(getBuffer().getType()); | ||
} | ||
}]; | ||
|
||
let assemblyFormat = [{ | ||
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict | ||
`:` type($memref) `to` type($result) | ||
$buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict | ||
`:` type($buffer) `to` type($result) | ||
}]; | ||
|
||
let builders = [ | ||
OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{ | ||
auto rtt = memref::getTensorTypeFromMemRefType(memref.getType()); | ||
build($_builder, $_state, rtt, memref, restrict, writeable); | ||
}]> | ||
]; | ||
|
||
let hasCanonicalizer = 1; | ||
let hasFolder = 1; | ||
} | ||
|
@@ -502,10 +504,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [ | |
SameOperandsAndResultShape, | ||
SameOperandsAndResultElementType, | ||
Pure, | ||
AllShapesMatch<["memref", "tensor"]>, | ||
AllElementTypesMatch<["memref", "tensor"]> | ||
Bufferization_TensorAndBufferMatch<"tensor", "buffer"> | ||
]> { | ||
let summary = "cast a tensor to memref"; | ||
let summary = "cast a tensor-like type to buffer-like type"; | ||
let description = [{ | ||
An operation that returns the future buffer of a `tensor`. | ||
|
||
|
@@ -523,8 +524,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [ | |
the returned buffer) will not be written to. | ||
}]; | ||
|
||
let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only); | ||
let results = (outs AnyRankedOrUnrankedMemRef:$memref); | ||
let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only); | ||
let results = (outs Bufferization_BufferLikeTypeInterface:$buffer); | ||
|
||
let extraClassDeclaration = [{ | ||
//===------------------------------------------------------------------===// | ||
|
@@ -559,7 +560,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [ | |
}]; | ||
|
||
let assemblyFormat = [{ | ||
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref) | ||
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer) | ||
}]; | ||
|
||
let hasFolder = 1; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is weird that you declare this in
detail
namespace, but you use it like public API in other places. My understanding is that the methods withindetail
namespace are not intended to be used widely. With your change, it is used in many other dialects like Linalg/SCF/Arith/etc. Are you going to remove this function? Is it just for the transition state?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, ideally, I want to get rid of this once there's a direct support of Buffer-Like (even if it's "always a memref"). The follow-up patch already removes some of the instances. It is a bridge code that exists to have less boilerplate between BufferLike <-> BaseMemrefType conversions in C++, which is why it's in
detail
- it's not really intended to be used (and there are no promises on API / usability of the function).Long term, if it turns out to be useful, we can promote it to a "public" function, right now it seems to be an implementation detail related to bufferization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for the explanation!