Skip to content

Commit 10f74ca

Browse files
committed
Improve matching of typing names in the pyi parser.
Adds support for aliasing names like TypeVar and simplifies name matching by requiring the caller to pass in only one variant of a name (e.g., "typing.Callable") and having the matcher automatically try all variants ("Callable", "typing.Callable", "collections.abc.Callable"). See python/typeshed#10201 for motivation. Fixes #1430. PiperOrigin-RevId: 534721075
1 parent b8ef372 commit 10f74ca

File tree

7 files changed

+82
-91
lines changed

7 files changed

+82
-91
lines changed

pytype/pyi/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ py_library(
6666
.classdef
6767
.metadata
6868
.types
69+
pytype.utils
6970
pytype.pytd.pytd_for_parser
7071
pytype.pytd.parse.parse
7172
)

pytype/pyi/classdef.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
import collections
44
import sys
55

6-
from typing import cast, Dict, List
6+
from typing import cast, Callable, Dict, List
77

88
from pytype.pyi import types
99
from pytype.pytd import pytd
10-
from pytype.pytd import pytd_utils
1110
from pytype.pytd.parse import node as pytd_node
1211

1312
# pylint: disable=g-import-not-at-top
@@ -17,18 +16,17 @@
1716
from typed_ast import ast3
1817
# pylint: enable=g-import-not-at-top
1918

20-
21-
_PROTOCOL_ALIASES = ("typing.Protocol", "typing_extensions.Protocol")
2219
ParseError = types.ParseError
2320

2421

25-
def get_bases(bases: List[pytd.Type]) -> List[pytd.Type]:
22+
def get_bases(
23+
bases: List[pytd.Type], type_match: Callable[..., bool]) -> List[pytd.Type]:
2624
"""Collect base classes."""
2725

2826
bases_out = []
2927
namedtuple_index = None
3028
for i, p in enumerate(bases):
31-
if p.name and pytd_utils.MatchesFullName(p, _PROTOCOL_ALIASES):
29+
if p.name and type_match(p.name, "typing.Protocol"):
3230
if isinstance(p, pytd.GenericType):
3331
# From PEP 544: "`Protocol[T, S, ...]` is allowed as a shorthand for
3432
# `Protocol, Generic[T, S, ...]`."

pytype/pyi/definitions.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import dataclasses
55
import sys
66

7-
from typing import Any, Dict, List, Optional, Union
7+
from typing import Any, Dict, List, Optional, Tuple, Union
88

9+
from pytype import utils
910
from pytype.pyi import classdef
1011
from pytype.pyi import metadata
1112
from pytype.pyi import types
@@ -19,6 +20,7 @@
1920
from pytype.pytd.codegen import namedtuple
2021
from pytype.pytd.codegen import pytdgen
2122
from pytype.pytd.parse import node as pytd_node
23+
from pytype.pytd.parse import parser_constants
2224

2325
# pylint: disable=g-import-not-at-top
2426
if sys.version_info >= (3, 8):
@@ -31,13 +33,6 @@
3133
# Typing members that represent sets of types.
3234
_TYPING_SETS = ("typing.Intersection", "typing.Optional", "typing.Union")
3335

34-
# Aliases for some typing.X types
35-
_ANNOTATED_TYPES = ("typing.Annotated", "typing_extensions.Annotated")
36-
_CALLABLE_TYPES = ("typing.Callable", "collections.abc.Callable")
37-
_CONCATENATE_TYPES = ("typing.Concatenate", "typing_extensions.Concatenate")
38-
_LITERAL_TYPES = ("typing.Literal", "typing_extensions.Literal")
39-
_TUPLE_TYPES = ("tuple", "builtins.tuple", "typing.Tuple")
40-
4136

4237
class StringParseError(ParseError):
4338
pass
@@ -493,23 +488,43 @@ def add_import(self, from_package, import_list):
493488
if t:
494489
self.aliases[t.new_name] = t.pytd_alias()
495490

