From 5a1cd6d4260fb7741ad08d8137adceb833d8df1e Mon Sep 17 00:00:00 2001 From: mboeck Date: Thu, 26 Jun 2025 15:24:11 +0200 Subject: [PATCH 1/3] [mlir][tensor] Relax input type requirement on `tensor.splat` `tensor.splat` is currently restricted to only accepting input values that are of integer, index or float type. This is much more restrictive than the tensor type itself as well as any lowerings of it. This PR therefore removes this restriction by using `AnyType` for the input value. Whether the type is actually valid or not for a tensor remains verified through the type equality of the result tensor element type and the input type. --- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 6 ++---- mlir/test/Dialect/Tensor/bufferize.mlir | 15 +++++++++++++++ mlir/test/Dialect/Tensor/invalid.mlir | 7 ++++--- mlir/test/Dialect/Tensor/ops.mlir | 6 +++++- 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 47962f75558ea..7d396e5c64c28 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1771,8 +1771,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [ ]> { let summary = "tensor splat or broadcast operation"; let description = [{ - Broadcast the operand to all elements of the result tensor. The operand is - required to be of integer/index/float type. + Broadcast the operand to all elements of the result tensor. An additional argument of type `index` must be provided for each dynamic dimension present in the result type. @@ -1795,8 +1794,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [ ``` }]; - let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], - "integer/index/float type">:$input, + let arguments = (ins AnyType:$input, Variadic:$dynamicSizes); let results = (outs AnyRankedTensor:$aggregate); diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index c0adc8a49bf70..e202a6b3f3e7a 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -615,6 +615,21 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> { // ----- +// CHECK-LABEL: func @tensor.splat_other( +// CHECK-SAME: %[[F:.*]]: !llvm.ptr) +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!llvm.ptr> +// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: %[[MAPPED:.*]] = linalg.map +// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4x!llvm.ptr>) +// CHECK: linalg.yield %[[F]] +// CHECK: return %[[MAPPED]] : tensor<10x2x4x!llvm.ptr> +func.func @tensor.splat_other(%f: !llvm.ptr) -> tensor<10x2x4x!llvm.ptr> { + %t = tensor.splat %f : tensor<10x2x4x!llvm.ptr> + return %t : tensor<10x2x4x!llvm.ptr> +} + +// ----- + // CHECK-LABEL: func @tensor.concat( // CHECK-SAME: %[[F:.*]]: tensor<8xf32>) // CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]] diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index f35d52e700084..665657a67dc61 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -466,9 +466,10 @@ func.func @invalid_splat(%v : f32) { // ----- -func.func @invalid_splat(%v : vector<8xf32>) { - // expected-error@+1 {{must be integer/index/float type}} - %w = tensor.splat %v : tensor<8xvector<8xf32>> +// expected-note@+1 {{prior use here}} +func.func @invalid_splat(%v : f32) { + // expected-error@+1 {{expects different type than prior uses: 'i32' vs 'f32'}} + %w = tensor.splat %v : tensor<1xi32> return } diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 930986211cb6d..0fd4b87508a79 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -314,12 +314,16 @@ func.func @pad_to_static_size(%arg0: tensor, %ub0: index, %ub1: index, // CHECK-LABEL: func @test_splat_op // CHECK-SAME: [[S:%arg[0-9]+]]: f32 -func.func @test_splat_op(%s : f32) { +// CHECK-SAME: [[P:%arg[0-9]+]]: !llvm.ptr +func.func @test_splat_op(%s : f32, %p : !llvm.ptr) { // CHECK: tensor.splat [[S]] : tensor<8xf32> %v = tensor.splat %s : tensor<8xf32> // CHECK: tensor.splat [[S]] : tensor<4xf32> %u = "tensor.splat"(%s) : (f32) -> tensor<4xf32> + + // CHECK: tensor.splat [[P]] : tensor<8x!llvm.ptr> + %w = tensor.splat %p : tensor<8x!llvm.ptr> return } From cd1b78655556d643b110069a5f18d15b78387481 Mon Sep 17 00:00:00 2001 From: mboeck Date: Thu, 26 Jun 2025 16:05:21 +0200 Subject: [PATCH 2/3] address review comments --- mlir/test/Dialect/Tensor/ops.mlir | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 0fd4b87508a79..681a934ba0698 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -313,16 +313,16 @@ func.func @pad_to_static_size(%arg0: tensor, %ub0: index, %ub1: index, // ----- // CHECK-LABEL: func @test_splat_op -// CHECK-SAME: [[S:%arg[0-9]+]]: f32 -// CHECK-SAME: [[P:%arg[0-9]+]]: !llvm.ptr +// CHECK-SAME: %[[S:.*]]: f32 +// CHECK-SAME: %[[P:.*]]: !llvm.ptr func.func @test_splat_op(%s : f32, %p : !llvm.ptr) { - // CHECK: tensor.splat [[S]] : tensor<8xf32> + // CHECK: tensor.splat %[[S]] : tensor<8xf32> %v = tensor.splat %s : tensor<8xf32> - // CHECK: tensor.splat [[S]] : tensor<4xf32> + // CHECK: tensor.splat %[[S]] : tensor<4xf32> %u = "tensor.splat"(%s) : (f32) -> tensor<4xf32> - // CHECK: tensor.splat [[P]] : tensor<8x!llvm.ptr> + // CHECK: tensor.splat %[[P]] : tensor<8x!llvm.ptr> %w = tensor.splat %p : tensor<8x!llvm.ptr> return } From 0e4676ec12c322f6fbfa056f9660531e81612eca Mon Sep 17 00:00:00 2001 From: mboeck Date: Thu, 26 Jun 2025 16:05:38 +0200 Subject: [PATCH 3/3] use type legal memref type in bufferization test --- mlir/test/Dialect/Tensor/bufferize.mlir | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index e202a6b3f3e7a..296ca02564e35 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -616,16 +616,16 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> { // ----- // CHECK-LABEL: func @tensor.splat_other( -// CHECK-SAME: %[[F:.*]]: !llvm.ptr) -// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!llvm.ptr> +// CHECK-SAME: %[[F:.*]]: !test.memref_element) +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!test.memref_element> // CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] // CHECK: %[[MAPPED:.*]] = linalg.map -// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4x!llvm.ptr>) +// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4x!test.memref_element>) // CHECK: linalg.yield %[[F]] -// CHECK: return %[[MAPPED]] : tensor<10x2x4x!llvm.ptr> -func.func @tensor.splat_other(%f: !llvm.ptr) -> tensor<10x2x4x!llvm.ptr> { - %t = tensor.splat %f : tensor<10x2x4x!llvm.ptr> - return %t : tensor<10x2x4x!llvm.ptr> +// CHECK: return %[[MAPPED]] : tensor<10x2x4x!test.memref_element> +func.func @tensor.splat_other(%f: !test.memref_element) -> tensor<10x2x4x!test.memref_element> { + %t = tensor.splat %f : tensor<10x2x4x!test.memref_element> + return %t : tensor<10x2x4x!test.memref_element> } // -----