Skip to content

Commit 5619bfa

Browse files
committed
fix: resolve typevar before generating attrs init method
1 parent 6e25daa commit 5619bfa

File tree

2 files changed

+137
-12
lines changed

2 files changed

+137
-12
lines changed

mypy/plugins/attrs.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
SymbolTableNode, MDEF, JsonDict, OverloadedFuncDef, ARG_NAMED_OPT, ARG_NAMED,
1616
TypeVarExpr, PlaceholderNode
1717
)
18+
from mypy.plugin import SemanticAnalyzerPluginInterface
1819
from mypy.plugins.common import (
19-
_get_argument, _get_bool_argument, _get_decorator_bool_argument, add_method
20+
_get_argument, _get_bool_argument, _get_decorator_bool_argument, add_method,
21+
deserialize_and_fixup_type
2022
)
2123
from mypy.types import (
2224
Type, AnyType, TypeOfAny, CallableType, NoneType, TypeVarDef, TypeVarType,
2325
Overloaded, UnionType, FunctionLike, get_proper_type
2426
)
25-
from mypy.typeops import make_simplified_union
27+
from mypy.typeops import make_simplified_union, map_type_from_supertype
2628
from mypy.typevars import fill_typevars
2729
from mypy.util import unmangle
2830
from mypy.server.trigger import make_wildcard_trigger
@@ -62,19 +64,22 @@ class Attribute:
6264

6365
def __init__(self, name: str, info: TypeInfo,
6466
has_default: bool, init: bool, kw_only: bool, converter: Converter,
65-
context: Context) -> None:
67+
context: Context,
68+
init_type: Optional[Type]) -> None:
6669
self.name = name
6770
self.info = info
6871
self.has_default = has_default
6972
self.init = init
7073
self.kw_only = kw_only
7174
self.converter = converter
7275
self.context = context
76+
self.init_type = init_type
7377

7478
def argument(self, ctx: 'mypy.plugin.ClassDefContext') -> Argument:
7579
"""Return this attribute as an argument to __init__."""
7680
assert self.init
77-
init_type = self.info[self.name].type
81+
82+
init_type = self.init_type or self.info[self.name].type
7883

7984
if self.converter.name:
8085
# When a converter is set the init_type is overridden by the first argument
@@ -160,20 +165,33 @@ def serialize(self) -> JsonDict:
160165
'converter_is_attr_converters_optional': self.converter.is_attr_converters_optional,
161166
'context_line': self.context.line,
162167
'context_column': self.context.column,
168+
'init_type': self.init_type.serialize() if self.init_type else None,
163169
}
164170

165171
@classmethod
166-
def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'Attribute':
172+
def deserialize(cls, info: TypeInfo,
173+
data: JsonDict,
174+
api: SemanticAnalyzerPluginInterface) -> 'Attribute':
167175
"""Return the Attribute that was serialized."""
168-
return Attribute(
169-
data['name'],
176+
raw_init_type = data['init_type']
177+
init_type = deserialize_and_fixup_type(raw_init_type, api) if raw_init_type else None
178+
179+
return Attribute(data['name'],
170180
info,
171181
data['has_default'],
172182
data['init'],
173183
data['kw_only'],
174184
Converter(data['converter_name'], data['converter_is_attr_converters_optional']),
175-
Context(line=data['context_line'], column=data['context_column'])
176-
)
185+
Context(line=data['context_line'], column=data['context_column']),
186+
init_type)
187+
188+
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
189+
"""Expands type vars in the context of a subtype when an attribute is inherited
190+
from a generic super type."""
191+
if not isinstance(self.init_type, TypeVarType):
192+
return
193+
194+
self.init_type = map_type_from_supertype(self.init_type, sub_type, self.info)
177195

178196