496-
def _matches_full_name(self, t, full_name):
497-
"""Whether t.name matches full_name in format {module}.{member}."""
498-
return pytd_utils.MatchesFullName(
499-
t, full_name, self.module_info.module_name, self.aliases)
491+
def _resolve_alias(self, name: str) -> str:
492+
if name in self.aliases:
493+
alias = self.aliases[name].type
494+
if isinstance(alias, pytd.NamedType):
495+
name = alias.name
496+
elif isinstance(alias, pytd.Module):
497+
name = alias.module_name
498+
return name
499+
500+
def matches_type(self, name: str, target: Union[str, Tuple[str, ...]]):
501+
"""Checks whether 'name' matches the 'target' type."""
502+
if isinstance(target, tuple):
503+
return any(self.matches_type(name, t) for t in target)
504+
assert "." in target, "'target' must be a fully qualified type name"
505+
if "." in name:
506+
prefix, name_base = name.rsplit(".", 1)
507+
name = f"{self._resolve_alias(prefix)}.{name_base}"
508+
else:
509+
name = self._resolve_alias(name)
510+
name = utils.strip_prefix(name, parser_constants.EXTERNAL_NAME_PREFIX)
511+
if name == target:
512+
return True
513+
module, target_base = target.rsplit(".", 1)
514+
if name == target_base:
515+
return True
516+
if module == "builtins":
517+
return self.matches_type(name, f"typing.{target_base.title()}")
518+
equivalent_modules = {"typing", "collections.abc", "typing_extensions"}
519+
if module not in equivalent_modules:
520+
return False
521+
return any(name == f"{mod}.{target_base}" for mod in equivalent_modules)
500522

501523
def _matches_named_type(self, t, names):
502524
"""Whether t is a NamedType matching any of names."""
503525
if not isinstance(t, pytd.NamedType):
504526
return False
505-
for name in names:
506-
if "." in name:
507-
if self._matches_full_name(t, name):
508-
return True
509-
else:
510-
if t.name == name:
511-
return True
512-
return False
527+
return self.matches_type(t.name, names)
513528

514529
def _is_empty_tuple(self, t):
515530
return isinstance(t, pytd.TupleType) and not t.parameters
@@ -551,22 +566,22 @@ def _remove_unsupported_features(self, parameters):
551566

552567
def _parameterized_type(self, base_type: Any, parameters):
553568
"""Return a parameterized type."""
554-
if self._matches_named_type(base_type, _LITERAL_TYPES):
569+
if self._matches_named_type(base_type, "typing.Literal"):
555570
return pytd_literal(parameters, self.aliases)
556-
elif self._matches_named_type(base_type, _ANNOTATED_TYPES):
571+
elif self._matches_named_type(base_type, "typing.Annotated"):
557572
return pytd_annotated(parameters)
558573
self._verify_no_literal_parameters(base_type, parameters)
559574
arg_is_paramspec = False
560-
if self._matches_named_type(base_type, _TUPLE_TYPES):
575+
if self._matches_named_type(base_type, "builtins.tuple"):
561576
if len(parameters) == 2 and parameters[1] is self.ELLIPSIS:
562577
parameters = parameters[:1]
563578
builder = pytd.GenericType
564579
else:
565580
builder = pytdgen.heterogeneous_tuple
566-
elif self._matches_named_type(base_type, _CONCATENATE_TYPES):
581+
elif self._matches_named_type(base_type, "typing.Concatenate"):
567582
assert parameters
568583
builder = pytd.Concatenate
569-
elif self._matches_named_type(base_type, _CALLABLE_TYPES):
584+
elif self._matches_named_type(base_type, "typing.Callable"):
570585
if parameters[0] is self.ELLIPSIS:
571586
parameters = (pytd.AnythingType(),) + parameters[1:]
572587
if parameters and isinstance(parameters[0], pytd.NamedType):
@@ -661,7 +676,7 @@ def build_class(
661676
self, class_name, bases, keywords, decorators, defs
662677
) -> pytd.Class:
663678
"""Build a pytd.Class from definitions collected from an ast node."""
664-
bases = classdef.get_bases(bases)
679+
bases = classdef.get_bases(bases, self.matches_type)
665680
keywords = classdef.get_keywords(keywords)
666681
constants, methods, aliases, slots, classes = _split_definitions(defs)
667682

pytype/pyi/parser.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,6 @@
3030
# reexport as parser.ParseError
3131
ParseError = types.ParseError
3232

