diff --git a/mypy/typeops.py b/mypy/typeops.py index dbfeebe42f14..653a2e4df531 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -5,7 +5,9 @@ since these may assume that MROs are ready. """ -from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any, Union +from typing import ( + cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any, Union, Callable +) from typing_extensions import Type as TypingType import itertools import sys @@ -336,6 +338,47 @@ def is_simple_literal(t: ProperType) -> bool: return False +def _get_flattened_proper_types(items: Sequence[Type]) -> Sequence[ProperType]: + """Similar to types.get_proper_types, with flattening of UnionType + + Optimized to avoid allocating a new list whenever possible""" + i: int = 0 + base: int = 0 + n: int = len(items) + + # optimistic fast path + while i < n: + t = items[i] + pt = get_proper_type(t) + if id(t) != id(pt) or isinstance(pt, UnionType): + # we need to allocate, switch to slow path + break + # simplify away any number of bottom type at the start of the input + if i == base and i+1 < n and isinstance(pt, UninhabitedType): + base += 1 + i += 1 + + # optimistic fast path reached end of input, no need to allocate + if i == n: + return cast(Sequence[ProperType], items[base:] if base > 0 else items) + + all_items = list(cast(Sequence[ProperType], items[base:i])) + + while i < n: + pt = get_proper_type(items[i]) + if isinstance(pt, UnionType): + all_items.extend(_get_flattened_proper_types(pt.items)) + else: + all_items.append(pt) + i += 1 + return all_items + + +_simplified_union_cache: List[Dict[Tuple[ProperType, ...], ProperType]] = [ + {} for _ in range(2**3) +] + + def make_simplified_union(items: Sequence[Type], line: int = -1, column: int = -1, *, keep_erased: bool = False, @@ -362,17 +405,35 @@ def make_simplified_union(items: Sequence[Type], back into a sum type. Set it to False when called by try_expanding_sum_type_ to_union(). """ - items = get_proper_types(items) - # Step 1: expand all nested unions - while any(isinstance(typ, UnionType) for typ in items): - all_items: List[ProperType] = [] - for typ in items: - if isinstance(typ, UnionType): - all_items.extend(get_proper_types(typ.items)) - else: - all_items.append(typ) - items = all_items + items = _get_flattened_proper_types(items) + + cache_fn: Optional[Callable[[ProperType], None]] = None + + # 1 or 2 elements account for the vast majority of inputs and are not worth caching: + # - they're two small for the quadratic worst-case cost of simplification to really + # manifest + # - they majority of those inputs are only triggered once + # - avoiding the extra allocations is a bigger win + if len(items) == 1: + return items[0] + elif len(items) > 2: + # NB: ideally we would use a frozenset, but that would require normalizing the + # order of entries in the simplified union, or updating the test harness to + # treat Unions as equivalent regardless of item ordering (which is particularly + # tricky when it comes to all tests using string matching on reveal_type output) + cache_key = tuple(items) + # NB: we need to maintain separate caches depending on flags that might impact + # the results of simplification + cache = _simplified_union_cache[ + int(keep_erased) + | int(contract_literals) << 1 + | int(state.strict_optional) << 2 + ] + ret = cache.get(cache_key, None) + if ret is not None: + return ret + cache_fn = lambda v: cache.__setitem__(cache_key, v) # noqa: E731 # Step 2: remove redundant unions simplified_set = _remove_redundant_union_items(items, keep_erased) @@ -381,13 +442,20 @@ def make_simplified_union(items: Sequence[Type], if contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1: simplified_set = try_contracting_literals_in_union(simplified_set) - return UnionType.make_union(simplified_set, line, column) + ret = UnionType.make_union(simplified_set, line, column) + + if cache_fn: + cache_fn(ret) + + return ret -def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> List[ProperType]: +def _remove_redundant_union_items(items: Sequence[ProperType], + keep_erased: bool) -> Sequence[ProperType]: from mypy.subtypes import is_proper_subtype removed: Set[int] = set() + truthed: Set[int] = set() seen: Set[Tuple[str, ...]] = set() # NB: having a separate fast path for Union of Literal and slow path for other things @@ -397,6 +465,7 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> for i, item in enumerate(items): if i in removed: continue + # Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169) k = simple_literal_value_key(item) if k is not None: @@ -434,20 +503,34 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> continue # actual redundancy checks if ( - is_redundant_literal_instance(item, tj) # XXX? - and is_proper_subtype(tj, item, keep_erased_types=keep_erased) + isinstance(tj, UninhabitedType) + or ( + ( + not isinstance(item, Instance) + or item.last_known_value is None + or ( + isinstance(tj, Instance) + and tj.last_known_value == item.last_known_value + ) + ) + and is_proper_subtype(tj, item, keep_erased_types=keep_erased) + ) ): # We found a redundant item in the union. removed.add(j) cbt = cbt or tj.can_be_true cbf = cbf or tj.can_be_false + # if deleted subtypes had more general truthiness, use that if not item.can_be_true and cbt: - items[i] = true_or_false(item) + truthed.add(i) elif not item.can_be_false and cbf: - items[i] = true_or_false(item) + truthed.add(i) - return [items[i] for i in range(len(items)) if i not in removed] + if not removed and not truthed: + return items + return [true_or_false(items[i]) if i in truthed else items[i] + for i in range(len(items)) if i not in removed] def _get_type_special_method_bool_ret_type(t: Type) -> Optional[Type]: @@ -889,17 +972,6 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool return False -def is_redundant_literal_instance(general: ProperType, specific: ProperType) -> bool: - if not isinstance(general, Instance) or general.last_known_value is None: - return True - if isinstance(specific, Instance) and specific.last_known_value == general.last_known_value: - return True - if isinstance(specific, UninhabitedType): - return True - - return False - - def separate_union_literals(t: UnionType) -> Tuple[Sequence[LiteralType], Sequence[Type]]: """Separate literals from other members in a union type.""" literal_items = []