From c752f28d0e4c5356daf3f5457dd815a8a8828632 Mon Sep 17 00:00:00 2001 From: Jie Luo Date: Wed, 21 May 2025 13:10:03 -0700 Subject: [PATCH] python pyi print "import datetime" for Duration/Timestamp field PiperOrigin-RevId: 761639872 --- .../protobuf/compiler/python/pyi_generator.cc | 50 ++++++++++++------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/src/google/protobuf/compiler/python/pyi_generator.cc b/src/google/protobuf/compiler/python/pyi_generator.cc index d687259cf9c73..31e1fb8545485 100644 --- a/src/google/protobuf/compiler/python/pyi_generator.cc +++ b/src/google/protobuf/compiler/python/pyi_generator.cc @@ -73,6 +73,7 @@ struct ImportModules { bool has_union = false; // typing.Union bool has_callable = false; // typing.Callable bool has_well_known_type = false; + bool has_datetime = false; }; // Checks whether a descriptor name matches a well-known type. @@ -112,8 +113,15 @@ void CheckImportModules(const Descriptor* descriptor, if (field->is_map()) { import_modules->has_mapping = true; const FieldDescriptor* value_des = field->message_type()->field(1); - if (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE || - value_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { + if (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + import_modules->has_union = true; + const absl::string_view name = value_des->message_type()->full_name(); + if (name == "google.protobuf.Duration" || + name == "google.protobuf.Timestamp") { + import_modules->has_datetime = true; + } + } + if (value_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { import_modules->has_union = true; } } else { @@ -123,6 +131,11 @@ void CheckImportModules(const Descriptor* descriptor, if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { import_modules->has_union = true; import_modules->has_mapping = true; + const absl::string_view name = field->message_type()->full_name(); + if (name == "google.protobuf.Duration" || + name == "google.protobuf.Timestamp") { + import_modules->has_datetime = true; + } } if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { import_modules->has_union = true; @@ -170,21 +183,6 @@ void PyiGenerator::PrintImportForDescriptor( } void PyiGenerator::PrintImports() const { - // Prints imported dependent _pb2 files. - absl::flat_hash_set seen_aliases; - bool has_importlib = false; - for (int i = 0; i < file_->dependency_count(); ++i) { - const FileDescriptor* dep = file_->dependency(i); - if (strip_nonfunctional_codegen_ && IsKnownFeatureProto(dep->name())) { - continue; - } - PrintImportForDescriptor(*dep, &seen_aliases, &has_importlib); - for (int j = 0; j < dep->public_dependency_count(); ++j) { - PrintImportForDescriptor(*dep->public_dependency(j), &seen_aliases, - &has_importlib); - } - } - // Checks what modules should be imported. ImportModules import_modules; if (file_->message_type_count() > 0) { @@ -201,6 +199,24 @@ void PyiGenerator::PrintImports() const { for (int i = 0; i < file_->message_type_count(); i++) { CheckImportModules(file_->message_type(i), &import_modules); } + if (import_modules.has_datetime) { + printer_->Print("import datetime\n\n"); + } + + // Prints imported dependent _pb2 files. + absl::flat_hash_set seen_aliases; + bool has_importlib = false; + for (int i = 0; i < file_->dependency_count(); ++i) { + const FileDescriptor* dep = file_->dependency(i); + if (strip_nonfunctional_codegen_ && IsKnownFeatureProto(dep->name())) { + continue; + } + PrintImportForDescriptor(*dep, &seen_aliases, &has_importlib); + for (int j = 0; j < dep->public_dependency_count(); ++j) { + PrintImportForDescriptor(*dep->public_dependency(j), &seen_aliases, + &has_importlib); + } + } // Prints modules (e.g. _containers, _messages, typing) that are // required in the proto file.