33-
_TYPEVAR_IDS = ("TypeVar", "typing.TypeVar")
34-
_PARAMSPEC_IDS = (
35-
"ParamSpec", "typing.ParamSpec", "typing_extensions.ParamSpec")
36-
_TYPING_NAMEDTUPLE_IDS = ("NamedTuple", "typing.NamedTuple")
37-
_COLL_NAMEDTUPLE_IDS = ("namedtuple", "collections.namedtuple")
38-
_TYPEDDICT_IDS = (
39-
"TypedDict", "typing.TypedDict", "typing_extensions.TypedDict")
40-
_NEWTYPE_IDS = ("NewType", "typing.NewType")
41-
_ANNOTATED_IDS = (
42-
"Annotated", "typing.Annotated", "typing_extensions.Annotated")
43-
_FINAL_IDS = ("typing.Final", "typing_extensions.Final")
44-
_TYPE_ALIAS_IDS = ("typing.TypeAlias", "typing_extensions.TypeAlias")
45-
_TYPING_LITERAL_IDS = ("Literal", "typing.Literal", "typing_extensions.Literal")
46-
4733
#------------------------------------------------------
4834
# imports
4935

@@ -218,7 +204,8 @@ def _convert_typing_annotated(self, node):
218204
def enter_Subscript(self, node):
219205
if isinstance(node.value, ast3.Attribute):
220206
node.value = _attribute_to_name(node.value).id
221-
if getattr(node.value, "id", None) in _ANNOTATED_IDS:
207+
if self.defs.matches_type(getattr(node.value, "id", ""),
208+
"typing.Annotated"):
222209
self._convert_typing_annotated(node)
223210
self.subscripted.append(node.value)
224211

@@ -244,8 +231,13 @@ def visit_Attribute(self, node):
244231
def visit_BinOp(self, node):
245232
if self.subscripted:
246233
last = self.subscripted[-1]
247-
if ((isinstance(last, ast3.Name) and last.id in _TYPING_LITERAL_IDS) or
248-
isinstance(last, str) and last in _TYPING_LITERAL_IDS):
234+
if isinstance(last, ast3.Name):
235+
last_id = last.id
236+
elif isinstance(last, str):
237+
last_id = last
238+
else:
239+
last_id = ""
240+
if self.defs.matches_type(last_id, "typing.Literal"):
249241
raise ParseError("Expressions are not allowed in typing.Literal.")
250242
if isinstance(node.op, ast3.BitOr):
251243
return self.defs.new_type("typing.Union", [node.left, node.right])
@@ -446,7 +438,7 @@ def visit_AnnAssign(self, node):
446438
typ = pytd.NamedType("tuple")
447439
val = None
448440
elif typ.name:
449-
if pytd_utils.MatchesFullName(typ, _FINAL_IDS):
441+
if self.defs.matches_type(typ.name, "typing.Final"):
450442
if isinstance(node.value, types.Pyval):
451443
# to_pytd_literal raises an exception if the value is a float, but
452444
# checking upfront allows us to generate a nicer error message.
@@ -460,7 +452,7 @@ def visit_AnnAssign(self, node):
460452
elif isinstance(val, pytd.NamedType):
461453
typ = pytd.Literal(val)
462454
val = None
463-
elif pytd_utils.MatchesFullName(typ, _TYPE_ALIAS_IDS):
455+
elif self.defs.matches_type(typ.name, "typing.TypeAlias"):
464456
typ = val
465457
val = None
466458
is_alias = True
@@ -668,31 +660,32 @@ def enter_Call(self, node):
668660
# passing them to internal functions directly in visit_Call.
669661
if isinstance(node.func, ast3.Attribute):
670662
node.func = _attribute_to_name(node.func)
671-
if node.func.id in _TYPEVAR_IDS:
663+
if self.defs.matches_type(node.func.id, "typing.TypeVar"):
672664
self._convert_typevar_args(node)
673-
elif node.func.id in _PARAMSPEC_IDS:
665+
elif self.defs.matches_type(node.func.id, "typing.ParamSpec"):
674666
self._convert_paramspec_args(node)
675-
elif node.func.id in _TYPING_NAMEDTUPLE_IDS:
667+
elif self.defs.matches_type(node.func.id, "typing.NamedTuple"):
676668
self._convert_typing_namedtuple_args(node)
677-
elif node.func.id in _COLL_NAMEDTUPLE_IDS:
669+
elif self.defs.matches_type(node.func.id, "collections.namedtuple"):
678670
self._convert_collections_namedtuple_args(node)
679-
elif node.func.id in _TYPEDDICT_IDS:
671+
elif self.defs.matches_type(node.func.id, "typing.TypedDict"):
680672
self._convert_typed_dict_args(node)
681-
elif node.func.id in _NEWTYPE_IDS:
673+
elif self.defs.matches_type(node.func.id, "typing.NewType"):
682674
return self._convert_newtype_args(node)
683675

