Skip to content

Commit 9b019ee

Browse files
ezbrcopybara-github
authored andcommitted
Add prefetching of subsequent extensions in ExtensionSet::ForEach.
PiperOrigin-RevId: 671457336
1 parent f72e5ce commit 9b019ee

File tree

4 files changed

+244
-131
lines changed

4 files changed

+244
-131
lines changed

src/google/protobuf/extension_set.cc

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ void ExtensionSet::RegisterMessageExtension(const MessageLite* extendee,
186186
ExtensionSet::~ExtensionSet() {
187187
// Deletes all allocated extensions.
188188
if (arena_ == nullptr) {
189-
ForEach([](int /* number */, Extension& ext) { ext.Free(); });
189+
ForEach([](int /* number */, Extension& ext) { ext.Free(); },
190+
PrefetchNta{});
190191
if (PROTOBUF_PREDICT_FALSE(is_large())) {
191192
delete map_.large;
192193
} else {
@@ -225,7 +226,7 @@ bool ExtensionSet::HasLazy(int number) const {
225226

226227
int ExtensionSet::NumExtensions() const {
227228
int result = 0;
228-
ForEach([&result](int /* number */, const Extension& ext) {
229+
ForEachNoPrefetch([&result](int /* number */, const Extension& ext) {
229230
if (!ext.is_cleared) {
230231
++result;
231232
}
@@ -308,6 +309,7 @@ enum { REPEATED_FIELD, OPTIONAL_FIELD };
308309
ABSL_DCHECK_EQ(cpp_type(extension->type), \
309310
WireFormatLite::CPPTYPE_##UPPERCASE); \
310311
extension->is_repeated = false; \
312+
extension->is_pointer = false; \
311313
} else { \
312314
ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, UPPERCASE); \
313315
} \
@@ -351,6 +353,7 @@ enum { REPEATED_FIELD, OPTIONAL_FIELD };
351353
ABSL_DCHECK_EQ(cpp_type(extension->type), \
352354
WireFormatLite::CPPTYPE_##UPPERCASE); \
353355
extension->is_repeated = true; \
356+
extension->is_pointer = true; \
354357
extension->is_packed = packed; \
355358
extension->ptr.repeated_##LOWERCASE##_value = \
356359
Arena::Create<RepeatedField<LOWERCASE>>(arena_); \
@@ -391,6 +394,7 @@ void* ExtensionSet::MutableRawRepeatedField(int number, FieldType field_type,
391394
// extension.
392395
if (MaybeNewExtension(number, desc, &extension)) {
393396
extension->is_repeated = true;
397+
extension->is_pointer = true;
394398
extension->type = field_type;
395399
extension->is_packed = packed;
396400

@@ -487,6 +491,7 @@ void ExtensionSet::SetEnum(int number, FieldType type, int value,
487491
extension->type = type;
488492
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_ENUM);
489493
extension->is_repeated = false;
494+
extension->is_pointer = false;
490495
} else {
491496
ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, ENUM);
492497
}
@@ -522,6 +527,7 @@ void ExtensionSet::AddEnum(int number, FieldType type, bool packed, int value,
522527
extension->type = type;
523528
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_ENUM);
524529
extension->is_repeated = true;
530+
extension->is_pointer = true;
525531
extension->is_packed = packed;
526532
extension->ptr.repeated_enum_value =
527533
Arena::Create<RepeatedField<int>>(arena_);
@@ -554,6 +560,7 @@ std::string* ExtensionSet::MutableString(int number, FieldType type,
554560
extension->type = type;
555561
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_STRING);
556562
extension->is_repeated = false;
563+
extension->is_pointer = true;
557564
extension->ptr.string_value = Arena::Create<std::string>(arena_);
558565
} else {
559566
ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, STRING);
@@ -584,6 +591,7 @@ std::string* ExtensionSet::AddString(int number, FieldType type,
584591
extension->type = type;
585592
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_STRING);
586593
extension->is_repeated = true;
594+
extension->is_pointer = true;
587595
extension->is_packed = false;
588596
extension->ptr.repeated_string_value =
589597
Arena::Create<RepeatedPtrField<std::string>>(arena_);
@@ -626,6 +634,7 @@ MessageLite* ExtensionSet::MutableMessage(int number, FieldType type,
626634
extension->type = type;
627635
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
628636
extension->is_repeated = false;
637+
extension->is_pointer = true;
629638
extension->is_lazy = false;
630639
extension->ptr.message_value = prototype.New(arena_);
631640
extension->is_cleared = false;
@@ -663,6 +672,7 @@ void ExtensionSet::SetAllocatedMessage(int number, FieldType type,
663672
extension->type = type;
664673
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
665674
extension->is_repeated = false;
675+
extension->is_pointer = true;
666676
extension->is_lazy = false;
667677
if (message_arena == arena) {
668678
extension->ptr.message_value = message;
@@ -707,6 +717,7 @@ void ExtensionSet::UnsafeArenaSetAllocatedMessage(
707717
extension->type = type;
708718
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
709719
extension->is_repeated = false;
720+
extension->is_pointer = true;
710721
extension->is_lazy = false;
711722
extension->ptr.message_value = message;
712723
} else {
@@ -805,6 +816,7 @@ MessageLite* ExtensionSet::AddMessage(int number, FieldType type,
805816
extension->type = type;
806817
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
807818
extension->is_repeated = true;
819+
extension->is_pointer = true;
808820
extension->ptr.repeated_message_value =
809821
Arena::Create<RepeatedPtrField<MessageLite>>(arena_);
810822
} else {
@@ -920,7 +932,7 @@ void ExtensionSet::SwapElements(int number, int index1, int index2) {
920932
// ===================================================================
921933

922934
void ExtensionSet::Clear() {
923-
ForEach([](int /* number */, Extension& ext) { ext.Clear(); });
935+
ForEach([](int /* number */, Extension& ext) { ext.Clear(); }, Prefetch{});
924936
}
925937

926938
namespace {
@@ -969,9 +981,11 @@ void ExtensionSet::MergeFrom(const MessageLite* extendee,
969981
other.map_.large->end()));
970982
}
971983
}
972-
other.ForEach([extendee, this, &other](int number, const Extension& ext) {
973-
this->InternalExtensionMergeFrom(extendee, number, ext, other.arena_);
974-
});
984+
other.ForEach(
985+
[extendee, this, &other](int number, const Extension& ext) {
986+
this->InternalExtensionMergeFrom(extendee, number, ext, other.arena_);
987+
},
988+
Prefetch{});
975989
}
976990

977991
void ExtensionSet::InternalExtensionMergeFrom(const MessageLite* extendee,
@@ -987,6 +1001,7 @@ void ExtensionSet::InternalExtensionMergeFrom(const MessageLite* extendee,
9871001
extension->type = other_extension.type;
9881002
extension->is_packed = other_extension.is_packed;
9891003
extension->is_repeated = true;
1004+
extension->is_pointer = true;
9901005
} else {
9911006
ABSL_DCHECK_EQ(extension->type, other_extension.type);
9921007
ABSL_DCHECK_EQ(extension->is_packed, other_extension.is_packed);
@@ -1049,6 +1064,7 @@ void ExtensionSet::InternalExtensionMergeFrom(const MessageLite* extendee,
10491064
extension->type = other_extension.type;
10501065
extension->is_packed = other_extension.is_packed;
10511066
extension->is_repeated = false;
1067+
extension->is_pointer = true;
10521068
if (other_extension.is_lazy) {
10531069
extension->is_lazy = true;
10541070
extension->ptr.lazymessage_value =
@@ -1226,6 +1242,13 @@ const char* ExtensionSet::ParseMessageSetItem(
12261242
metadata, ctx);
12271243
}
12281244

1245+
bool ExtensionSet::FieldTypeIsPointer(FieldType type) {
1246+
return type == WireFormatLite::TYPE_STRING ||
1247+
type == WireFormatLite::TYPE_BYTES ||
1248+
type == WireFormatLite::TYPE_GROUP ||
1249+
type == WireFormatLite::TYPE_MESSAGE;
1250+
}
1251+
12291252
uint8_t* ExtensionSet::_InternalSerializeImpl(
12301253
const MessageLite* extendee, int start_field_number, int end_field_number,
12311254
uint8_t* target, io::EpsCopyOutputStream* stream) const {
@@ -1252,19 +1275,23 @@ uint8_t* ExtensionSet::InternalSerializeMessageSetWithCachedSizesToArray(
12521275
const MessageLite* extendee, uint8_t* target,
12531276
io::EpsCopyOutputStream* stream) const {
12541277
const ExtensionSet* extension_set = this;
1255-
ForEach([&target, extendee, stream, extension_set](int number,
1256-
const Extension& ext) {
1257-
target = ext.InternalSerializeMessageSetItemWithCachedSizesToArray(
1258-
extendee, extension_set, number, target, stream);
1259-
});
1278+
ForEach(
1279+
[&target, extendee, stream, extension_set](int number,
1280+
const Extension& ext) {
1281+
target = ext.InternalSerializeMessageSetItemWithCachedSizesToArray(
1282+
extendee, extension_set, number, target, stream);
1283+
},
1284+
Prefetch{});
12601285
return target;
12611286
}
12621287

12631288
size_t ExtensionSet::ByteSize() const {
12641289
size_t total_size = 0;
1265-
ForEach([&total_size](int number, const Extension& ext) {
1266-
total_size += ext.ByteSize(number);
1267-
});
1290+
ForEach(
1291+
[&total_size](int number, const Extension& ext) {
1292+
total_size += ext.ByteSize(number);
1293+
},
1294+
Prefetch{});
12681295
return total_size;
12691296
}
12701297

@@ -1932,9 +1959,11 @@ size_t ExtensionSet::Extension::MessageSetItemByteSize(int number) const {
19321959

19331960
size_t ExtensionSet::MessageSetByteSize() const {
19341961
size_t total_size = 0;
1935-
ForEach([&total_size](int number, const Extension& ext) {
1936-
total_size += ext.MessageSetItemByteSize(number);
1937-
});
1962+
ForEach(
1963+
[&total_size](int number, const Extension& ext) {
1964+
total_size += ext.MessageSetItemByteSize(number);
1965+
},
1966+
Prefetch{});
19381967
return total_size;
19391968
}
19401969

src/google/protobuf/extension_set.h

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
#include "google/protobuf/stubs/common.h"
3030
#include "absl/base/call_once.h"
31+
#include "absl/base/casts.h"
32+
#include "absl/base/prefetch.h"
3133
#include "absl/container/btree_map.h"
3234
#include "absl/log/absl_check.h"
3335
#include "google/protobuf/internal_visibility.h"
@@ -555,6 +557,8 @@ class PROTOBUF_EXPORT ExtensionSet {
555557

556558
friend void internal::InitializeLazyExtensionSet();
557559

560+
static bool FieldTypeIsPointer(FieldType type);
561+
558562
const int32_t& GetRefInt32(int number, const int32_t& default_value) const;
559563
const int64_t& GetRefInt64(int number, const int64_t& default_value) const;
560564
const uint32_t& GetRefUInt32(int number, const uint32_t& default_value) const;
@@ -670,6 +674,12 @@ class PROTOBUF_EXPORT ExtensionSet {
670674
size_t SpaceUsedExcludingSelfLong() const;
671675
bool IsInitialized(const ExtensionSet* ext_set, const MessageLite* extendee,
672676
int number, Arena* arena) const;
677+
const void* PrefetchPtr() const {
678+
ABSL_DCHECK_EQ(is_pointer, is_repeated || FieldTypeIsPointer(type));
679+
// We don't want to prefetch invalid/null pointers so if there isn't a
680+
// pointer to prefetch, then return `this`.
681+
return is_pointer ? absl::bit_cast<const void*>(ptr) : this;
682+
}
673683

674684
// The order of these fields packs Extension into 24 bytes when using 8
675685
// byte alignment. Consider this when adding or removing fields here.
@@ -708,20 +718,23 @@ class PROTOBUF_EXPORT ExtensionSet {
708718
FieldType type;
709719
bool is_repeated;
710720

721+
// Whether the extension is a pointer. This is used for prefetching.
722+
bool is_pointer : 1;
723+
711724
// For singular types, indicates if the extension is "cleared". This
712725
// happens when an extension is set and then later cleared by the caller.
713726
// We want to keep the Extension object around for reuse, so instead of
714727
// removing it from the map, we just set is_cleared = true. This has no
715728
// meaning for repeated types; for those, the size of the RepeatedField
716729
// simply becomes zero when cleared.
717-
bool is_cleared : 4;
730+
bool is_cleared : 1;
718731

719732
// For singular message types, indicates whether lazy parsing is enabled
720733
// for this extension. This field is only valid when type == TYPE_MESSAGE
721734
// and !is_repeated because we only support lazy parsing for singular
722735
// message types currently. If is_lazy = true, the extension is stored in
723736
// lazymessage_value. Otherwise, the extension will be message_value.
724-
bool is_lazy : 4;
737+
bool is_lazy : 1;
725738

726739
// For repeated types, this indicates if the [packed=true] option is set.
727740
bool is_packed;
@@ -779,32 +792,93 @@ class PROTOBUF_EXPORT ExtensionSet {
779792
return PROTOBUF_PREDICT_FALSE(is_large()) ? map_.large->size() : flat_size_;
780793
}
781794

795+
// For use as `PrefetchFunctor`s in `ForEach`.
796+
struct Prefetch {
797+
void operator()(const void* ptr) const { absl::PrefetchToLocalCache(ptr); }
798+
};
799+
struct PrefetchNta {
800+
void operator()(const void* ptr) const {
801+
absl::PrefetchToLocalCacheNta(ptr);
802+
}
803+
};
804+
805+
template <typename Iterator, typename KeyValueFunctor,
806+
typename PrefetchFunctor>
807+
static KeyValueFunctor ForEachPrefetchImpl(Iterator it, Iterator end,
808+
KeyValueFunctor func,
809+
PrefetchFunctor prefetch_func) {
810+
// Note: based on arena's ChunkList::Cleanup().
811+
// Prefetch distance 16 performs better than 8 in load tests.
812+
constexpr int kPrefetchDistance = 16;
813+
Iterator prefetch = it;
814+
// Prefetch the first kPrefetchDistance extensions.
815+
for (int i = 0; prefetch != end && i < kPrefetchDistance; ++prefetch, ++i) {
816+
prefetch_func(prefetch->second.PrefetchPtr());
817+
}
818+
// For the middle extensions, call func and then prefetch the extension
819+
// kPrefetchDistance after the current one.
820+
for (; prefetch != end; ++it, ++prefetch) {
821+
func(it->first, it->second);
822+
prefetch_func(prefetch->second.PrefetchPtr());
823+
}
824+
// Call func on the rest without prefetching.
825+
for (; it != end; ++it) func(it->first, it->second);
826+
return std::move(func);
827+
}
828+
782829
// Similar to std::for_each.
783830
// Each Iterator is decomposed into ->first and ->second fields, so
784831
// that the KeyValueFunctor can be agnostic vis-a-vis KeyValue-vs-std::pair.
832+
// Applies a functor to the <int, Extension&> pairs in sorted order and
833+
// prefetches ahead.
834+
template <typename KeyValueFunctor, typename PrefetchFunctor>
835+
KeyValueFunctor ForEach(KeyValueFunctor func, PrefetchFunctor prefetch_func) {
836+
if (PROTOBUF_PREDICT_FALSE(is_large())) {
837+
return ForEachPrefetchImpl(map_.large->begin(), map_.large->end(),
838+
std::move(func), std::move(prefetch_func));
839+
}
840+
return ForEachPrefetchImpl(flat_begin(), flat_end(), std::move(func),
841+
std::move(prefetch_func));
842+
}
843+
// As above, but const.
844+
template <typename KeyValueFunctor, typename PrefetchFunctor>
845+
KeyValueFunctor ForEach(KeyValueFunctor func,
846+
PrefetchFunctor prefetch_func) const {
847+
if (PROTOBUF_PREDICT_FALSE(is_large())) {
848+
return ForEachPrefetchImpl(map_.large->begin(), map_.large->end(),
849+
std::move(func), std::move(prefetch_func));
850+
}
851+
return ForEachPrefetchImpl(flat_begin(), flat_end(), std::move(func),
852+
std::move(prefetch_func));
853+
}
854+
855+
// As above, but without prefetching. This is for use in cases where we never
856+
// use the pointed-to extension values in `func`.
785857
template <typename Iterator, typename KeyValueFunctor>
786-
static KeyValueFunctor ForEach(Iterator begin, Iterator end,
787-
KeyValueFunctor func) {
858+
static KeyValueFunctor ForEachNoPrefetch(Iterator begin, Iterator end,
859+
KeyValueFunctor func) {
788860
for (Iterator it = begin; it != end; ++it) func(it->first, it->second);
789861
return std::move(func);
790862
}
791863

792864
// Applies a functor to the <int, Extension&> pairs in sorted order.
793865
template <typename KeyValueFunctor>
794-
KeyValueFunctor ForEach(KeyValueFunctor func) {
866+
KeyValueFunctor ForEachNoPrefetch(KeyValueFunctor func) {
795867
if (PROTOBUF_PREDICT_FALSE(is_large())) {
796-
return ForEach(map_.large->begin(), map_.large->end(), std::move(func));
868+
return ForEachNoPrefetch(map_.large->begin(), map_.large->end(),
869+
std::move(func));
797870
}
798-
return ForEach(flat_begin(), flat_end(), std::move(func));
871+
return ForEachNoPrefetch(flat_begin(), flat_end(), std::move(func));
799872
}
800873

801-
// Applies a functor to the <int, const Extension&> pairs in sorted order.
874+
// As above, but const.
802875
template <typename KeyValueFunctor>
803-
KeyValueFunctor ForEach(KeyValueFunctor func) const {
876+
KeyValueFunctor ForEachNoPrefetch(KeyValueFunctor func) const {
804877
if (PROTOBUF_PREDICT_FALSE(is_large())) {
805-
return ForEach(map_.large->begin(), map_.large->end(), std::move(func));
878+
return ForEachNoPrefetch(map_.large->begin(), map_.large->end(),
879+
std::move(func));
806880
}
807-
return ForEach(flat_begin(), flat_end(), std::move(func));
881+
return ForEachNoPrefetch(flat_begin(), flat_end(), std::move(func));
808882
}
809883

810884
// Merges existing Extension from other_extension

0 commit comments

Comments
 (0)