diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index c818675993c2c..a934e47794051 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1254,7 +1254,7 @@ static SmallVector getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef reassociation) { // Some basic checks for this fusion to be valid. - if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1) + if (!genericOp.hasPureTensorSemantics()) return {}; if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) { diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index 600f0dea31f4a..f17881d59a266 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -7,49 +7,55 @@ #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)> #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> +#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d0, d7, d3, d4, d5, d6)> func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>, - %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> { + %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) { %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> - %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32> - %generic = linalg.generic { - indexing_maps = [#map0, #map1, #map2, #map3], + %init_0 = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32> + %init_1 = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32> + %generic:2 = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%expand, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>) - outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) { - ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): + outs(%init_0, %init_1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) { + ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, %b4 : i32): %t0 = arith.addi %b0, %b1 : i32 %t1 = arith.addi %t0, %b2 : i32 - linalg.yield %t1 : i32 - } -> tensor<2x3x4x5x6x7x8x9xi32> - return %generic : tensor<2x3x4x5x6x7x8x9xi32> + linalg.yield %t1, %t1 : i32, i32 + } -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) + return %generic#0, %generic#1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32> } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d4, d2, d3)> // CHECK: func @fuse_by_collapsing( // CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32> // CHECK-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32> // CHECK-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32> -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() +// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32> +// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32> // CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}} // CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}} -// CHECK-DAG: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} -// CHECK: %[[COLLAPSED_OP:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] +// CHECK-DAG: %[[INIT0_RESHAPE:.+]] = tensor.collapse_shape %[[INIT0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} +// CHECK-DAG: %[[INIT1_RESHAPE:.+]] = tensor.collapse_shape %[[INIT1]] {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}} +// CHECK: %[[COLLAPSED_OP:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] : -// CHECK-SAME: outs(%[[INIT_RESHAPE]] : -// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9] -// CHECK: return %[[RESULT_RESHAPE]] +// CHECK-SAME: outs(%[[INIT0_RESHAPE]], %[[INIT1_RESHAPE]] : +// CHECK: %[[RESULT0_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]]#0 {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9] +// CHECK: %[[RESULT1_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]]#1 {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}} output_shape [3, 4, 2, 9, 5, 6, 7, 8] +// CHECK: return %[[RESULT0_RESHAPE]], %[[RESULT1_RESHAPE]] // CONTROL: func @fuse_by_collapsing( // CONTROL-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32> // CONTROL-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32> // CONTROL-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32> // CONTROL: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] -// CONTROL: %[[GENERIC:.+]] = linalg.generic +// CONTROL: %[[GENERIC:.+]]:2 = linalg.generic // CONTROL-SAME: ins(%[[EXPAND]], -// CONTROL: return %[[GENERIC]] +// CONTROL: return %[[GENERIC]]#0, %[[GENERIC]]#1 // -----