diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 059288e18049b..f26aa0a1516a6 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1328,15 +1328,19 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( mapping.map(origArg, argMat); appendRewrite(block, origArg); - // FIXME: We simply pass through the replacement argument if there wasn't a - // converter, which isn't great as it allows implicit type conversions to - // appear. We should properly restructure this code to handle cases where a - // converter isn't provided and also to properly handle the case where an - // argument materialization is actually a temporary source materialization - // (e.g. in the case of 1->N). Type legalOutputType; - if (converter) + if (converter) { legalOutputType = converter->convertType(origArgType); + } else if (replArgs.size() == 1) { + // When there is no type converter, assume that the new block argument + // types are legal. This is reasonable to assume because they were + // specified by the user. + // FIXME: This won't work for 1->N conversions because multiple output + // types are not supported in parts of the dialect conversion. In such a + // case, we currently use the original block argument type (produced by + // the argument materialization). + legalOutputType = replArgs[0].getType(); + } if (legalOutputType && legalOutputType != origArgType) { Value targetMat = buildUnresolvedTargetMaterialization( origArg.getLoc(), argMat, legalOutputType, converter); diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir index 8254be68912c8..d0563fed8e5d9 100644 --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -127,3 +127,18 @@ llvm.func @unsupported_func_op_interface() { // CHECK: llvm.return llvm.return } + +// ----- + +// CHECK-LABEL: func @test_signature_conversion_no_converter() +func.func @test_signature_conversion_no_converter() { + // CHECK: "test.signature_conversion_no_converter"() ({ + // CHECK: ^{{.*}}(%[[arg0:.*]]: f64): + "test.signature_conversion_no_converter"() ({ + ^bb0(%arg0: f32): + // CHECK: "test.legal_op_d"(%[[arg0]]) : (f64) -> () + "test.replace_with_legal_op"(%arg0) : (f32) -> () + "test.return"() : () -> () + }) : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 2d97a02b8076a..2b55bff3538d3 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1884,6 +1884,7 @@ def LegalOpA : TEST_Op<"legal_op_a">, def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>; def LegalOpC : TEST_Op<"legal_op_c">, Arguments<(ins I32)>, Results<(outs I32)>; +def LegalOpD : TEST_Op<"legal_op_d">, Arguments<(ins AnyType)>; // Check that the conversion infrastructure can properly undo the creation of // operations where an operation was created before its parent, in this case, diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 0546523a58c80..91dfb2faa80a1 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1580,6 +1580,17 @@ struct TestTypeConversionAnotherProducer } }; +struct TestReplaceWithLegalOp : public ConversionPattern { + TestReplaceWithLegalOp(MLIRContext *ctx) + : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp(op, operands[0]); + return success(); + } +}; + struct TestTypeConversionDriver : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) @@ -1671,6 +1682,7 @@ struct TestTypeConversionDriver // Initialize the conversion target. mlir::ConversionTarget target(getContext()); + target.addLegalOp(); target.addDynamicallyLegalOp([](TestTypeProducerOp op) { auto recursiveType = dyn_cast(op.getType()); return op.getType().isF64() || op.getType().isInteger(64) || @@ -1696,7 +1708,8 @@ struct TestTypeConversionDriver TestSignatureConversionUndo, TestTestSignatureConversionNoConverter>(converter, &getContext()); - patterns.add(&getContext()); + patterns.add( + &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter);