diff --git a/lib/ClangImporter/ClangDerivedConformances.cpp b/lib/ClangImporter/ClangDerivedConformances.cpp index 89bd5d8ddd66d..59a7a60410aaa 100644 --- a/lib/ClangImporter/ClangDerivedConformances.cpp +++ b/lib/ClangImporter/ClangDerivedConformances.cpp @@ -586,10 +586,33 @@ void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl, if (!valueType || !sizeType) return; + auto insertId = ctx.getIdentifier("__insertUnsafe"); + auto inserts = lookupDirectWithoutExtensions(decl, insertId); + FuncDecl *insert = nullptr; + for (auto candidate : inserts) { + if (auto candidateMethod = dyn_cast(candidate)) { + if (!candidateMethod->hasParameterList()) + continue; + auto params = candidateMethod->getParameters(); + if (params->size() != 1) + continue; + auto param = params->front(); + if (param->getType()->getCanonicalType() != + valueType->getUnderlyingType()->getCanonicalType()) + continue; + insert = candidateMethod; + break; + } + } + if (!insert) + return; + impl.addSynthesizedTypealias(decl, ctx.Id_Element, valueType->getUnderlyingType()); impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Size"), sizeType->getUnderlyingType()); + impl.addSynthesizedTypealias(decl, ctx.getIdentifier("InsertionResult"), + insert->getResultInterfaceType()); impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSet}); } diff --git a/stdlib/public/Cxx/CxxSet.swift b/stdlib/public/Cxx/CxxSet.swift index 02dec94fb1193..e3f143da45b60 100644 --- a/stdlib/public/Cxx/CxxSet.swift +++ b/stdlib/public/Cxx/CxxSet.swift @@ -13,11 +13,31 @@ public protocol CxxSet { associatedtype Element associatedtype Size: BinaryInteger + associatedtype InsertionResult // std::pair + + init() + + @discardableResult + mutating func __insertUnsafe(_ element: Element) -> InsertionResult func count(_ element: Element) -> Size } extension CxxSet { + /// Creates a C++ set containing the elements of a Swift Sequence. + /// + /// This initializes the set by copying every element of the sequence. + /// + /// - Complexity: O(*n*), where *n* is the number of elements in the Swift + /// sequence + @inlinable + public init(_ sequence: S) where S.Element == Element { + self.init() + for item in sequence { + self.__insertUnsafe(item) + } + } + @inlinable public func contains(_ element: Element) -> Bool { return count(element) > 0 diff --git a/test/Interop/Cxx/stdlib/use-std-set.swift b/test/Interop/Cxx/stdlib/use-std-set.swift index 56f7b7672f166..95f0110908f10 100644 --- a/test/Interop/Cxx/stdlib/use-std-set.swift +++ b/test/Interop/Cxx/stdlib/use-std-set.swift @@ -47,4 +47,18 @@ StdSetTestSuite.test("MultisetOfCInt.contains") { expectFalse(s.contains(3)) } +StdSetTestSuite.test("SetOfCInt.init()") { + let s = SetOfCInt([1, 3, 5]) + expectTrue(s.contains(1)) + expectFalse(s.contains(2)) + expectTrue(s.contains(3)) +} + +StdSetTestSuite.test("UnorderedSetOfCInt.init()") { + let s = UnorderedSetOfCInt([1, 3, 5]) + expectTrue(s.contains(1)) + expectFalse(s.contains(2)) + expectTrue(s.contains(3)) +} + runAllTests()