diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index eea06cfb99ba2..d2010e663ebed 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1032,13 +1032,22 @@ class concat_iterator static constexpr bool ReturnsByValue = !(std::is_reference_v())> && ...); + static constexpr bool ReturnsConvertiblePointer = + std::is_pointer_v && + (std::is_convertible_v()), ValueT> && ...); using reference_type = - typename std::conditional_t; - - using handle_type = - typename std::conditional_t, - ValueT *>; + typename std::conditional_t; + + using optional_value_type = + std::conditional_t, ValueT *>; + // handle_type is used to return an optional value from `getHelper()`. If + // the type resulting from dereferencing all IterTs is a pointer that can be + // converted to `ValueT`, use that pointer type instead to avoid implicit + // conversion issues. + using handle_type = typename std::conditional_t; /// We store both the current and end iterators for each concatenated /// sequence in a tuple of pairs. @@ -1088,7 +1097,7 @@ class concat_iterator if (Begin == End) return {}; - if constexpr (ReturnsByValue) + if constexpr (ReturnsByValue || ReturnsConvertiblePointer) return *Begin; else return &*Begin; @@ -1105,8 +1114,12 @@ class concat_iterator // Loop over them, and return the first result we find. for (auto &GetHelperFn : GetHelperFns) - if (auto P = (this->*GetHelperFn)()) - return *P; + if (auto P = (this->*GetHelperFn)()) { + if constexpr (ReturnsConvertiblePointer) + return P; + else + return *P; + } llvm_unreachable("Attempted to get a pointer from an end concat iterator!"); } diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp index 286cfa745fd14..6d893c2295819 100644 --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -398,6 +398,8 @@ struct some_struct { std::string swap_val; }; +struct derives_from_some_struct : some_struct {}; + std::vector::const_iterator begin(const some_struct &s) { return s.data.begin(); } @@ -532,6 +534,33 @@ TEST(STLExtrasTest, ConcatRangeADL) { EXPECT_THAT(concat(S0, S1), ElementsAre(1, 2, 3, 4)); } +TEST(STLExtrasTest, ConcatRangeRef) { + SmallVector V12{{{1, 2}, "V12[0]"}}; + SmallVector V3456{{{3, 4}, "V3456[0]"}, + {{5, 6}, "V3456[1]"}}; + + // Use concat with `iterator type = some_namespace::some_struct *` and value + // being a reference type. + std::vector Expected = {&V12[0], &V3456[0], + &V3456[1]}; + std::vector Test; + for (auto &i : concat(V12, V3456)) + Test.push_back(&i); + EXPECT_EQ(Expected, Test); +} + +TEST(STLExtrasTest, ConcatRangePtrToDerivedClass) { + some_namespace::some_struct S0{}; + some_namespace::derives_from_some_struct S1{}; + SmallVector V0{&S0}; + SmallVector V1{&S1, &S1}; + + // Use concat over ranges of pointers to different (but related) types. + EXPECT_THAT(concat(V0, V1), + ElementsAre(&S0, static_cast(&S1), + static_cast(&S1))); +} + TEST(STLExtrasTest, MakeFirstSecondRangeADL) { // Make sure that we use the `begin`/`end` functions from `some_namespace`, // using ADL.