179197
def _determine_eq_order(ctx: 'mypy.plugin.ClassDefContext') -> bool:
@@ -350,7 +368,8 @@ def _analyze_class(ctx: 'mypy.plugin.ClassDefContext',
350368
# Only add an attribute if it hasn't been defined before. This
351369
# allows for overwriting attribute definitions by subclassing.
352370
if data['name'] not in taken_attr_names:
353-
a = Attribute.deserialize(super_info, data)
371+
a = Attribute.deserialize(super_info, data, ctx.api)
372+
a.expand_typevar_from_subtype(ctx.cls.info)
354373
super_attrs.append(a)
355374
taken_attr_names.add(a.name)
356375
attributes = super_attrs + list(own_attrs.values())
@@ -451,7 +470,9 @@ def _attribute_from_auto_attrib(ctx: 'mypy.plugin.ClassDefContext',
451470
name = unmangle(lhs.name)
452471
# `x: int` (without equal sign) assigns rvalue to TempNode(AnyType())
453472
has_rhs = not isinstance(rvalue, TempNode)
454-
return Attribute(name, ctx.cls.info, has_rhs, True, kw_only, Converter(), stmt)
473+
sym = ctx.cls.info.names.get(name)
474+
init_type = sym.type if sym else None
475+
return Attribute(name, ctx.cls.info, has_rhs, True, kw_only, Converter(), stmt, init_type)
455476

456477

457478
def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext',
@@ -517,7 +538,8 @@ def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext',
517538
converter_info = _parse_converter(ctx, converter)
518539

519540
name = unmangle(lhs.name)
520-
return Attribute(name, ctx.cls.info, attr_has_default, init, kw_only, converter_info, stmt)
541+
return Attribute(name, ctx.cls.info, attr_has_default, init,
542+
kw_only, converter_info, stmt, init_type)
521543

522544

523545
def _parse_converter(ctx: 'mypy.plugin.ClassDefContext',

test-data/unit/check-attr.test

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,109 @@ A([1], '2') # E: Cannot infer type argument 1 of "A"
414414

415415
[builtins fixtures/list.pyi]
416416

417+
418+
[case testAttrsUntypedGenericInheritance]
419+
from typing import Generic, TypeVar
420+
import attr
421+
422+
T = TypeVar("T")
423+
424+
@attr.s(auto_attribs=True)
425+
class Base(Generic[T]):
426+
attr: T
427+
428+
@attr.s(auto_attribs=True)
429+
class Sub(Base):
430+
pass
431+
432+
sub = Sub(attr=1)
433+
reveal_type(sub) # N: Revealed type is '__main__.Sub'
434+
reveal_type(sub.attr) # N: Revealed type is 'Any'
435+
436+
[builtins fixtures/bool.pyi]
437+
438+
439+
[case testAttrsGenericInheritance]
440+
from typing import Generic, TypeVar
441+
import attr
442+
443+
S = TypeVar("S")
444+
T = TypeVar("T")
445+
446+
@attr.s(auto_attribs=True)
447+
class Base(Generic[T]):
448+
attr: T
449+
450+
@attr.s(auto_attribs=True)
451+
class Sub(Base[S]):
452+
pass
453+
454+
sub_int = Sub[int](attr=1)
455+
reveal_type(sub_int) # N: Revealed type is '__main__.Sub[builtins.int*]'
456+
reveal_type(sub_int.attr) # N: Revealed type is 'builtins.int*'
457+
458+
sub_str = Sub[str](attr='ok')
459+
reveal_type(sub_str) # N: Revealed type is '__main__.Sub[builtins.str*]'
460+
reveal_type(sub_str.attr) # N: Revealed type is 'builtins.str*'
461+
462+
[builtins fixtures/bool.pyi]
463+
464+
465+
[case testAttrsGenericInheritance]
466+
from typing import Generic, TypeVar
467+
import attr
468+
469+
T1 = TypeVar("T1")
470+
T2 = TypeVar("T2")
471+
T3 = TypeVar("T3")
472+
473+
@attr.s(auto_attribs=True)
474+
class Base(Generic[T1, T2, T3]):
475+
one: T1
476+
two: T2
477+
three: T3
478+
479+
@attr.s(auto_attribs=True)
480+
class Sub(Base[int, str, float]):
481+
pass
482+
483+
sub = Sub(one=1, two='ok', three=3.14)
484+
reveal_type(sub) # N: Revealed type is '__main__.Sub'
485+
reveal_type(sub.one) # N: Revealed type is 'builtins.int*'
486+
reveal_type(sub.two) # N: Revealed type is 'builtins.str*'
487+
reveal_type(sub.three) # N: Revealed type is 'builtins.float*'
488+
489+
[builtins fixtures/bool.pyi]
490+
491+
492+
[case testAttrsMultiGenericInheritance]
493+
from typing import Generic, TypeVar
494+
import attr
495+
496+
T = TypeVar("T")
497+
498+
@attr.s(auto_attribs=True, eq=False)
499+
class Base(Generic[T]):
500+
base_attr: T
501+
502+
S = TypeVar("S")
503+
504+
@attr.s(auto_attribs=True, eq=False)
505+
class Middle(Base[int], Generic[S]):
506+
middle_attr: S
507+
508+
@attr.s(auto_attribs=True, eq=False)
509+
class Sub(Middle[str]):
510+
pass
511+
512+
sub = Sub(base_attr=1, middle_attr='ok')
513+
reveal_type(sub) # N: Revealed type is '__main__.Sub'
514+
reveal_type(sub.base_attr) # N: Revealed type is 'builtins.int*'
515+
reveal_type(sub.middle_attr) # N: Revealed type is 'builtins.str*'
516+
517+
[builtins fixtures/bool.pyi]
518+
519+
417520
[case testAttrsGenericClassmethod]
418521
from typing import TypeVar, Generic, Optional
419522
import attr

0 commit comments

Comments
 (0)