diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 47962f75558ea..7d396e5c64c28 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1771,8 +1771,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [ ]> { let summary = "tensor splat or broadcast operation"; let description = [{ - Broadcast the operand to all elements of the result tensor. The operand is - required to be of integer/index/float type. + Broadcast the operand to all elements of the result tensor. An additional argument of type `index` must be provided for each dynamic dimension present in the result type. @@ -1795,8 +1794,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [ ``` }]; - let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], - "integer/index/float type">:$input, + let arguments = (ins AnyType:$input, Variadic:$dynamicSizes); let results = (outs AnyRankedTensor:$aggregate); diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index c0adc8a49bf70..296ca02564e35 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -615,6 +615,21 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> { // ----- +// CHECK-LABEL: func @tensor.splat_other( +// CHECK-SAME: %[[F:.*]]: !test.memref_element) +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!test.memref_element> +// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: %[[MAPPED:.*]] = linalg.map +// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4x!test.memref_element>) +// CHECK: linalg.yield %[[F]] +// CHECK: return %[[MAPPED]] : tensor<10x2x4x!test.memref_element> +func.func @tensor.splat_other(%f: !test.memref_element) -> tensor<10x2x4x!test.memref_element> { + %t = tensor.splat %f : tensor<10x2x4x!test.memref_element> + return %t : tensor<10x2x4x!test.memref_element> +} + +// ----- + // CHECK-LABEL: func @tensor.concat( // CHECK-SAME: %[[F:.*]]: tensor<8xf32>) // CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]] diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index f35d52e700084..665657a67dc61 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -466,9 +466,10 @@ func.func @invalid_splat(%v : f32) { // ----- -func.func @invalid_splat(%v : vector<8xf32>) { - // expected-error@+1 {{must be integer/index/float type}} - %w = tensor.splat %v : tensor<8xvector<8xf32>> +// expected-note@+1 {{prior use here}} +func.func @invalid_splat(%v : f32) { + // expected-error@+1 {{expects different type than prior uses: 'i32' vs 'f32'}} + %w = tensor.splat %v : tensor<1xi32> return } diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 930986211cb6d..681a934ba0698 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -313,13 +313,17 @@ func.func @pad_to_static_size(%arg0: tensor, %ub0: index, %ub1: index, // ----- // CHECK-LABEL: func @test_splat_op -// CHECK-SAME: [[S:%arg[0-9]+]]: f32 -func.func @test_splat_op(%s : f32) { - // CHECK: tensor.splat [[S]] : tensor<8xf32> +// CHECK-SAME: %[[S:.*]]: f32 +// CHECK-SAME: %[[P:.*]]: !llvm.ptr +func.func @test_splat_op(%s : f32, %p : !llvm.ptr) { + // CHECK: tensor.splat %[[S]] : tensor<8xf32> %v = tensor.splat %s : tensor<8xf32> - // CHECK: tensor.splat [[S]] : tensor<4xf32> + // CHECK: tensor.splat %[[S]] : tensor<4xf32> %u = "tensor.splat"(%s) : (f32) -> tensor<4xf32> + + // CHECK: tensor.splat %[[P]] : tensor<8x!llvm.ptr> + %w = tensor.splat %p : tensor<8x!llvm.ptr> return }