Skip to content

Commit 63389c0

Browse files
acozzettecopybara-github
authored andcommitted
Add Python support for retention attribute
PiperOrigin-RevId: 511914565
1 parent bcb20bb commit 63389c0

File tree

6 files changed

+251
-74
lines changed

6 files changed

+251
-74
lines changed

python/google/protobuf/internal/generator_test.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from google.protobuf import unittest_mset_pb2
5050
from google.protobuf import unittest_mset_wire_format_pb2
5151
from google.protobuf import unittest_pb2
52+
from google.protobuf import unittest_retention_pb2
5253
from google.protobuf import unittest_custom_options_pb2
5354
from google.protobuf import unittest_no_generic_services_pb2
5455

@@ -152,6 +153,82 @@ def testMessageWithCustomOptions(self):
152153
# TODO(gps): We really should test for the presence of the enum_opt1
153154
# extension and for its value to be set to -789.
154155

156+
# Options that are explicitly marked RETENTION_SOURCE should not be present
157+
# in the descriptors in the binary.
158+
def testOptionRetention(self):
159+
# Direct options
160+
options = unittest_retention_pb2.DESCRIPTOR.GetOptions()
161+
self.assertTrue(options.HasExtension(unittest_retention_pb2.plain_option))
162+
self.assertTrue(
163+
options.HasExtension(unittest_retention_pb2.runtime_retention_option)
164+
)
165+
self.assertFalse(
166+
options.HasExtension(unittest_retention_pb2.source_retention_option)
167+
)
168+
169+
def check_options_message_is_stripped_correctly(options):
170+
self.assertEqual(options.plain_field, 1)
171+
self.assertEqual(options.runtime_retention_field, 2)
172+
self.assertFalse(options.HasField('source_retention_field'))
173+
self.assertEqual(options.source_retention_field, 0)
174+
175+
# Verify that our test OptionsMessage is stripped correctly on all
176+
# different entity types.
177+
check_options_message_is_stripped_correctly(
178+
options.Extensions[unittest_retention_pb2.file_option]
179+
)
180+
check_options_message_is_stripped_correctly(
181+
unittest_retention_pb2.TopLevelMessage.DESCRIPTOR.GetOptions().Extensions[
182+
unittest_retention_pb2.message_option
183+
]
184+
)
185+
check_options_message_is_stripped_correctly(
186+
unittest_retention_pb2.TopLevelMessage.NestedMessage.DESCRIPTOR.GetOptions().Extensions[
187+
unittest_retention_pb2.message_option
188+
]
189+
)
190+
check_options_message_is_stripped_correctly(
191+
unittest_retention_pb2._TOPLEVELENUM.GetOptions().Extensions[
192+
unittest_retention_pb2.enum_option
193+
]
194+
)
195+
check_options_message_is_stripped_correctly(
196+
unittest_retention_pb2._TOPLEVELMESSAGE_NESTEDENUM.GetOptions().Extensions[
197+
unittest_retention_pb2.enum_option
198+
]
199+
)
200+
check_options_message_is_stripped_correctly(
201+
unittest_retention_pb2._TOPLEVELENUM.values[0]
202+
.GetOptions()
203+
.Extensions[unittest_retention_pb2.enum_entry_option]
204+
)
205+
check_options_message_is_stripped_correctly(
206+
unittest_retention_pb2.DESCRIPTOR.extensions_by_name['i']
207+
.GetOptions()
208+
.Extensions[unittest_retention_pb2.field_option]
209+
)
210+
check_options_message_is_stripped_correctly(
211+
unittest_retention_pb2.TopLevelMessage.DESCRIPTOR.fields[0]
212+
.GetOptions()
213+
.Extensions[unittest_retention_pb2.field_option]
214+
)
215+
check_options_message_is_stripped_correctly(
216+
unittest_retention_pb2.TopLevelMessage.DESCRIPTOR.oneofs[0]
217+
.GetOptions()
218+
.Extensions[unittest_retention_pb2.oneof_option]
219+
)
220+
check_options_message_is_stripped_correctly(
221+
unittest_retention_pb2.DESCRIPTOR.services_by_name['Service']
222+
.GetOptions()
223+
.Extensions[unittest_retention_pb2.service_option]
224+
)
225+
check_options_message_is_stripped_correctly(
226+
unittest_retention_pb2.DESCRIPTOR.services_by_name['Service']
227+
.methods[0]
228+
.GetOptions()
229+
.Extensions[unittest_retention_pb2.method_option]
230+
)
231+
155232
def testNestedTypes(self):
156233
self.assertEqual(
157234
set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types),

