Skip to content

Commit 14fa27f

Browse files
committed
Refine parent type when narrowing "lookup" expressions
This diff adds support for the following pattern: ```python from typing import Enum, List from enum import Enum class Key(Enum): A = 1 B = 2 class Foo: key: Literal[Key.A] blah: List[int] class Bar: key: Literal[Key.B] something: List[str] x: Union[Foo, Bar] if x.key is Key.A: reveal_type(x) # Revealed type is 'Foo' else: reveal_type(x) # Revealed type is 'Bar' ``` In short, when we do `x.key is Key.A`, we "propagate" the information we discovered about `x.key` up one level to refine the type of `x`. We perform this propagation only when `x` is a Union and only when we are doing member or index lookups into instances, typeddicts, namedtuples, and tuples. For indexing operations, we have one additional limitation: we *must* use a literal expression in order for narrowing to work at all. Using Literal types or Final instances won't work; See python#7905 for more details. To put it another way, this adds support for tagged unions, I guess. This more or less resolves python#7344. We currently don't have support for narrowing based on string or int literals, but that's a separate issue and should be resolved by python#7169 (which I resumed work on earlier this week).
1 parent 84126ab commit 14fa27f

File tree

6 files changed

+562
-22
lines changed

6 files changed

+562
-22
lines changed

mypy/checker.py

Lines changed: 173 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import (
88
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable,
9-
Sequence
9+
Mapping, Sequence
1010
)
1111
from typing_extensions import Final
1212