684676
def visit_Call(self, node):
685-
if node.func.id in _TYPEVAR_IDS:
677+
if self.defs.matches_type(node.func.id, "typing.TypeVar"):
686678
if self.level > 0:
687679
raise ParseError("TypeVars need to be defined at module level")
688680
return _TypeVar.from_call(node)
689-
elif node.func.id in _PARAMSPEC_IDS:
681+
elif self.defs.matches_type(node.func.id, "typing.ParamSpec"):
690682
return _ParamSpec.from_call(node)
691-
elif node.func.id in _TYPING_NAMEDTUPLE_IDS + _COLL_NAMEDTUPLE_IDS:
683+
elif self.defs.matches_type(
684+
node.func.id, ("typing.NamedTuple", "collections.namedtuple")):
692685
return self.defs.new_named_tuple(*node.args)
693-
elif node.func.id in _TYPEDDICT_IDS:
686+
elif self.defs.matches_type(node.func.id, "typing.TypedDict"):
694687
return self.defs.new_typed_dict(*node.args, node.keywords)
695-
elif node.func.id in _NEWTYPE_IDS:
688+
elif self.defs.matches_type(node.func.id, "typing.NewType"):
696689
return self.defs.new_new_type(*node.args)
697690
# Convert all other calls to NamedTypes; for example:
698691
# * typing.pyi uses things like

pytype/pyi/parser_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,19 @@ class B:
559559
__match_args__: tuple
560560
""")
561561

562+
def test_typevar_alias(self):
563+
self.check("""
564+
from typing import TypeVar as _TypeVar
565+
T = _TypeVar('T')
566+
def f(x: T) -> T: ...
567+
""", """
568+
from typing import TypeVar, TypeVar as _TypeVar
569+
570+
T = TypeVar('T')
571+
572+
def f(x: T) -> T: ...
573+
""")
574+
562575

563576
class QuotedTypeTest(parser_test_base.ParserTestBase):
564577

pytype/pytd/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ py_library(
158158
.pytd_visitors
159159
pytype.utils
160160
pytype.platform_utils.platform_utils
161-
pytype.pytd.parse.parse
162161
)
163162

164163
py_library(

pytype/pytd/pytd_utils.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from pytype.pytd import printer
2323
from pytype.pytd import pytd
2424
from pytype.pytd import pytd_visitors
25-
from pytype.pytd.parse import parser_constants
2625

2726

2827
ANON_PARAM = re.compile(r"_[0-9]+")
@@ -439,30 +438,3 @@ def MergeBaseClass(cls, base):
439438
decorators=decorators,
440439
slots=slots,
441440
template=cls.template or base.template)
442-
443-
444-
def MatchesFullName(t, full_name, current_module_name=None, aliases=None):
445-
"""Whether t.name matches full_name in format {module}.{member}."""
446-
if isinstance(full_name, tuple):
447-
return any(MatchesFullName(t, name, current_module_name, aliases)
448-
for name in full_name)
449-
expected_module_name, expected_name = full_name.rsplit(".", 1)
450-
if current_module_name == expected_module_name:
451-
# full_name is inside the current module, so check for the name without
452-
# the module prefix.
453-
return t.name == expected_name
454-
elif "." not in t.name:
455-
# full_name is not inside the current module, so a local type can't match.
456-
return False
457-
else:
458-
module_name, name = t.name.rsplit(".", 1)
459-
if aliases and module_name in aliases:
460-
# Adjust the module name if it has been aliased with `import x as y`.
461-
# See test_pyi.PYITest.testTypingAlias.
462-
module = aliases[module_name].type
463-
if isinstance(module, pytd.Module):
464-
module_name = module.module_name
465-
expected_module_names = {
466-
expected_module_name,
467-
parser_constants.EXTERNAL_NAME_PREFIX + expected_module_name}
468-
return module_name in expected_module_names and name == expected_name

0 commit comments

Comments
 (0)