15
15
#include " mlir/Dialect/Affine/IR/AffineOps.h"
16
16
#include " mlir/Dialect/Arith/IR/Arith.h"
17
17
#include " mlir/Dialect/Bufferization/IR/Bufferization.h"
18
+ #include " mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
18
19
#include " mlir/Dialect/Func/IR/FuncOps.h"
19
20
#include " mlir/Dialect/GPU/IR/GPUDialect.h"
20
21
#include " mlir/Dialect/Linalg/IR/Linalg.h"
@@ -2911,8 +2912,9 @@ static func::FuncOp createGpuAttentionKernel(ModuleOp module,
2911
2912
MemRefType resMemRefType =
2912
2913
MemRefType::get ({qShape[0 ], sequenceLengthQ, sequenceLengthK},
2913
2914
cast<ShapedType>(qkTensor.getType ()).getElementType ());
2914
- Value resMemref =
2915
- builder.create <bufferization::ToBufferOp>(loc, resMemRefType, qkTensor);
2915
+ Value resMemref = builder.create <bufferization::ToBufferOp>(
2916
+ loc, cast<mlir::bufferization::BufferLikeType>(resMemRefType),
2917
+ qkTensor);
2916
2918
Value outMemref = preSoftmaxElemwiseBlock->addArgument (resMemRefType, loc);
2917
2919
builder.create <memref::CopyOp>(loc, resMemref, outMemref);
2918
2920
builder.create <rock::YieldOp>(loc);
@@ -3002,8 +3004,9 @@ createGpuConvElementwiseGemmKernel(ModuleOp module, const GenParams ¶ms) {
3002
3004
MemRefType resMemRefType =
3003
3005
MemRefType::get ({aShape[0 ], firstGemmSize.m , firstGemmSize.n },
3004
3006
cast<ShapedType>(abTensor.getType ()).getElementType ());
3005
- Value resMemref =
3006
- builder.create <bufferization::ToBufferOp>(loc, resMemRefType, abTensor);
3007
+ Value resMemref = builder.create <bufferization::ToBufferOp>(
3008
+ loc, cast<mlir::bufferization::BufferLikeType>(resMemRefType),
3009
+ abTensor);
3007
3010
Value outMemref = preSecondGemmBlock->addArgument (resMemRefType, loc);
3008
3011
builder.create <memref::CopyOp>(loc, resMemref, outMemref);
3009
3012
builder.create <rock::YieldOp>(loc);
@@ -3098,8 +3101,9 @@ createGpuGemmElementwiseGemmKernel(ModuleOp module, const GenParams ¶ms) {
3098
3101
MemRefType resMemRefType =
3099
3102
MemRefType::get ({aShape[0 ], gemmM, gemmN},
3100
3103
cast<ShapedType>(abTensor.getType ()).getElementType ());
3101
- Value resMemref =
3102
- builder.create <bufferization::ToBufferOp>(loc, resMemRefType, abTensor);
3104
+ Value resMemref = builder.create <bufferization::ToBufferOp>(
3105
+ loc, cast<mlir::bufferization::BufferLikeType>(resMemRefType),
3106
+ abTensor);
3103
3107
Value outMemref = preSecondGemmBlock->addArgument (resMemRefType, loc);
3104
3108
builder.create <memref::CopyOp>(loc, resMemref, outMemref);
3105
3109
builder.create <rock::YieldOp>(loc);
@@ -3280,7 +3284,7 @@ createCpuConvElementwiseGemmKernelWithMlir(ModuleOp module,
3280
3284
bool isWritable = false ) {
3281
3285
constexpr bool isRestrict{true };
3282
3286
Value flatTensor = builder.create <bufferization::ToTensorOp>(
3283
- loc, block->getArgument (blockArgIndex).getType (),
3287
+ loc, memref::getTensorTypeFromMemRefType ( block->getArgument (blockArgIndex).getType () ),
3284
3288
block->getArgument (blockArgIndex), isRestrict, isWritable);
3285
3289
ArrayRef<int64_t > origShape =
3286
3290
cast<ShapedType>(argTypes[blockArgIndex]).getShape ();
@@ -3418,11 +3422,11 @@ createCpuConvElementwiseGemmKernelWithMlir(ModuleOp module,
3418
3422
}
3419
3423
3420
3424
Value output = block->getArguments ().back ();
3421
- auto outputType = cast<MemRefType >(output.getType ());
3425
+ auto outputType = cast<bufferization::BufferLikeType >(output.getType ());
3422
3426
3423
3427
ImplicitLocOpBuilder implicitBuilder (loc, builder);
3424
- auto shapeValue =
3425
- tosa::getTosaConstShape ( implicitBuilder, outputType.getShape ());
3428
+ auto shapeValue = tosa::getTosaConstShape (
3429
+ implicitBuilder, cast<ShapedType>( outputType) .getShape ());
3426
3430
auto flatResultTensor =
3427
3431
builder.create <tosa::ReshapeOp>(loc, resultTensor, shapeValue);
3428
3432
@@ -3460,7 +3464,7 @@ createCpuGemmElementwiseGemmKernelWithMlir(ModuleOp module,
3460
3464
bool isWritable = false ) {
3461
3465
constexpr bool isRestrict{true };
3462
3466
Value flatTensor = builder.create <bufferization::ToTensorOp>(
3463
- loc, block->getArgument (blockArgIndex).getType (),
3467
+ loc, memref::getTensorTypeFromMemRefType ( block->getArgument (blockArgIndex).getType () ),
3464
3468
block->getArgument (blockArgIndex), isRestrict, isWritable);
3465
3469
ArrayRef<int64_t > origShape =
3466
3470
cast<ShapedType>(argTypes[blockArgIndex]).getShape ();
@@ -3534,11 +3538,11 @@ createCpuGemmElementwiseGemmKernelWithMlir(ModuleOp module,
3534
3538
}
3535
3539
3536
3540
Value output = block->getArguments ().back ();
3537
- auto outputType = cast<MemRefType >(output.getType ());
3541
+ auto outputType = cast<mlir::bufferization::BufferLikeType >(output.getType ());
3538
3542
3539
3543
ImplicitLocOpBuilder implicitBuilder (loc, builder);
3540
- auto shapeValue =
3541
- tosa::getTosaConstShape ( implicitBuilder, outputType.getShape ());
3544
+ auto shapeValue = tosa::getTosaConstShape (
3545
+ implicitBuilder, cast<ShapedType>( outputType) .getShape ());
3542
3546
auto flatResultTensor =
3543
3547
builder.create <tosa::ReshapeOp>(loc, resultTensor, shapeValue);
3544
3548
@@ -3576,7 +3580,7 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
3576
3580
bool isWritable = false ) {
3577
3581
constexpr bool isRestrict{true };
3578
3582
Value flatTensor = builder.create <bufferization::ToTensorOp>(
3579
- loc, block->getArgument (blockArgIndex).getType (),
3583
+ loc, memref::getTensorTypeFromMemRefType ( block->getArgument (blockArgIndex).getType () ),
3580
3584
block->getArgument (blockArgIndex), isRestrict, isWritable);
3581
3585
ArrayRef<int64_t > origShape =
3582
3586
cast<ShapedType>(argTypes[blockArgIndex]).getShape ();
@@ -3792,10 +3796,10 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
3792
3796
}
3793
3797
3794
3798
Value output = block->getArguments ().back ();
3795
- auto outputType = cast<MemRefType >(output.getType ());
3799
+ auto outputType = cast<mlir::bufferization::BufferLikeType >(output.getType ());
3796
3800
ImplicitLocOpBuilder implicitBuilder (loc, builder);
3797
- auto shapeValue =
3798
- tosa::getTosaConstShape ( implicitBuilder, outputType.getShape ());
3801
+ auto shapeValue = tosa::getTosaConstShape (
3802
+ implicitBuilder, cast<ShapedType>( outputType) .getShape ());
3799
3803
auto flatResultTensor =
3800
3804
builder.create <tosa::ReshapeOp>(loc, resultTensor, shapeValue);
3801
3805
@@ -3806,9 +3810,9 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
3806
3810
3807
3811
// return LSE (log-sum-exp)
3808
3812
if (returnLSE) {
3809
- auto lseOutType = cast<MemRefType >(lseOut.getType ());
3810
- auto lseShapeValue =
3811
- tosa::getTosaConstShape ( implicitBuilder, lseOutType.getShape ());
3813
+ auto lseOutType = cast<bufferization::BufferLikeType >(lseOut.getType ());
3814
+ auto lseShapeValue = tosa::getTosaConstShape (
3815
+ implicitBuilder, cast<ShapedType>( lseOutType) .getShape ());
3812
3816
auto flatLseTensor =
3813
3817
builder.create <tosa::ReshapeOp>(loc, lseTensor, lseShapeValue);
3814
3818
0 commit comments