Skip to content

Commit f905655

Browse files
committed
Changes in rocMLIR due to llvm/llvm-project#142986
1 parent 289459b commit f905655

File tree

3 files changed

+34
-27
lines changed

3 files changed

+34
-27
lines changed

mlir/lib/Conversion/TosaToRock/TosaToRock.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
#include "mlir/Conversion/TosaToRock/TosaToRock.h"
1414
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1516
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
17+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
1618
#include "mlir/Dialect/Func/IR/FuncOps.h"
1719
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1820
#include "mlir/Dialect/Rock/IR/Rock.h"
@@ -1195,8 +1197,8 @@ struct ConvElementwiseGemmRewritePattern
11951197
RankedTensorType resTensorType = cast<RankedTensorType>(res.getType());
11961198
MemRefType resMemRefType = MemRefType::get(
11971199
resTensorType.getShape(), resTensorType.getElementType());
1198-
Value resMemref =
1199-
rewriter.create<bufferization::ToBufferOp>(loc, resMemRefType, res);
1200+
Value resMemref = rewriter.create<bufferization::ToBufferOp>(
1201+
loc, cast<mlir::bufferization::BufferLikeType>(resMemRefType), res);
12001202
Value outMemref =
12011203
preSecondGemmElemwiseBlock->addArgument(resMemRefType, loc);
12021204
rewriter.create<memref::CopyOp>(loc, resMemref, outMemref);
@@ -1279,8 +1281,8 @@ struct GemmElementwiseGemmRewritePattern
12791281
RankedTensorType resTensorType = cast<RankedTensorType>(res.getType());
12801282
MemRefType resMemRefType = MemRefType::get(
12811283
resTensorType.getShape(), resTensorType.getElementType());
1282-
Value resMemref =
1283-
rewriter.create<bufferization::ToBufferOp>(loc, resMemRefType, res);
1284+
Value resMemref = rewriter.create<bufferization::ToBufferOp>(
1285+
loc, cast<mlir::bufferization::BufferLikeType>(resMemRefType), res);
12841286
Value outMemref =
12851287
preSecondGemmElemwiseBlock->addArgument(resMemRefType, loc);
12861288
rewriter.create<memref::CopyOp>(loc, resMemref, outMemref);
@@ -1714,7 +1716,7 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
17141716
MemRefType resMemRefType = MemRefType::get(
17151717
resTensorType.getShape(), resTensorType.getElementType());
17161718
Value resMemref =
1717-
rewriter.create<bufferization::ToBufferOp>(loc, resMemRefType, res);
1719+
rewriter.create<bufferization::ToBufferOp>(loc, cast<bufferization::BufferLikeType>(resMemRefType), res);
17181720
Value outMemref =
17191721
preSoftmaxElemwiseBlock->addArgument(resMemRefType, loc);
17201722
rewriter.create<memref::CopyOp>(loc, resMemref, outMemref);

mlir/lib/Dialect/Rock/utility/builderUtils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ Value getAsTensor(OpBuilder &builder, Location loc, mlir::Value value,
209209
bool isWritable) {
210210
constexpr bool isRestrict{true};
211211
Value origTensor = builder.create<bufferization::ToTensorOp>(
212-
loc, value.getType(), value, isRestrict, isWritable);
212+
loc, memref::getTensorTypeFromMemRefType(value.getType()), value,
213+
isRestrict, isWritable);
213214
return origTensor;
214215
}
215216

