diff --git a/mypy/checker.py b/mypy/checker.py index cb6242008e3d..6b5bcbb5fb23 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4743,6 +4743,11 @@ def is_private(node_name: str) -> bool: return node_name.startswith('__') and not node_name.endswith('__') +def get_enum_values(typ: Instance) -> List[str]: + """Return the list of values for an Enum.""" + return [name for name, sym in typ.type.names.items() if isinstance(sym.node, Var)] + + def is_singleton_type(typ: Type) -> bool: """Returns 'true' if this type is a "singleton type" -- if there exists exactly only one runtime value associated with this type. @@ -4751,7 +4756,8 @@ def is_singleton_type(typ: Type) -> bool: 'is_singleton_type(t)' returns True if and only if the expression 'a is b' is always true. - Currently, this returns True when given NoneTypes and enum LiteralTypes. + Currently, this returns True when given NoneTypes, enum LiteralTypes and + enum types with a single value. Note that other kinds of LiteralTypes cannot count as singleton types. For example, suppose we do 'a = 100000 + 1' and 'b = 100001'. It is not guaranteed @@ -4761,7 +4767,10 @@ def is_singleton_type(typ: Type) -> bool: typ = get_proper_type(typ) # TODO: Also make this return True if the type is a bool LiteralType. # Also make this return True if the type corresponds to ... (ellipsis) or NotImplemented? - return isinstance(typ, NoneType) or (isinstance(typ, LiteralType) and typ.is_enum_literal()) + return ( + isinstance(typ, NoneType) or (isinstance(typ, LiteralType) and typ.is_enum_literal()) + or (isinstance(typ, Instance) and typ.type.is_enum and len(get_enum_values(typ)) == 1) + ) def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> ProperType: @@ -4808,17 +4817,21 @@ class Status(Enum): def coerce_to_literal(typ: Type) -> ProperType: - """Recursively converts any Instances that have a last_known_value into the - corresponding LiteralType. + """Recursively converts any Instances that have a last_known_value or are + instances of enum types with a single value into the corresponding LiteralType. """ typ = get_proper_type(typ) if isinstance(typ, UnionType): new_items = [coerce_to_literal(item) for item in typ.items] return make_simplified_union(new_items) - elif isinstance(typ, Instance) and typ.last_known_value: - return typ.last_known_value - else: - return typ + elif isinstance(typ, Instance): + if typ.last_known_value: + return typ.last_known_value + elif typ.type.is_enum: + enum_values = get_enum_values(typ) + if len(enum_values) == 1: + return LiteralType(value=enum_values[0], fallback=typ) + return typ def has_bool_item(typ: ProperType) -> bool: diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index d61ce527fd35..43355392098c 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -808,7 +808,7 @@ else: [builtins fixtures/bool.pyi] -[case testEnumReachabilityPEP484Example1] +[case testEnumReachabilityPEP484ExampleWithFinal] # flags: --strict-optional from typing import Union from typing_extensions import Final @@ -833,7 +833,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int: return x + 2 [builtins fixtures/primitives.pyi] -[case testEnumReachabilityPEP484Example2] +[case testEnumReachabilityPEP484ExampleWithMultipleValues] from typing import Union from enum import Enum @@ -852,5 +852,59 @@ def process(response: Union[str, Reason] = '') -> str: # response can be only str, all other possible values exhausted reveal_type(response) # N: Revealed type is 'builtins.str' return 'PROCESSED: ' + response +[builtins fixtures/primitives.pyi] + + +[case testEnumReachabilityPEP484ExampleSingleton] +# flags: --strict-optional +from typing import Union +from typing_extensions import Final +from enum import Enum + +class Empty(Enum): + token = 0 +_empty = Empty.token + +def func(x: Union[int, None, Empty] = _empty) -> int: + boom = x + 42 # E: Unsupported left operand type for + ("None") \ + # E: Unsupported left operand type for + ("Empty") \ + # N: Left operand is of type "Union[int, None, Empty]" + if x is _empty: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]' + return 0 + elif x is None: + reveal_type(x) # N: Revealed type is 'None' + return 1 + else: # At this point typechecker knows that x can only have type int + reveal_type(x) # N: Revealed type is 'builtins.int' + return x + 2 +[builtins fixtures/primitives.pyi] + +[case testEnumReachabilityPEP484ExampleSingletonWithMethod] +# flags: --strict-optional +from typing import Union +from typing_extensions import Final +from enum import Enum +class Empty(Enum): + token = lambda x: x + + def f(self) -> int: + return 1 + +_empty = Empty.token + +def func(x: Union[int, None, Empty] = _empty) -> int: + boom = x + 42 # E: Unsupported left operand type for + ("None") \ + # E: Unsupported left operand type for + ("Empty") \ + # N: Left operand is of type "Union[int, None, Empty]" + if x is _empty: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]' + return 0 + elif x is None: + reveal_type(x) # N: Revealed type is 'None' + return 1 + else: # At this point typechecker knows that x can only have type int + reveal_type(x) # N: Revealed type is 'builtins.int' + return x + 2 [builtins fixtures/primitives.pyi]