diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index fbbf817ecff98..5fab2ee1194e8 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2034,9 +2034,9 @@ def Vector_ScatterOp : Vector_Op<"scatter">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec, - VectorOfRankAndType<[1], [I1]>:$mask, - VectorOfRank<[1]>:$valueToStore)> { + VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec, + VectorOfNonZeroRankOf<[I1]>:$mask, + AnyVectorOfNonZeroRank:$valueToStore)> { let summary = [{ scatters elements from a vector into memory as defined by an index vector @@ -2044,9 +2044,9 @@ def Vector_ScatterOp : }]; let description = [{ - The scatter operation stores elements from a 1-D vector into memory as - defined by a base with indices and an additional 1-D index vector, but - only if the corresponding bit in a 1-D mask vector is set. Otherwise, no + The scatter operation stores elements from a n-D vector into memory as + defined by a base with indices and an additional n-D index vector, but + only if the corresponding bit in a n-D mask vector is set. Otherwise, no action is taken for that element. Informally the semantics are: ``` if (mask[0]) base[index[0]] = value[0] diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 357152eba8003..213f7375b8d13 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -263,22 +263,25 @@ class VectorGatherOpConversion LogicalResult matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + Location loc = gather->getLoc(); MemRefType memRefType = dyn_cast(gather.getBaseType()); assert(memRefType && "The base should be bufferized"); if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) - return failure(); + return rewriter.notifyMatchFailure(gather, "memref type not supported"); VectorType vType = gather.getVectorType(); - if (vType.getRank() > 1) - return failure(); - - Location loc = gather->getLoc(); + if (vType.getRank() > 1) { + return rewriter.notifyMatchFailure( + gather, "only 1-D vectors can be lowered to LLVM"); + } // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) - return failure(); + if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) { + return rewriter.notifyMatchFailure(gather, + "could not resolve memref alignment"); + } // Resolve address. Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), @@ -309,15 +312,22 @@ class VectorScatterOpConversion MemRefType memRefType = scatter.getMemRefType(); if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) - return failure(); + return rewriter.notifyMatchFailure(scatter, "memref type not supported"); + + VectorType vType = scatter.getVectorType(); + if (vType.getRank() > 1) { + return rewriter.notifyMatchFailure( + scatter, "only 1-D vectors can be lowered to LLVM"); + } // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) - return failure(); + if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) { + return rewriter.notifyMatchFailure(scatter, + "could not resolve memref alignment"); + } // Resolve address. - VectorType vType = scatter.getVectorType(); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), adaptor.getIndices(), rewriter); Value ptrs = diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index d4c1da30d498d..d006a1498f350 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() { return emitOpError("base and valueToStore element type should match"); if (llvm::size(getIndices()) != memType.getRank()) return emitOpError("requires ") << memType.getRank() << " indices"; - if (valueVType.getDimSize(0) != indVType.getDimSize(0)) + if (valueVType.getShape() != indVType.getShape()) return emitOpError("expected valueToStore dim to match indices dim"); - if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) + if (valueVType.getShape() != maskVType.getShape()) return emitOpError("expected valueToStore dim to match mask dim"); return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 5404fdda033ee..ba1da84719106 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref, %arg1: vector<2x // ----- +//===----------------------------------------------------------------------===// +// vector.scatter +//===----------------------------------------------------------------------===// + +// Multi-Dimensional scatters are not supported yet. Check that we do not lower +// them. + +func.func @scatter_with_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) { + %0 = arith.constant 0: index + %1 = vector.constant_mask [2, 2] : vector<2x3xi1> + vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> + return +} + +// CHECK-LABEL: func @scatter_with_mask +// CHECK: vector.scatter + +// ----- + +func.func @scatter_with_mask_scalable(%arg0: memref, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) { + %0 = arith.constant 0: index + // vector.constant_mask only supports 'none set' or 'all set' scalable + // dimensions, hence [2, 3] rather than [2, 2] as in the example for fixed + // width vectors above. + %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1> + vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> + return +} + +// CHECK-LABEL: func @scatter_with_mask_scalable +// CHECK: vector.scatter + +// ----- + //===----------------------------------------------------------------------===// // vector.interleave //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 57e348c7d5991..1b89e8eb5069b 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1484,7 +1484,7 @@ func.func @scatter_memref_mismatch(%base: memref, %indices: vector<16xi func.func @scatter_rank_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<2x16xf32>) { %c0 = arith.constant 0 : index - // expected-error@+1 {{'vector.scatter' op operand #4 must be of ranks 1, but got 'vector<2x16xf32>'}} + // expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match indices dim}} vector.scatter %base[%c0][%indices], %mask, %value : memref, vector<16xi32>, vector<16xi1>, vector<2x16xf32> } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 67484e06f456d..8ae1e9f9d0c64 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref, %v: vector<16xi32>, %mas return } +// CHECK-LABEL: @gather_and_scatter_multi_dims +func.func @gather_and_scatter_multi_dims(%base: memref, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> { + %c0 = arith.constant 0 : index + // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32> + %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32> + // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> + vector.scatter %base[%c0][%v], %mask, %0 : memref, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> + return %0 : vector<2x16xf32> +} + // CHECK-LABEL: @gather_on_tensor func.func @gather_on_tensor(%base: tensor, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> { %c0 = arith.constant 0 : index @@ -890,14 +900,6 @@ func.func @gather_on_tensor(%base: tensor, %v: vector<16xi32>, %mask: vec return %0 : vector<16xf32> } -// CHECK-LABEL: @gather_multi_dims -func.func @gather_multi_dims(%base: tensor, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> { - %c0 = arith.constant 0 : index - // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32> - %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32> - return %0 : vector<2x16xf32> -} - // CHECK-LABEL: @expand_and_compress func.func @expand_and_compress(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = arith.constant 0 : index