diff --git a/lib/ClangImporter/ClangDerivedConformances.cpp b/lib/ClangImporter/ClangDerivedConformances.cpp index 89bd5d8ddd66d..7991e4987ff96 100644 --- a/lib/ClangImporter/ClangDerivedConformances.cpp +++ b/lib/ClangImporter/ClangDerivedConformances.cpp @@ -565,6 +565,23 @@ void swift::conformToCxxSequenceIfNeeded( } } +static bool isStdSetType(const clang::CXXRecordDecl *clangDecl) { + return isStdDecl(clangDecl, {"set", "unordered_set", "multiset"}); +} + +bool swift::isUnsafeStdMethod(const clang::CXXMethodDecl *methodDecl) { + auto parentDecl = + dyn_cast(methodDecl->getDeclContext()); + if (!parentDecl) + return false; + if (!isStdSetType(parentDecl)) + return false; + if (methodDecl->getDeclName().isIdentifier() && + methodDecl->getName() == "insert") + return true; + return false; +} + void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl, NominalTypeDecl *decl, const clang::CXXRecordDecl *clangDecl) { @@ -576,7 +593,7 @@ void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl, // Only auto-conform types from the C++ standard library. Custom user types // might have a similar interface but different semantics. - if (!isStdDecl(clangDecl, {"set", "unordered_set", "multiset"})) + if (!isStdSetType(clangDecl)) return; auto valueType = lookupDirectSingleWithoutExtensions( @@ -586,10 +603,33 @@ void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl, if (!valueType || !sizeType) return; + auto insertId = ctx.getIdentifier("__insertUnsafe"); + auto inserts = lookupDirectWithoutExtensions(decl, insertId); + FuncDecl *insert = nullptr; + for (auto candidate : inserts) { + if (auto candidateMethod = dyn_cast(candidate)) { + if (!candidateMethod->hasParameterList()) + continue; + auto params = candidateMethod->getParameters(); + if (params->size() != 1) + continue; + auto param = params->front(); + if (param->getType()->getCanonicalType() != + valueType->getUnderlyingType()->getCanonicalType()) + continue; + insert = candidateMethod; + break; + } + } + if (!insert) + return; + impl.addSynthesizedTypealias(decl, ctx.Id_Element, valueType->getUnderlyingType()); impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Size"), sizeType->getUnderlyingType()); + impl.addSynthesizedTypealias(decl, ctx.getIdentifier("InsertionResult"), + insert->getResultInterfaceType()); impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSet}); } diff --git a/lib/ClangImporter/ClangDerivedConformances.h b/lib/ClangImporter/ClangDerivedConformances.h index 46616c677b702..3faeb3efdee05 100644 --- a/lib/ClangImporter/ClangDerivedConformances.h +++ b/lib/ClangImporter/ClangDerivedConformances.h @@ -20,6 +20,8 @@ namespace swift { bool isIterator(const clang::CXXRecordDecl *clangDecl); +bool isUnsafeStdMethod(const clang::CXXMethodDecl *methodDecl); + /// If the decl is a C++ input iterator, synthesize a conformance to the /// UnsafeCxxInputIterator protocol, which is defined in the Cxx module. void conformToCxxIteratorIfNeeded(ClangImporter::Implementation &impl, diff --git a/lib/ClangImporter/ClangImporter.cpp b/lib/ClangImporter/ClangImporter.cpp index 3ec0611bcbe1d..d7a322ff0eeb1 100644 --- a/lib/ClangImporter/ClangImporter.cpp +++ b/lib/ClangImporter/ClangImporter.cpp @@ -6747,6 +6747,11 @@ bool IsSafeUseOfCxxDecl::evaluate(Evaluator &evaluator, method->getReturnType()->isReferenceType()) return false; + // Check if it's one of the known unsafe methods we currently + // mark as safe by default. + if (isUnsafeStdMethod(method)) + return false; + // Try to figure out the semantics of the return type. If it's a // pointer/iterator, it's unsafe. if (auto returnType = dyn_cast( diff --git a/stdlib/public/Cxx/CxxSet.swift b/stdlib/public/Cxx/CxxSet.swift index 02dec94fb1193..d468f6deed52e 100644 --- a/stdlib/public/Cxx/CxxSet.swift +++ b/stdlib/public/Cxx/CxxSet.swift @@ -13,11 +13,31 @@ public protocol CxxSet { associatedtype Element associatedtype Size: BinaryInteger + associatedtype InsertionResult // std::pair + + init() + + @discardableResult + mutating func __insertUnsafe(_ element: Element) -> InsertionResult func count(_ element: Element) -> Size } extension CxxSet { + /// Creates a C++ set containing the elements of a Swift Sequence. + /// + /// This initializes the set by copying every element of the sequence. + /// + /// - Complexity: O(*n*), where *n* is the number of elements in the Swift + /// sequence + @inlinable + public init(_ sequence: __shared S) where S.Element == Element { + self.init() + for item in sequence { + self.__insertUnsafe(item) + } + } + @inlinable public func contains(_ element: Element) -> Bool { return count(element) > 0 diff --git a/test/Interop/Cxx/stdlib/use-std-set.swift b/test/Interop/Cxx/stdlib/use-std-set.swift index 56f7b7672f166..95f0110908f10 100644 --- a/test/Interop/Cxx/stdlib/use-std-set.swift +++ b/test/Interop/Cxx/stdlib/use-std-set.swift @@ -47,4 +47,18 @@ StdSetTestSuite.test("MultisetOfCInt.contains") { expectFalse(s.contains(3)) } +StdSetTestSuite.test("SetOfCInt.init()") { + let s = SetOfCInt([1, 3, 5]) + expectTrue(s.contains(1)) + expectFalse(s.contains(2)) + expectTrue(s.contains(3)) +} + +StdSetTestSuite.test("UnorderedSetOfCInt.init()") { + let s = UnorderedSetOfCInt([1, 3, 5]) + expectTrue(s.contains(1)) + expectFalse(s.contains(2)) + expectTrue(s.contains(3)) +} + runAllTests()