src/google/protobuf/compiler/python/BUILD.bazel

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ cc_library(
2828
deps = [
2929
"//src/google/protobuf:protobuf_nowkt",
3030
"//src/google/protobuf/compiler:code_generator",
31+
"//src/google/protobuf/compiler:retention",
3132
"@com_google_absl//absl/strings",
3233
"@com_google_absl//absl/synchronization",
3334
],
@@ -61,9 +62,12 @@ pkg_files(
6162

6263
filegroup(
6364
name = "test_srcs",
64-
srcs = glob([
65-
"*_test.cc",
66-
"*unittest.cc",
67-
], allow_empty = True),
65+
srcs = glob(
66+
[
67+
"*_test.cc",
68+
"*unittest.cc",
69+
],
70+
allow_empty = True,
71+
),
6872
visibility = ["//src/google/protobuf/compiler:__pkg__"],
6973
)

src/google/protobuf/compiler/python/generator.cc

Lines changed: 54 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
#include "absl/strings/substitute.h"
6565
#include "google/protobuf/compiler/python/helpers.h"
6666
#include "google/protobuf/compiler/python/pyi_generator.h"
67+
#include "google/protobuf/compiler/retention.h"
6768
#include "google/protobuf/descriptor.h"
6869
#include "google/protobuf/descriptor.pb.h"
6970
#include "google/protobuf/io/printer.h"
@@ -249,8 +250,7 @@ bool Generator::Generate(const FileDescriptor* file,
249250

250251
std::string filename = GetFileName(file, ".py");
251252

252-
FileDescriptorProto fdp;
253-
file_->CopyTo(&fdp);
253+
FileDescriptorProto fdp = StripSourceRetentionOptions(*file_);
254254
fdp.SerializeToString(&file_descriptor_serialized_);
255255

256256
if (!opensource_runtime_ && GeneratingDescriptorProto()) {
@@ -342,7 +342,7 @@ bool Generator::Generate(const FileDescriptor* file,
342342
FixAllDescriptorOptions();
343343

344344
// Set serialized_start and serialized_end.
345-
SetSerializedPbInterval();
345+
SetSerializedPbInterval(fdp);
346346

347347
printer_->Outdent();
348348
if (HasGenericServices(file)) {
@@ -442,7 +442,8 @@ void Generator::PrintFileDescriptor() const {
442442
m["name"] = file_->name();
443443
m["package"] = file_->package();
444444
m["syntax"] = StringifySyntax(file_->syntax());
445-
m["options"] = OptionsValue(file_->options().SerializeAsString());
445+
m["options"] = OptionsValue(
446+
StripLocalSourceRetentionOptions(*file_).SerializeAsString());
446447
m["serialized_descriptor"] = absl::CHexEscape(file_descriptor_serialized_);
447448
if (GeneratingDescriptorProto()) {
448449
printer_->Print("if _descriptor._USE_C_DESCRIPTORS == False:\n");
@@ -528,7 +529,8 @@ void Generator::PrintEnum(const EnumDescriptor& enum_descriptor) const {
528529
" create_key=_descriptor._internal_create_key,\n"
529530
" values=[\n";
530531
std::string options_string;
531-
enum_descriptor.options().SerializeToString(&options_string);
532+
StripLocalSourceRetentionOptions(enum_descriptor)
533+
.SerializeToString(&options_string);
532534
printer_->Print(m, enum_descriptor_template);
533535
printer_->Indent();
534536
printer_->Indent();
@@ -681,7 +683,8 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const {
681683
printer_->Outdent();
682684
printer_->Print("],\n");
683685
std::string options_string;
684-
message_descriptor.options().SerializeToString(&options_string);
686+
StripLocalSourceRetentionOptions(message_descriptor)
687+
.SerializeToString(&options_string);
685688
printer_->Print(
686689
"serialized_options=$options_value$,\n"
687690
"is_extendable=$extendable$,\n"
@@ -708,7 +711,8 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const {
708711
m["name"] = desc->name();
709712
m["full_name"] = desc->full_name();
710713
m["index"] = absl::StrCat(desc->index());
711-
options_string = OptionsValue(desc->options().SerializeAsString());
714+
options_string = OptionsValue(
715+
StripLocalSourceRetentionOptions(*desc).SerializeAsString());
712716
if (options_string == "None") {
713717
m["serialized_options"] = "";
714718
} else {
@@ -1050,7 +1054,8 @@ void Generator::PrintEnumValueDescriptor(
10501054
// TODO(robinson): Fix up EnumValueDescriptor "type" fields.
10511055
// More circular references. ::sigh::
10521056
std::string options_string;
1053-
descriptor.options().SerializeToString(&options_string);
1057+
StripLocalSourceRetentionOptions(descriptor)
1058+
.SerializeToString(&options_string);
10541059
absl::flat_hash_map<absl::string_view, std::string> m;
10551060
m["name"] = descriptor.name();
10561061
m["index"] = absl::StrCat(descriptor.index());
@@ -1078,7 +1083,7 @@ std::string Generator::OptionsValue(
10781083
void Generator::PrintFieldDescriptor(const FieldDescriptor& field,
10791084
bool is_extension) const {
10801085
std::string options_string;
1081-
field.options().SerializeToString(&options_string);
1086+
StripLocalSourceRetentionOptions(field).SerializeToString(&options_string);
10821087
absl::flat_hash_map<absl::string_view, std::string> m;
10831088
m["name"] = field.name();
10841089
m["full_name"] = field.full_name();
@@ -1216,21 +1221,17 @@ std::string Generator::InternalPackage() const {
12161221
: "google3.net.google.protobuf.python.internal";
12171222
}
12181223

1219-
// Prints standard constructor arguments serialized_start and serialized_end.
1224+
// Prints descriptor offsets _serialized_start and _serialized_end.
12201225
// Args:
1221-
// descriptor: The cpp descriptor to have a serialized reference.
1222-
// proto: A proto
1226+
// descriptor_proto: The descriptor proto to have a serialized reference.
12231227
// Example printer output:
1224-
// serialized_start=41,
1225-
// serialized_end=43,
1226-
//
1227-
template <typename DescriptorT, typename DescriptorProtoT>
1228-
void Generator::PrintSerializedPbInterval(const DescriptorT& descriptor,
1229-
DescriptorProtoT& proto,
1230-
absl::string_view name) const {
1231-
descriptor.CopyTo(&proto);
1228+
// _globals['_MYMESSAGE']._serialized_start=47
1229+
// _globals['_MYMESSAGE']._serialized_end=76
1230+
template <typename DescriptorProtoT>
1231+
void Generator::PrintSerializedPbInterval(
1232+
const DescriptorProtoT& descriptor_proto, absl::string_view name) const {
12321233
std::string sp;
1233-
proto.SerializeToString(&sp);
1234+
descriptor_proto.SerializeToString(&sp);
12341235
int offset = file_descriptor_serialized_.find(sp);
12351236
ABSL_CHECK_GE(offset, 0);
12361237

@@ -1254,51 +1255,56 @@ void PrintDescriptorOptionsFixingCode(absl::string_view descriptor,
12541255
}
12551256
} // namespace
12561257

1257-
void Generator::SetSerializedPbInterval() const {
1258+
// Generates the start and end offsets for each entity in the serialized file
1259+
// descriptor. The file argument must exactly match what was serialized into
1260+
// file_descriptor_serialized_, and should already have had any
1261+
// source-retention options stripped out. This is important because we need an
1262+
// exact byte-for-byte match so that we can successfully find the correct
1263+
// offsets in the serialized descriptors.
1264+
void Generator::SetSerializedPbInterval(const FileDescriptorProto& file) const {
12581265
// Top level enums.
12591266
for (int i = 0; i < file_->enum_type_count(); ++i) {
1260-
EnumDescriptorProto proto;
12611267
const EnumDescriptor& descriptor = *file_->enum_type(i);
1262-
PrintSerializedPbInterval(descriptor, proto,
1268+
PrintSerializedPbInterval(file.enum_type(i),
12631269
ModuleLevelDescriptorName(descriptor));
12641270
}
12651271

12661272
// Messages.
12671273
for (int i = 0; i < file_->message_type_count(); ++i) {
1268-
SetMessagePbInterval(*file_->message_type(i));
1274+
SetMessagePbInterval(file.message_type(i), *file_->message_type(i));
12691275
}
12701276

12711277
// Services.
12721278
for (int i = 0; i < file_->service_count(); ++i) {
1273-
ServiceDescriptorProto proto;
12741279
const ServiceDescriptor& service = *file_->service(i);
1275-
PrintSerializedPbInterval(service, proto,
1280+
PrintSerializedPbInterval(file.service(i),
12761281
ModuleLevelServiceDescriptorName(service));
12771282
}
12781283
}
12791284

1280-
void Generator::SetMessagePbInterval(const Descriptor& descriptor) const {
1281-
DescriptorProto message_proto;
1282-
PrintSerializedPbInterval(descriptor, message_proto,
1285+
void Generator::SetMessagePbInterval(const DescriptorProto& message_proto,
1286+
const Descriptor& descriptor) const {
1287+
PrintSerializedPbInterval(message_proto,
12831288
ModuleLevelDescriptorName(descriptor));
12841289

12851290
// Nested messages.
12861291
for (int i = 0; i < descriptor.nested_type_count(); ++i) {
1287-
SetMessagePbInterval(*descriptor.nested_type(i));
1292+
SetMessagePbInterval(message_proto.nested_type(i),
1293+
*descriptor.nested_type(i));
12881294
}
12891295

12901296
for (int i = 0; i < descriptor.enum_type_count(); ++i) {
1291-
EnumDescriptorProto proto;
12921297
const EnumDescriptor& enum_des = *descriptor.enum_type(i);
1293-
PrintSerializedPbInterval(enum_des, proto,
1298+
PrintSerializedPbInterval(message_proto.enum_type(i),
12941299
ModuleLevelDescriptorName(enum_des));
12951300
}
12961301
}
12971302

12981303
// Prints expressions that set the options field of all descriptors.
12991304
void Generator::FixAllDescriptorOptions() const {
13001305
// Prints an expression that sets the file descriptor's options.
1301-
std::string file_options = OptionsValue(file_->options().SerializeAsString());
1306+
std::string file_options = OptionsValue(
1307+
StripLocalSourceRetentionOptions(*file_).SerializeAsString());
13021308
if (file_options != "None") {
13031309
PrintDescriptorOptionsFixingCode(kDescriptorKey, file_options, printer_);
13041310
} else {
@@ -1326,7 +1332,8 @@ void Generator::FixAllDescriptorOptions() const {
13261332
}
13271333

13281334
void Generator::FixOptionsForOneof(const OneofDescriptor& oneof) const {
1329-
std::string oneof_options = OptionsValue(oneof.options().SerializeAsString());
1335+
std::string oneof_options =
1336+
OptionsValue(StripLocalSourceRetentionOptions(oneof).SerializeAsString());
13301337
if (oneof_options != "None") {
13311338
std::string oneof_name = absl::Substitute(
13321339
"$0.$1['$2']", ModuleLevelDescriptorName(*oneof.containing_type()),
@@ -1339,15 +1346,15 @@ void Generator::FixOptionsForOneof(const OneofDescriptor& oneof) const {
13391346
// value descriptors.
13401347
void Generator::FixOptionsForEnum(const EnumDescriptor& enum_descriptor) const {
13411348
std::string descriptor_name = ModuleLevelDescriptorName(enum_descriptor);
1342-
std::string enum_options =
1343-
OptionsValue(enum_descriptor.options().SerializeAsString());
1349+
std::string enum_options = OptionsValue(
1350+
StripLocalSourceRetentionOptions(enum_descriptor).SerializeAsString());
13441351
if (enum_options != "None") {
13451352
PrintDescriptorOptionsFixingCode(descriptor_name, enum_options, printer_);
13461353
}
13471354
for (int i = 0; i < enum_descriptor.value_count(); ++i) {
13481355
const EnumValueDescriptor& value_descriptor = *enum_descriptor.value(i);
1349-
std::string value_options =
1350-
OptionsValue(value_descriptor.options().SerializeAsString());
1356+
std::string value_options = OptionsValue(
1357+
StripLocalSourceRetentionOptions(value_descriptor).SerializeAsString());
13511358
if (value_options != "None") {
13521359
PrintDescriptorOptionsFixingCode(
13531360
absl::StrFormat("%s.values_by_name[\"%s\"]", descriptor_name.c_str(),
@@ -1363,17 +1370,17 @@ void Generator::FixOptionsForService(
13631370
const ServiceDescriptor& service_descriptor) const {
13641371
std::string descriptor_name =
13651372
ModuleLevelServiceDescriptorName(service_descriptor);
1366-
std::string service_options =
1367-
OptionsValue(service_descriptor.options().SerializeAsString());
1373+
std::string service_options = OptionsValue(
1374+
StripLocalSourceRetentionOptions(service_descriptor).SerializeAsString());
13681375
if (service_options != "None") {
13691376
PrintDescriptorOptionsFixingCode(descriptor_name, service_options,
13701377
printer_);
13711378
}
13721379

13731380
for (int i = 0; i < service_descriptor.method_count(); ++i) {
13741381
const MethodDescriptor* method = service_descriptor.method(i);
1375-
std::string method_options =
1376-
OptionsValue(method->options().SerializeAsString());
1382+
std::string method_options = OptionsValue(
1383+
StripLocalSourceRetentionOptions(*method).SerializeAsString());
13771384
if (method_options != "None") {
13781385
std::string method_name = absl::StrCat(
13791386
descriptor_name, ".methods_by_name['", method->name(), "']");
@@ -1385,7 +1392,8 @@ void Generator::FixOptionsForService(
13851392
// Prints expressions that set the options for field descriptors (including
13861393
// extensions).
13871394
void Generator::FixOptionsForField(const FieldDescriptor& field) const {
1388-
std::string field_options = OptionsValue(field.options().SerializeAsString());
1395+
std::string field_options =
1396+
OptionsValue(StripLocalSourceRetentionOptions(field).SerializeAsString());
13891397
if (field_options != "None") {
13901398
std::string field_name;
13911399
if (field.is_extension()) {
@@ -1430,8 +1438,8 @@ void Generator::FixOptionsForMessage(const Descriptor& descriptor) const {
14301438
FixOptionsForField(field);
14311439
}
14321440
// Message option for this message.
1433-
std::string message_options =
1434-
OptionsValue(descriptor.options().SerializeAsString());
1441+
std::string message_options = OptionsValue(
1442+
StripLocalSourceRetentionOptions(descriptor).SerializeAsString());
14351443
if (message_options != "None") {
14361444
std::string descriptor_name = ModuleLevelDescriptorName(descriptor);
14371445
PrintDescriptorOptionsFixingCode(descriptor_name, message_options,

0 commit comments

Comments
 (0)