diff --git a/mypy/plugin.py b/mypy/plugin.py index 72d28c39436a..2f571d7eecc6 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -692,9 +692,33 @@ def get_class_decorator_hook(self, fullname: str The plugin can modify a TypeInfo _in place_ (for example add some generated methods to the symbol table). This hook is called after the class body was - semantically analyzed. + semantically analyzed, but *there may still be placeholders* (typically + caused by forward references). - The hook is called with full names of all class decorators, for example + NOTE: Usually get_class_decorator_hook_2 is the better option, since it + guarantees that there are no placeholders. + + The hook is called with full names of all class decorators. + + The hook can be called multiple times per class, so it must be + idempotent. + """ + return None + + def get_class_decorator_hook_2(self, fullname: str + ) -> Optional[Callable[[ClassDefContext], bool]]: + """Update class definition for given class decorators. + + Similar to get_class_decorator_hook, but this runs in a later pass when + placeholders have been resolved. + + The hook can return False if some base class hasn't been + processed yet using class hooks. It causes all class hooks + (that are run in this same pass) to be invoked another time for + the file(s) currently being processed. + + The hook can be called multiple times per class, so it must be + idempotent. """ return None @@ -815,6 +839,10 @@ def get_class_decorator_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: return self._find_hook(lambda plugin: plugin.get_class_decorator_hook(fullname)) + def get_class_decorator_hook_2(self, fullname: str + ) -> Optional[Callable[[ClassDefContext], bool]]: + return self._find_hook(lambda plugin: plugin.get_class_decorator_hook_2(fullname)) + def get_metaclass_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: return self._find_hook(lambda plugin: plugin.get_metaclass_hook(fullname)) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 62b4c89bd674..24077bb4a549 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -107,10 +107,19 @@ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: class DataclassTransformer: + """Implement the behavior of @dataclass. + + Note that this may be executed multiple times on the same class, so + everything here must be idempotent. + + This runs after the main semantic analysis pass, so you can assume that + there are no placeholders. + """ + def __init__(self, ctx: ClassDefContext) -> None: self._ctx = ctx - def transform(self) -> None: + def transform(self) -> bool: """Apply all the necessary transformations to the underlying dataclass so as to ensure it is fully type checked according to the rules in PEP 557. @@ -119,12 +128,11 @@ def transform(self) -> None: info = self._ctx.cls.info attributes = self.collect_attributes() if attributes is None: - # Some definitions are not ready, defer() should be already called. - return + # Some definitions are not ready. We need another pass. + return False for attr in attributes: if attr.type is None: - ctx.api.defer() - return + return False decorator_arguments = { 'init': _get_decorator_bool_argument(self._ctx, 'init', True), 'eq': _get_decorator_bool_argument(self._ctx, 'eq', True), @@ -236,6 +244,8 @@ def transform(self) -> None: 'frozen': decorator_arguments['frozen'], } + return True + def add_slots(self, info: TypeInfo, attributes: List[DataclassAttribute], @@ -294,6 +304,9 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: b: SomeOtherType = ... are collected. + + Return None if some dataclass base class hasn't been processed + yet and thus we'll need to ask for another pass. """ # First, collect attributes belonging to the current class. ctx = self._ctx @@ -315,14 +328,11 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: sym = cls.info.names.get(lhs.name) if sym is None: - # This name is likely blocked by a star import. We don't need to defer because - # defer() is already called by mark_incomplete(). + # There was probably a semantic analysis error. continue node = sym.node - if isinstance(node, PlaceholderNode): - # This node is not ready yet. - return None + assert not isinstance(node, PlaceholderNode) assert isinstance(node, Var) # x: ClassVar[int] is ignored by dataclasses. @@ -390,6 +400,9 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: # we'll have unmodified attrs laying around. all_attrs = attrs.copy() for info in cls.info.mro[1:-1]: + if 'dataclass_tag' in info.metadata and 'dataclass' not in info.metadata: + # We haven't processed the base class yet. Need another pass. + return None if 'dataclass' not in info.metadata: continue @@ -517,11 +530,21 @@ def _add_dataclass_fields_magic_attribute(self) -> None: ) -def dataclass_class_maker_callback(ctx: ClassDefContext) -> None: +def dataclass_tag_callback(ctx: ClassDefContext) -> None: + """Record that we have a dataclass in the main semantic analysis pass. + + The later pass implemented by DataclassTransformer will use this + to detect dataclasses in base classes. + """ + # The value is ignored, only the existence matters. + ctx.cls.info.metadata['dataclass_tag'] = {} + + +def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool: """Hooks into the class typechecking process to add support for dataclasses. """ transformer = DataclassTransformer(ctx) - transformer.transform() + return transformer.transform() def _collect_field_args(expr: Expression, diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index a7fa2cfaa868..50e0e8cb4315 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -117,12 +117,21 @@ def get_class_decorator_hook(self, fullname: str auto_attribs_default=None, ) elif fullname in dataclasses.dataclass_makers: - return dataclasses.dataclass_class_maker_callback + return dataclasses.dataclass_tag_callback elif fullname in functools.functools_total_ordering_makers: return functools.functools_total_ordering_maker_callback return None + def get_class_decorator_hook_2(self, fullname: str + ) -> Optional[Callable[[ClassDefContext], bool]]: + from mypy.plugins import dataclasses + + if fullname in dataclasses.dataclass_makers: + return dataclasses.dataclass_class_maker_callback + + return None + def contextmanager_callback(ctx: FunctionContext) -> Type: """Infer a better return type for 'contextlib.contextmanager'.""" diff --git a/mypy/semanal.py b/mypy/semanal.py index 555cb749074e..d68928ef21ad 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1234,43 +1234,44 @@ def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: def apply_class_plugin_hooks(self, defn: ClassDef) -> None: """Apply a plugin hook that may infer a more precise definition for a class.""" - def get_fullname(expr: Expression) -> Optional[str]: - if isinstance(expr, CallExpr): - return get_fullname(expr.callee) - elif isinstance(expr, IndexExpr): - return get_fullname(expr.base) - elif isinstance(expr, RefExpr): - if expr.fullname: - return expr.fullname - # If we don't have a fullname look it up. This happens because base classes are - # analyzed in a different manner (see exprtotype.py) and therefore those AST - # nodes will not have full names. - sym = self.lookup_type_node(expr) - if sym: - return sym.fullname - return None for decorator in defn.decorators: - decorator_name = get_fullname(decorator) + decorator_name = self.get_fullname_for_hook(decorator) if decorator_name: hook = self.plugin.get_class_decorator_hook(decorator_name) if hook: hook(ClassDefContext(defn, decorator, self)) if defn.metaclass: - metaclass_name = get_fullname(defn.metaclass) + metaclass_name = self.get_fullname_for_hook(defn.metaclass) if metaclass_name: hook = self.plugin.get_metaclass_hook(metaclass_name) if hook: hook(ClassDefContext(defn, defn.metaclass, self)) for base_expr in defn.base_type_exprs: - base_name = get_fullname(base_expr) + base_name = self.get_fullname_for_hook(base_expr) if base_name: hook = self.plugin.get_base_class_hook(base_name) if hook: hook(ClassDefContext(defn, base_expr, self)) + def get_fullname_for_hook(self, expr: Expression) -> Optional[str]: + if isinstance(expr, CallExpr): + return self.get_fullname_for_hook(expr.callee) + elif isinstance(expr, IndexExpr): + return self.get_fullname_for_hook(expr.base) + elif isinstance(expr, RefExpr): + if expr.fullname: + return expr.fullname + # If we don't have a fullname look it up. This happens because base classes are + # analyzed in a different manner (see exprtotype.py) and therefore those AST + # nodes will not have full names. + sym = self.lookup_type_node(expr) + if sym: + return sym.fullname + return None + def analyze_class_keywords(self, defn: ClassDef) -> None: for value in defn.keywords.values(): value.accept(self) diff --git a/mypy/semanal_main.py b/mypy/semanal_main.py index 7a82032b46b7..bb0af8edc46f 100644 --- a/mypy/semanal_main.py +++ b/mypy/semanal_main.py @@ -45,6 +45,8 @@ from mypy.checker import FineGrainedDeferredNode from mypy.server.aststrip import SavedAttributes from mypy.util import is_typeshed_file +from mypy.options import Options +from mypy.plugin import ClassDefContext import mypy.build if TYPE_CHECKING: @@ -82,6 +84,8 @@ def semantic_analysis_for_scc(graph: 'Graph', scc: List[str], errors: Errors) -> apply_semantic_analyzer_patches(patches) # This pass might need fallbacks calculated above. check_type_arguments(graph, scc, errors) + # Run class decorator hooks (they requite complete MROs and no placeholders). + apply_class_plugin_hooks(graph, scc, errors) calculate_class_properties(graph, scc, errors) check_blockers(graph, scc) # Clean-up builtins, so that TypeVar etc. are not accessible without importing. @@ -132,6 +136,7 @@ def semantic_analysis_for_targets( check_type_arguments_in_targets(nodes, state, state.manager.errors) calculate_class_properties(graph, [state.id], state.manager.errors) + apply_class_plugin_hooks(graph, [state.id], state.manager.errors) def restore_saved_attrs(saved_attrs: SavedAttributes) -> None: @@ -382,14 +387,62 @@ def check_type_arguments_in_targets(targets: List[FineGrainedDeferredNode], stat target.node.accept(analyzer) +def apply_class_plugin_hooks(graph: 'Graph', scc: List[str], errors: Errors) -> None: + """Apply class plugin hooks within a SCC. + + We run these after to the main semantic analysis so that the hooks + don't need to deal with incomplete definitions such as placeholder + types. + + Note that some hooks incorrectly run during the main semantic + analysis pass, for historical reasons. + """ + num_passes = 0 + incomplete = True + # If we encounter a base class that has not been processed, we'll run another + # pass. This should eventually reach a fixed point. + while incomplete: + assert num_passes < 10, "Internal error: too many class plugin hook passes" + num_passes += 1 + incomplete = False + for module in scc: + state = graph[module] + tree = state.tree + assert tree + for _, node, _ in tree.local_definitions(): + if isinstance(node.node, TypeInfo): + if not apply_hooks_to_class(state.manager.semantic_analyzer, + module, node.node, state.options, tree, errors): + incomplete = True + + +def apply_hooks_to_class(self: SemanticAnalyzer, + module: str, + info: TypeInfo, + options: Options, + file_node: MypyFile, + errors: Errors) -> bool: + # TODO: Move more class-related hooks here? + defn = info.defn + ok = True + for decorator in defn.decorators: + with self.file_context(file_node, options, info): + decorator_name = self.get_fullname_for_hook(decorator) + if decorator_name: + hook = self.plugin.get_class_decorator_hook_2(decorator_name) + if hook: + ok = ok and hook(ClassDefContext(defn, decorator, self)) + return ok + + def calculate_class_properties(graph: 'Graph', scc: List[str], errors: Errors) -> None: for module in scc: - tree = graph[module].tree + state = graph[module] + tree = state.tree assert tree for _, node, _ in tree.local_definitions(): if isinstance(node.node, TypeInfo): - saved = (module, node.node, None) # module, class, function - with errors.scope.saved_scope(saved) if errors.scope else nullcontext(): + with state.manager.semantic_analyzer.file_context(tree, state.options, node.node): calculate_class_abstract_status(node.node, tree.is_stub, errors) check_protocol_status(node.node, errors) calculate_class_vars(node.node) diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test index bce1ee24a31a..4cddc59b0153 100644 --- a/test-data/unit/check-dataclasses.test +++ b/test-data/unit/check-dataclasses.test @@ -1207,7 +1207,7 @@ class A2: from dataclasses import dataclass @dataclass -class A: # E: Name "x" already defined (possibly by an import) +class A: x: int = 0 x: int = 0 # E: Name "x" already defined on line 7 @@ -1619,3 +1619,69 @@ Child(x='', y='') # E: Argument "y" to "Child" has incompatible type "str"; exp Child(x='', y=1) Child(x=None, y=None) [builtins fixtures/dataclasses.pyi] + +[case testDataclassGenericInheritanceSpecialCase1] +# flags: --python-version 3.7 +from dataclasses import dataclass +from typing import Generic, TypeVar, List + +T = TypeVar("T") + +@dataclass +class Parent(Generic[T]): + x: List[T] + +@dataclass +class Child1(Parent["Child2"]): ... + +@dataclass +class Child2(Parent["Child1"]): ... + +def f(c: Child2) -> None: + reveal_type(Child1([c]).x) # N: Revealed type is "builtins.list[__main__.Child2]" + +def g(c: Child1) -> None: + reveal_type(Child2([c]).x) # N: Revealed type is "builtins.list[__main__.Child1]" +[builtins fixtures/dataclasses.pyi] + +[case testDataclassGenericInheritanceSpecialCase2] +# flags: --python-version 3.7 +from dataclasses import dataclass +from typing import Generic, TypeVar + +T = TypeVar("T") + +# A subclass might be analyzed before base in import cycles. They are +# defined here in reversed order to simulate this. + +@dataclass +class Child1(Parent["Child2"]): + x: int + +@dataclass +class Child2(Parent["Child1"]): + y: int + +@dataclass +class Parent(Generic[T]): + key: str + +Child1(x=1, key='') +Child2(y=1, key='') +[builtins fixtures/dataclasses.pyi] + +[case testDataclassGenericWithBound] +# flags: --python-version 3.7 +from dataclasses import dataclass +from typing import Generic, TypeVar + +T = TypeVar("T", bound="C") + +@dataclass +class C(Generic[T]): + x: int + +c: C[C] +d: C[str] # E: Type argument "str" of "C" must be a subtype of "C[Any]" +C(x=2) +[builtins fixtures/dataclasses.pyi]