@@ -43,11 +43,13 @@
4343
)
4444
import mypy.checkexpr
4545
from mypy.checkmember import (
46-
analyze_descriptor_access, type_object_type,
46+
analyze_member_access, analyze_descriptor_access, type_object_type,
4747
)
4848
from mypy.typeops import (
4949
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union,
50-
erase_def_to_union_or_bound, erase_to_union_or_bound,
50+
erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal,
51+
try_getting_str_literals_from_type, try_getting_int_literals_from_type,
52+
tuple_fallback, is_singleton_type, try_expanding_enum_to_union,
5153
true_only, false_only, function_type,
5254
)
5355
from mypy import message_registry
@@ -72,9 +74,6 @@
7274
from mypy.plugin import Plugin, CheckerPluginInterface
7375
from mypy.sharedparse import BINARY_MAGIC_METHODS
7476
from mypy.scope import Scope
75-
from mypy.typeops import (
76-
tuple_fallback, coerce_to_literal, is_singleton_type, try_expanding_enum_to_union
77-
)
7877
from mypy import state, errorcodes as codes
7978
from mypy.traverser import has_return_statement, all_return_statements
8079
from mypy.errorcodes import ErrorCode
@@ -3709,6 +3708,12 @@ def find_isinstance_check(self, node: Expression
37093708
37103709
Guaranteed to not return None, None. (But may return {}, {})
37113710
"""
3711+
if_map, else_map = self.find_isinstance_check_helper(node)
3712+
new_if_map = self.propagate_up_typemap_info(self.type_map, if_map)
3713+
new_else_map = self.propagate_up_typemap_info(self.type_map, else_map)
3714+
return new_if_map, new_else_map
3715+
3716+
def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeMap]:
37123717
type_map = self.type_map
37133718
if is_true_literal(node):
37143719
return {}, None
@@ -3835,28 +3840,185 @@ def find_isinstance_check(self, node: Expression
38353840
else None)
38363841
return if_map, else_map
38373842
elif isinstance(node, OpExpr) and node.op == 'and':
3838-
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
3839-
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
3843+
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left)
3844+
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right)
38403845

38413846
# (e1 and e2) is true if both e1 and e2 are true,
38423847
# and false if at least one of e1 and e2 is false.
38433848
return (and_conditional_maps(left_if_vars, right_if_vars),
38443849
or_conditional_maps(left_else_vars, right_else_vars))
38453850
elif isinstance(node, OpExpr) and node.op == 'or':
3846-
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
3847-
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
3851+
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left)
3852+
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right)
38483853

38493854
# (e1 or e2) is true if at least one of e1 or e2 is true,
38503855
# and false if both e1 and e2 are false.
38513856
return (or_conditional_maps(left_if_vars, right_if_vars),
38523857
and_conditional_maps(left_else_vars, right_else_vars))
38533858
elif isinstance(node, UnaryExpr) and node.op == 'not':
3854-
left, right = self.find_isinstance_check(node.expr)
3859+
left, right = self.find_isinstance_check_helper(node.expr)
38553860
return right, left
38563861

38573862
# Not a supported isinstance check
38583863
return {}, {}
38593864

3865+
def propagate_up_typemap_info(self,
3866+
existing_types: Mapping[Expression, Type],
3867+
new_types: TypeMap) -> TypeMap:
3868+
"""Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
3869+
3870+
Specifically, this function accepts two mappings of expression to original types:
3871+
the original mapping (existing_types), and a new mapping (new_types) intended to
3872+
update the original.
3873+
3874+
This function iterates through new_types and attempts to use the information to try
3875+
refining any parent types that happen to be unions.
3876+
3877+
For example, suppose there are two types "A = Tuple[int, int]" and "B = Tuple[str, str]".
3878+
Next, suppose that 'new_types' specifies the expression 'foo[0]' has a refined type
3879+
of 'int' and that 'foo' was previously deduced to be of type Union[A, B].
3880+
3881+
Then, this function will observe that since A[0] is an int and B[0] is not, the type of
3882+
'foo' can be further refined from Union[A, B] into just B.
3883+
3884+
We perform this kind of "parent narrowing" for member lookup expressions and indexing
3885+
expressions into tuples, namedtuples, and typeddicts. This narrowing is also performed
3886+
only once, for the immediate parents of any "lookup" expressions in `new_types`.
3887+
3888+
We return the newly refined map. This map is guaranteed to be a superset of 'new_types'.
3889+
"""
3890+
if new_types is None:
3891+
return None
3892+
output_map = {}
3893+
for expr, expr_type in new_types.items():
3894+
# The original inferred type should always be present in the output map, of course
3895+
output_map[expr] = expr_type
3896+
3897+
# Next, try using this information to refine the parent type, if applicable.
3898+
# Note that we currently refine just the immediate parent.
3899+
#
3900+
# TODO: Should we also try recursively refining any parents of the parents?
3901+
#
3902+
# One quick-and-dirty way of doing this would be to have the caller repeatedly run
3903+
# this function until we reach fixpoint; another way would be to modify
3904+
# 'refine_parent_type' to run in a loop. Both approaches seem expensive though.
3905+
new_mapping = self.refine_parent_type(existing_types, expr, expr_type)
3906+
for parent_expr, proposed_parent_type in new_mapping.items():
3907+
# We don't try inferring anything if we've already inferred something for
3908+
# the parent expression.
3909+
# TODO: Consider picking the narrower type instead of always discarding this?
3910+
if parent_expr in new_types:
3911+
continue
3912+
output_map[parent_expr] = proposed_parent_type
3913+
return output_map
3914+
3915+
def refine_parent_type(self,
3916+
existing_types: Mapping[Expression, Type],
3917+
expr: Expression,
3918+
expr_type: Type) -> Mapping[Expression, Type]:
3919+
"""Checks if the given expr is a 'lookup operation' into a union and refines the parent type
3920+
based on the 'expr_type'.
3921+
3922+
For more details about what a 'lookup operation' is and how we use the expr_type to refine
3923+
the parent type, see the docstring in 'propagate_up_typemap_info'.
3924+
"""
3925+
3926+
# First, check if this expression is one that's attempting to
3927+
# "lookup" some key in the parent type. If so, save the parent type
3928+
# and create function that will try replaying the same lookup
3929+
# operation against arbitrary types.
3930+
if isinstance(expr, MemberExpr):
3931+
parent_expr = expr.expr
3932+
parent_type = existing_types.get(parent_expr)
3933+
member_name = expr.name
3934+
3935+
def replay_lookup(new_parent_type: ProperType) -> Optional[Type]:
3936+
msg_copy = self.msg.clean_copy()
3937+
msg_copy.disable_count = 0
3938+
member_type = analyze_member_access(
3939+
name=member_name,
3940+
typ=new_parent_type,
3941+
context=parent_expr,
3942+
is_lvalue=False,
3943+
is_super=False,
3944+
is_operator=False,
3945+
msg=msg_copy,
3946+
original_type=new_parent_type,
3947+
chk=self,
3948+
in_literal_context=False,
3949+
)
3950+
if msg_copy.is_errors():
3951+
return None
3952+
else:
3953+
return member_type
3954+
elif isinstance(expr, IndexExpr):
3955+
parent_expr = expr.base
3956+
parent_type = existing_types.get(parent_expr)
3957+
3958+
index_type = existing_types.get(expr.index)
3959+
if index_type is None:
3960+
return {}
3961+
3962+
str_literals = try_getting_str_literals_from_type(index_type)
3963+
if str_literals is not None:
3964+
def replay_lookup(new_parent_type: ProperType) -> Optional[Type]:
3965+
if not isinstance(new_parent_type, TypedDictType):
3966+
return None
3967+
try:
3968+
assert str_literals is not None
3969+
member_types = [new_parent_type.items[key] for key in str_literals]
3970+
except KeyError:
3971+
return None
3972+
return make_simplified_union(member_types)
3973+
else:
3974+
int_literals = try_getting_int_literals_from_type(index_type)
3975+
if int_literals is not None:
3976+
def replay_lookup(new_parent_type: ProperType) -> Optional[Type]:
3977+
if not isinstance(new_parent_type, TupleType):
3978+
return None
3979+
try:
3980+
assert int_literals is not None
3981+
member_types = [new_parent_type.items[key] for key in int_literals]
3982+
except IndexError:
3983+
return None
3984+
return make_simplified_union(member_types)
3985+
else:
3986+
return {}
3987+
else:
3988+
return {}
3989+
3990+
# If we somehow didn't previously derive the parent type, abort:
3991+
# something went wrong at an earlier stage.
3992+
if parent_type is None:
3993+
return {}
3994+
3995+
# We currently only try refining the parent type if it's a Union.
3996+
parent_type = get_proper_type(parent_type)
3997+
if not isinstance(parent_type, UnionType):
3998+
return {}
3999+
4000+
# Take each element in the parent union and replay the original lookup procedure
4001+
# to figure out which parents are compatible.
4002+
new_parent_types = []
4003+
for item in parent_type.items:
4004+
item = get_proper_type(item)
4005+
member_type = replay_lookup(item)
4006+
if member_type is None:
4007+
# We were unable to obtain the member type. So, we give up on refining this
4008+
# parent type entirely.
4009+
return {}
4010+
4011+
if is_overlapping_types(member_type, expr_type):
4012+
new_parent_types.append(item)
4013+
4014+
# If none of the parent types overlap (if we derived an empty union), something
4015+
# went wrong. We should never hit this case, but deriving the uninhabited type or
4016+
# reporting an error both seem unhelpful. So we abort.
4017+
if not new_parent_types:
4018+
return {}
4019+
4020+
return {parent_expr: make_simplified_union(new_parent_types)}
4021+
38604022
#
38614023
# Helpers
38624024
#

mypy/checkexpr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2704,6 +2704,9 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
27042704
index = e.index
27052705
left_type = get_proper_type(left_type)
27062706

2707+
# Visit the index, just to make sure we have a type for it available
2708+
self.accept(index)
2709+
27072710
if isinstance(left_type, UnionType):
27082711
original_type = original_type or left_type
27092712
return make_simplified_union([self.visit_index_with_type(typ, e,

mypy/test/testcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
'check-isinstance.test',
4747
'check-lists.test',
4848
'check-namedtuple.test',
49+
'check-narrowing.test',
4950
'check-typeddict.test',
5051
'check-type-aliases.test',
5152
'check-ignore.test',

mypy/typeops.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
since these may assume that MROs are ready.
66
"""
77

8-
from typing import cast, Optional, List, Sequence, Set
8+
from typing import cast, Optional, List, Sequence, Set, TypeVar, Type as TypingType
99
import sys
1010

1111
from mypy.types import (
1212
TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded,
13-
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType,
13+
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, TypedDictType,
1414
AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types,
1515
copy_type, TypeAliasType
1616
)
1717
from mypy.nodes import (
18-
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, TypeVar, ARG_STAR, ARG_STAR2, ARG_POS,
18+
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, ARG_STAR, ARG_STAR2, ARG_POS,
1919
Expression, StrExpr, Var
2020
)
2121
from mypy.maptype import map_instance_to_supertype
@@ -43,6 +43,25 @@ def tuple_fallback(typ: TupleType) -> Instance:
4343
return Instance(info, [join_type_list(typ.items)])
4444

4545

46+
def try_getting_instance_fallback(typ: ProperType) -> Optional[Instance]:
47+
"""Returns the Instance fallback for this type if one exists.
48+
49+
Otherwise, returns None.
50+
"""
51+
if isinstance(typ, Instance):
52+
return typ
53+
elif isinstance(typ, TupleType):
54+
return tuple_fallback(typ)
55+
elif isinstance(typ, TypedDictType):
56+
return typ.fallback
57+
elif isinstance(typ, FunctionLike):
58+
return typ.fallback
59+
elif isinstance(typ, LiteralType):
60+
return typ.fallback
61+
else:
62+
return None
63+
64+
4665
def type_object_type_from_function(signature: FunctionLike,
4766
info: TypeInfo,
4867
def_info: TypeInfo,
@@ -475,27 +494,66 @@ def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]]
475494
2. 'typ' is a LiteralType containing a string
476495
3. 'typ' is a UnionType containing only LiteralType of strings
477496
"""
478-
typ = get_proper_type(typ)
479-
480497
if isinstance(expr, StrExpr):
481498
return [expr.value]
482499

500+
# TODO: See if we can eliminate this function and call the below one directly
501+
return try_getting_str_literals_from_type(typ)
502+
503+
504+
def try_getting_str_literals_from_type(typ: Type) -> Optional[List[str]]:
505+
"""If the given expression or type corresponds to a string Literal
506+
or a union of string Literals, returns a list of the underlying strings.
507+
Otherwise, returns None.
508+
509+
For example, if we had the type 'Literal["foo", "bar"]' as input, this function
510+
would return a list of strings ["foo", "bar"].
511+
"""
512+
return try_getting_literals_from_type(typ, str, "builtins.str")
513+
514+
515+
def try_getting_int_literals_from_type(typ: Type) -> Optional[List[int]]:
516+
"""If the given expression or type corresponds to an int Literal
517+
or a union of int Literals, returns a list of the underlying ints.
518+
Otherwise, returns None.
519+
520+
For example, if we had the type 'Literal[1, 2, 3]' as input, this function
521+
would return a list of ints [1, 2, 3].
522+
"""
523+
return try_getting_literals_from_type(typ, int, "builtins.int")
524+
525+
526+
T = TypeVar('T')
527+
528+
529+
def try_getting_literals_from_type(typ: Type,
530+
target_literal_type: TypingType[T],
531+
target_fullname: str) -> Optional[List[T]]:
532+
"""If the given expression or type corresponds to a Literal or
533+
union of Literals where the underlying values corresponds to the given
534+
target type, returns a list of those underlying values. Otherwise,
535+
returns None.
536+
"""
537+
typ = get_proper_type(typ)
538+
483539
if isinstance(typ, Instance) and typ.last_known_value is not None:
484540
possible_literals = [typ.last_known_value] # type: List[Type]
485541
elif isinstance(typ, UnionType):
486542
possible_literals = list(typ.items)
487543
else:
488544
possible_literals = [typ]
489545

490-
strings = []
546+
literals = [] # type: List[T]
491547
for lit in get_proper_types(possible_literals):
492-
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str':
548+
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == target_fullname:
493549
val = lit.value
494-
assert isinstance(val, str)
495-
strings.append(val)
550+
if isinstance(val, target_literal_type):
551+
literals.append(val)
552+
else:
553+
return None
496554
else:
497555
return None
498-
return strings
556+
return literals
499557

500558

501559
def get_enum_values(typ: Instance) -> List[str]:

0 commit comments

Comments
 (0)