mlir/tools/rocmlir-gen/rocmlir-gen.cpp

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
18+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
1819
#include "mlir/Dialect/Func/IR/FuncOps.h"
1920
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2021
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -2911,8 +2912,9 @@ static func::FuncOp createGpuAttentionKernel(ModuleOp module,
29112912
MemRefType resMemRefType =
29122913
MemRefType::get({qShape[0], sequenceLengthQ, sequenceLengthK},
29132914
cast<ShapedType>(qkTensor.getType()).getElementType());
2914-
Value resMemref =
2915-
builder.create<bufferization::ToBufferOp>(loc, resMemRefType, qkTensor);
2915+
Value resMemref = builder.create<bufferization::ToBufferOp>(
2916+
loc, cast<mlir::bufferization::BufferLikeType>(resMemRefType),
2917+
qkTensor);
29162918
Value outMemref = preSoftmaxElemwiseBlock->addArgument(resMemRefType, loc);
29172919
builder.create<memref::CopyOp>(loc, resMemref, outMemref);
29182920
builder.create<rock::YieldOp>(loc);
@@ -3002,8 +3004,9 @@ createGpuConvElementwiseGemmKernel(ModuleOp module, const GenParams &params) {
30023004
MemRefType resMemRefType =
30033005
MemRefType::get({aShape[0], firstGemmSize.m, firstGemmSize.n},
30043006
cast<ShapedType>(abTensor.getType()).getElementType());
3005-
Value resMemref =
3006-
builder.create<bufferization::ToBufferOp>(loc, resMemRefType, abTensor);
3007+
Value resMemref = builder.create<bufferization::ToBufferOp>(
3008+
loc, cast<mlir::bufferization::BufferLikeType>(resMemRefType),
3009+
abTensor);
30073010
Value outMemref = preSecondGemmBlock->addArgument(resMemRefType, loc);
30083011
builder.create<memref::CopyOp>(loc, resMemref, outMemref);
30093012
builder.create<rock::YieldOp>(loc);
@@ -3098,8 +3101,9 @@ createGpuGemmElementwiseGemmKernel(ModuleOp module, const GenParams &params) {
30983101
MemRefType resMemRefType =
30993102
MemRefType::get({aShape[0], gemmM, gemmN},
31003103
cast<ShapedType>(abTensor.getType()).getElementType());
3101-
Value resMemref =
3102-
builder.create<bufferization::ToBufferOp>(loc, resMemRefType, abTensor);
3104+
Value resMemref = builder.create<bufferization::ToBufferOp>(
3105+
loc, cast<mlir::bufferization::BufferLikeType>(resMemRefType),
3106+
abTensor);
31033107
Value outMemref = preSecondGemmBlock->addArgument(resMemRefType, loc);
31043108
builder.create<memref::CopyOp>(loc, resMemref, outMemref);
31053109
builder.create<rock::YieldOp>(loc);
@@ -3280,7 +3284,7 @@ createCpuConvElementwiseGemmKernelWithMlir(ModuleOp module,
32803284
bool isWritable = false) {
32813285
constexpr bool isRestrict{true};
32823286
Value flatTensor = builder.create<bufferization::ToTensorOp>(
3283-
loc, block->getArgument(blockArgIndex).getType(),
3287+
loc, memref::getTensorTypeFromMemRefType(block->getArgument(blockArgIndex).getType()),
32843288
block->getArgument(blockArgIndex), isRestrict, isWritable);
32853289
ArrayRef<int64_t> origShape =
32863290
cast<ShapedType>(argTypes[blockArgIndex]).getShape();
@@ -3418,11 +3422,11 @@ createCpuConvElementwiseGemmKernelWithMlir(ModuleOp module,
34183422
}
34193423

34203424
Value output = block->getArguments().back();
3421-
auto outputType = cast<MemRefType>(output.getType());
3425+
auto outputType = cast<bufferization::BufferLikeType>(output.getType());
34223426

34233427
ImplicitLocOpBuilder implicitBuilder(loc, builder);
3424-
auto shapeValue =
3425-
tosa::getTosaConstShape(implicitBuilder, outputType.getShape());
3428+
auto shapeValue = tosa::getTosaConstShape(
3429+
implicitBuilder, cast<ShapedType>(outputType).getShape());
34263430
auto flatResultTensor =
34273431
builder.create<tosa::ReshapeOp>(loc, resultTensor, shapeValue);
34283432

@@ -3460,7 +3464,7 @@ createCpuGemmElementwiseGemmKernelWithMlir(ModuleOp module,
34603464
bool isWritable = false) {
34613465
constexpr bool isRestrict{true};
34623466
Value flatTensor = builder.create<bufferization::ToTensorOp>(
3463-
loc, block->getArgument(blockArgIndex).getType(),
3467+
loc, memref::getTensorTypeFromMemRefType(block->getArgument(blockArgIndex).getType()),
34643468
block->getArgument(blockArgIndex), isRestrict, isWritable);
34653469
ArrayRef<int64_t> origShape =
34663470
cast<ShapedType>(argTypes[blockArgIndex]).getShape();
@@ -3534,11 +3538,11 @@ createCpuGemmElementwiseGemmKernelWithMlir(ModuleOp module,
35343538
}
35353539

35363540
Value output = block->getArguments().back();
3537-
auto outputType = cast<MemRefType>(output.getType());
3541+
auto outputType = cast<mlir::bufferization::BufferLikeType>(output.getType());
35383542

35393543
ImplicitLocOpBuilder implicitBuilder(loc, builder);
3540-
auto shapeValue =
3541-
tosa::getTosaConstShape(implicitBuilder, outputType.getShape());
3544+
auto shapeValue = tosa::getTosaConstShape(
3545+
implicitBuilder, cast<ShapedType>(outputType).getShape());
35423546
auto flatResultTensor =
35433547
builder.create<tosa::ReshapeOp>(loc, resultTensor, shapeValue);
35443548

@@ -3576,7 +3580,7 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
35763580
bool isWritable = false) {
35773581
constexpr bool isRestrict{true};
35783582
Value flatTensor = builder.create<bufferization::ToTensorOp>(
3579-
loc, block->getArgument(blockArgIndex).getType(),
3583+
loc, memref::getTensorTypeFromMemRefType(block->getArgument(blockArgIndex).getType()),
35803584
block->getArgument(blockArgIndex), isRestrict, isWritable);
35813585
ArrayRef<int64_t> origShape =
35823586
cast<ShapedType>(argTypes[blockArgIndex]).getShape();
@@ -3792,10 +3796,10 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
37923796
}
37933797

37943798
Value output = block->getArguments().back();
3795-
auto outputType = cast<MemRefType>(output.getType());
3799+
auto outputType = cast<mlir::bufferization::BufferLikeType>(output.getType());
37963800
ImplicitLocOpBuilder implicitBuilder(loc, builder);
3797-
auto shapeValue =
3798-
tosa::getTosaConstShape(implicitBuilder, outputType.getShape());
3801+
auto shapeValue = tosa::getTosaConstShape(
3802+
implicitBuilder, cast<ShapedType>(outputType).getShape());
37993803
auto flatResultTensor =
38003804
builder.create<tosa::ReshapeOp>(loc, resultTensor, shapeValue);
38013805

@@ -3806,9 +3810,9 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
38063810

38073811
// return LSE (log-sum-exp)
38083812
if (returnLSE) {
3809-
auto lseOutType = cast<MemRefType>(lseOut.getType());
3810-
auto lseShapeValue =
3811-
tosa::getTosaConstShape(implicitBuilder, lseOutType.getShape());
3813+
auto lseOutType = cast<bufferization::BufferLikeType>(lseOut.getType());
3814+
auto lseShapeValue = tosa::getTosaConstShape(
3815+
implicitBuilder, cast<ShapedType>(lseOutType).getShape());
38123816
auto flatLseTensor =
38133817
builder.create<tosa::ReshapeOp>(loc, lseTensor, lseShapeValue);
38143818

0 commit comments

Comments
 (0)