From 9b4eac81eed87577388474559a00bd72ed09bd80 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 7 Aug 2023 23:33:43 +0100 Subject: [PATCH 01/15] Add basic support for polymorphic infernce with ParamSpec --- mypy/checkexpr.py | 43 +++++++++++++++---- mypy/constraints.py | 31 ++++++++++--- mypy/solve.py | 4 +- mypy/typeops.py | 19 +++++--- test-data/unit/check-generics.test | 36 ++++++++++++++++ test-data/unit/check-overloading.test | 4 +- .../unit/check-parameter-specification.test | 21 ++++++--- 7 files changed, 126 insertions(+), 32 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9e46d9ee39cb..dc8c268c34f5 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -122,6 +122,7 @@ false_only, fixup_partial_type, function_type, + get_all_type_vars, get_type_vars, is_literal_type_like, make_simplified_union, @@ -145,6 +146,7 @@ LiteralValue, NoneType, Overloaded, + Parameters, ParamSpecFlavor, ParamSpecType, PartialType, @@ -167,6 +169,7 @@ get_proper_types, has_recursive_types, is_named_instance, + remove_dups, split_with_prefix_and_suffix, ) from mypy.types_utils import ( @@ -5632,18 +5635,24 @@ def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None: self.bound_tvars: set[TypeVarLikeType] = set() self.seen_aliases: set[TypeInfo] = set() - def visit_callable_type(self, t: CallableType) -> Type: - found_vars = set() + def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]: + found_vars = [] for arg in t.arg_types: - found_vars |= set(get_type_vars(arg)) & self.poly_tvars + found_vars += [ + tv + for tv in get_all_type_vars(arg) + if tv in self.poly_tvars and tv not in self.bound_tvars + ] + return remove_dups(found_vars) - found_vars -= self.bound_tvars - self.bound_tvars |= found_vars + def visit_callable_type(self, t: CallableType) -> Type: + found_vars = self.collect_vars(t) + self.bound_tvars |= set(found_vars) result = super().visit_callable_type(t) - self.bound_tvars -= found_vars + self.bound_tvars -= set(found_vars) assert isinstance(result, ProperType) and isinstance(result, CallableType) - result.variables = list(result.variables) + list(found_vars) + result.variables = list(result.variables) + found_vars return result def visit_type_var(self, t: TypeVarType) -> Type: @@ -5652,8 +5661,9 @@ def visit_type_var(self, t: TypeVarType) -> Type: return super().visit_type_var(t) def visit_param_spec(self, t: ParamSpecType) -> Type: - # TODO: Support polymorphic apply for ParamSpec. - raise PolyTranslationError() + if t in self.poly_tvars and t not in self.bound_tvars: + raise PolyTranslationError() + return super().visit_param_spec(t) def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: # TODO: Support polymorphic apply for TypeVarTuple. @@ -5669,6 +5679,21 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type: raise PolyTranslationError() def visit_instance(self, t: Instance) -> Type: + if t.type.has_param_spec_type: + param_spec_index = next( + i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType) + ) + p = get_proper_type(t.args[param_spec_index]) + if isinstance(p, Parameters): + found_vars = self.collect_vars(p) + self.bound_tvars |= set(found_vars) + new_args = [a.accept(self) for a in t.args] + self.bound_tvars -= set(found_vars) + + repl = new_args[param_spec_index] + assert isinstance(repl, ProperType) and isinstance(repl, Parameters) + repl.variables = list(repl.variables) + list(found_vars) + return t.copy_modified(args=new_args) # There is the same problem with callback protocols as with aliases # (callback protocols are essentially more flexible aliases to callables). # Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T]. diff --git a/mypy/constraints.py b/mypy/constraints.py index 299c6292a259..0ff48f60f8a2 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -899,7 +899,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: cactual = self.actual.with_unpacked_kwargs() param_spec = template.param_spec() if param_spec is None: - # FIX verify argument counts + # TODO: verify argument counts; more generally, use some "formal to actual" map # TODO: Erase template variables if it is generic? if ( type_state.infer_polymorphic @@ -943,34 +943,52 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: cactual_args = cactual.arg_types # The lengths should match, but don't crash (it will error elsewhere). for t, a in zip(template_args, cactual_args): + if isinstance(a, ParamSpecType) and not isinstance(t, ParamSpecType): + # This avoids bogus constraints like T <: P.args + # TODO: figure out a more principled way to skip arg_kind mismatch + # (see also a similar to do item in corresponding branch below) + continue # Negate direction due to function argument type contravariance. res.extend(infer_constraints(t, a, neg_op(self.direction))) else: # sometimes, it appears we try to get constraints between two paramspec callables? - # TODO: Direction # TODO: check the prefixes match prefix = param_spec.prefix prefix_len = len(prefix.arg_types) cactual_ps = cactual.param_spec() + if type_state.infer_polymorphic and cactual.variables and not self.skip_neg_op: + # Similar logic to the branch above. + res.extend( + infer_constraints( + cactual, template, neg_op(self.direction), skip_neg_op=True + ) + ) + extra_tvars = True + if not cactual_ps: max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)]) prefix_len = min(prefix_len, max_prefix_len) res.append( Constraint( param_spec, - SUBTYPE_OF, - cactual.copy_modified( + neg_op(self.direction), + Parameters( arg_types=cactual.arg_types[prefix_len:], arg_kinds=cactual.arg_kinds[prefix_len:], arg_names=cactual.arg_names[prefix_len:], - ret_type=UninhabitedType(), + variables=cactual.variables + if not type_state.infer_polymorphic + else [], ), ) ) else: - res.append(Constraint(param_spec, SUBTYPE_OF, cactual_ps)) + if not param_spec.prefix.arg_types or cactual_ps.prefix.arg_types: + # TODO: figure out a more general logic to reject shorter prefix in actual. + # This may be actually fixed by a more general to do item above. + res.append(Constraint(param_spec, neg_op(self.direction), cactual_ps)) # compare prefixes cactual_prefix = cactual.copy_modified( @@ -979,7 +997,6 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: arg_names=cactual.arg_names[:prefix_len], ) - # TODO: see above "FIX" comments for param_spec is None case # TODO: this assumes positional arguments for t, a in zip(prefix.arg_types, cactual_prefix.arg_types): res.extend(infer_constraints(t, a, neg_op(self.direction))) diff --git a/mypy/solve.py b/mypy/solve.py index 72b3d6f26618..ea2db3c9a3cf 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -12,7 +12,7 @@ from mypy.join import join_types from mypy.meet import meet_type_list, meet_types from mypy.subtypes import is_subtype -from mypy.typeops import get_type_vars +from mypy.typeops import get_all_type_vars from mypy.types import ( AnyType, Instance, @@ -463,4 +463,4 @@ def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool: def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]: """Find type variables for which we are solving in a target type.""" - return {tv.id for tv in get_type_vars(target)} & set(vars) + return {tv.id for tv in get_all_type_vars(target)} & set(vars) diff --git a/mypy/typeops.py b/mypy/typeops.py index 65ab4340403c..c311c4ac504d 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -952,22 +952,31 @@ def coerce_to_literal(typ: Type) -> Type: def get_type_vars(tp: Type) -> list[TypeVarType]: - return tp.accept(TypeVarExtractor()) + return cast("list[TypeVarType]", tp.accept(TypeVarExtractor())) -class TypeVarExtractor(TypeQuery[List[TypeVarType]]): - def __init__(self) -> None: +def get_all_type_vars(tp: Type) -> list[TypeVarLikeType]: + # TODO: should we always use this function instead of get_type_vars() above? + return tp.accept(TypeVarExtractor(include_all=True)) + + +class TypeVarExtractor(TypeQuery[List[TypeVarLikeType]]): + def __init__(self, include_all: bool = False) -> None: super().__init__(self._merge) + self.include_all = include_all - def _merge(self, iter: Iterable[list[TypeVarType]]) -> list[TypeVarType]: + def _merge(self, iter: Iterable[list[TypeVarLikeType]]) -> list[TypeVarLikeType]: out = [] for item in iter: out.extend(item) return out - def visit_type_var(self, t: TypeVarType) -> list[TypeVarType]: + def visit_type_var(self, t: TypeVarType) -> list[TypeVarLikeType]: return [t] + def visit_param_spec(self, t: ParamSpecType) -> list[TypeVarLikeType]: + return [t] if self.include_all else [] + def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool: """Does this type have a custom special method such as __format__() or __eq__()? diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index d1842a74d634..23c739ebb41c 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -3035,3 +3035,39 @@ reveal_type(dec1(id2)) # N: Revealed type is "def [S in (builtins.int, builtins reveal_type(dec2(id1)) # N: Revealed type is "def [UC <: __main__.C] (UC`5) -> builtins.list[UC`5]" reveal_type(dec2(id2)) # N: Revealed type is "def () -> builtins.list[]" \ # E: Argument 1 to "dec2" has incompatible type "Callable[[V], V]"; expected "Callable[[], ]" + +[case testInferenceAgainstGenericParamSpecBasicInList] +# flags: --new-type-inference +from typing import TypeVar, Callable, List, Tuple +from typing_extensions import ParamSpec + +T = TypeVar('T') +P = ParamSpec('P') +U = TypeVar('U') +V = TypeVar('V') + +def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ... +def id(x: U) -> U: ... +def either(x: U, y: U) -> U: ... +def pair(x: U, y: V) -> Tuple[U, V]: ... +reveal_type(dec(id)) # N: Revealed type is "def [T] (x: T`2) -> builtins.list[T`2]" +reveal_type(dec(either)) # N: Revealed type is "def [T] (x: T`4, y: T`4) -> builtins.list[T`4]" +reveal_type(dec(pair)) # N: Revealed type is "def [U, V] (x: U`-1, y: V`-2) -> builtins.list[Tuple[U`-1, V`-2]]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericParamSpecBasicDeList] +# flags: --new-type-inference +from typing import TypeVar, Callable, List, Tuple +from typing_extensions import ParamSpec + +T = TypeVar('T') +P = ParamSpec('P') +U = TypeVar('U') +V = TypeVar('V') + +def dec(f: Callable[P, List[T]]) -> Callable[P, T]: ... +def id(x: U) -> U: ... +def either(x: U, y: U) -> U: ... +reveal_type(dec(id)) # N: Revealed type is "def [T] (x: builtins.list[T`2]) -> T`2" +reveal_type(dec(either)) # N: Revealed type is "def [T] (x: builtins.list[T`4], y: builtins.list[T`4]) -> T`4" +[builtins fixtures/list.pyi] diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 50acd7d77c8c..466216e11d15 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6456,7 +6456,7 @@ P = ParamSpec("P") R = TypeVar("R") @overload -def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ... @overload def func(x: Callable[P, R]) -> Callable[Concatenate[str, P], R]: ... def func(x: Callable[..., R]) -> Callable[..., R]: ... @@ -6474,7 +6474,7 @@ eggs = lambda: 'eggs' reveal_type(func(eggs)) # N: Revealed type is "def (builtins.str) -> builtins.str" spam: Callable[..., str] = lambda x, y: 'baz' -reveal_type(func(spam)) # N: Revealed type is "def (*Any, **Any) -> builtins.str" +reveal_type(func(spam)) # N: Revealed type is "def (*Any, **Any) -> Any" [builtins fixtures/paramspec.pyi] diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 114fe1f8438a..cea1ce089eff 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1048,10 +1048,10 @@ class Job(Generic[_P, _T]): def generic_f(x: _T) -> _T: ... j = Job(generic_f) -reveal_type(j) # N: Revealed type is "__main__.Job[[x: _T`-1], _T`-1]" +reveal_type(j) # N: Revealed type is "__main__.Job[[x: _T`2], _T`2]" jf = j.into_callable() -reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`-1) -> _T`-1" +reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`2) -> _T`2" reveal_type(jf(1)) # N: Revealed type is "builtins.int" [builtins fixtures/paramspec.pyi] @@ -1307,7 +1307,7 @@ reveal_type(bar(C(fn=foo, x=1))) # N: Revealed type is "__main__.C[[x: builtins [builtins fixtures/paramspec.pyi] [case testParamSpecClassConstructor] -from typing import ParamSpec, Callable +from typing import ParamSpec, Callable, TypeVar P = ParamSpec("P") @@ -1315,7 +1315,10 @@ class SomeClass: def __init__(self, a: str) -> None: pass -def func(t: Callable[P, SomeClass], val: Callable[P, SomeClass]) -> None: +def func(t: Callable[P, SomeClass], val: Callable[P, SomeClass]) -> Callable[P, SomeClass]: + pass + +def func_regular(t: Callable[[T], SomeClass], val: Callable[[T], SomeClass]) -> Callable[[T], SomeClass]: pass def constructor(a: str) -> SomeClass: @@ -1324,9 +1327,13 @@ def constructor(a: str) -> SomeClass: def wrong_constructor(a: bool) -> SomeClass: return SomeClass("a") +def wrong_name_constructor(b: bool) -> SomeClass: + return SomeClass("a") + func(SomeClass, constructor) -func(SomeClass, wrong_constructor) # E: Argument 1 to "func" has incompatible type "Type[SomeClass]"; expected "Callable[[VarArg(), KwArg()], SomeClass]" \ - # E: Argument 2 to "func" has incompatible type "Callable[[bool], SomeClass]"; expected "Callable[[VarArg(), KwArg()], SomeClass]" +reveal_type(func(SomeClass, wrong_constructor)) # N: Revealed type is "def (a: ) -> __main__.SomeClass" +reveal_type(func_regular(SomeClass, wrong_constructor)) # N: Revealed type is "def () -> __main__.SomeClass" +func(SomeClass, wrong_name_constructor) # E: Argument 1 to "func" has incompatible type "Type[SomeClass]"; expected "Callable[[], SomeClass]" [builtins fixtures/paramspec.pyi] [case testParamSpecInTypeAliasBasic] @@ -1547,5 +1554,5 @@ U = TypeVar("U") def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ... def test(x: U) -> U: ... reveal_type(dec) # N: Revealed type is "def [P, T] (f: def (*P.args, **P.kwargs) -> T`-2) -> def (*P.args, **P.kwargs) -> builtins.list[T`-2]" -reveal_type(dec(test)) # N: Revealed type is "def [U] (x: U`-1) -> builtins.list[U`-1]" +reveal_type(dec(test)) # N: Revealed type is "def [T] (x: T`2) -> builtins.list[T`2]" [builtins fixtures/paramspec.pyi] From 1959136f2cbd4b98757cad4aaf53c189c484b201 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 8 Aug 2023 18:51:35 +0100 Subject: [PATCH 02/15] Some ParamSpec cleanup --- mypy/applytype.py | 2 +- mypy/constraints.py | 19 ++++--------------- mypy/expandtype.py | 4 ++-- mypy/types.py | 5 ++--- 4 files changed, 9 insertions(+), 21 deletions(-) diff --git a/mypy/applytype.py b/mypy/applytype.py index 55a51d4adbb6..b919aaec67ef 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -110,7 +110,7 @@ def apply_generic_arguments( nt = id_to_type.get(param_spec.id) if nt is not None: nt = get_proper_type(nt) - if isinstance(nt, (CallableType, Parameters)): + if isinstance(nt, Parameters): callable = callable.expand_param_spec(nt) # Apply arguments to argument types. diff --git a/mypy/constraints.py b/mypy/constraints.py index 0ff48f60f8a2..6e04e0c972b3 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -9,7 +9,7 @@ from mypy.argmap import ArgTypeExpander from mypy.erasetype import erase_typevars from mypy.maptype import map_instance_to_supertype -from mypy.nodes import ARG_OPT, ARG_POS, CONTRAVARIANT, COVARIANT, ArgKind +from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, CONTRAVARIANT, COVARIANT, ArgKind from mypy.types import ( TUPLE_LIKE_INSTANCE_NAMES, AnyType, @@ -40,7 +40,6 @@ UninhabitedType, UnionType, UnpackType, - callable_with_ellipsis, get_proper_type, has_recursive_types, has_type_vars, @@ -699,12 +698,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]: elif isinstance(tvar, ParamSpecType) and isinstance(mapped_arg, ParamSpecType): suffix = get_proper_type(instance_arg) - if isinstance(suffix, CallableType): - prefix = mapped_arg.prefix - from_concat = bool(prefix.arg_types) or suffix.from_concatenate - suffix = suffix.copy_modified(from_concatenate=from_concat) - - if isinstance(suffix, (Parameters, CallableType)): + if isinstance(suffix, Parameters): # no such thing as variance for ParamSpecs # TODO: is there a case I am missing? # TODO: constraints between prefixes @@ -769,12 +763,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]: ): suffix = get_proper_type(mapped_arg) - if isinstance(suffix, CallableType): - prefix = template_arg.prefix - from_concat = bool(prefix.arg_types) or suffix.from_concatenate - suffix = suffix.copy_modified(from_concatenate=from_concat) - - if isinstance(suffix, (Parameters, CallableType)): + if isinstance(suffix, Parameters): # no such thing as variance for ParamSpecs # TODO: is there a case I am missing? # TODO: constraints between prefixes @@ -1023,7 +1012,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: Constraint( param_spec, SUBTYPE_OF, - callable_with_ellipsis(any_type, any_type, template.fallback), + Parameters([any_type, any_type], [ARG_STAR, ARG_STAR2], [None, None]), ) ] res.extend(infer_constraints(template.ret_type, any_type, self.direction)) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index b599b49e4c12..a8e4b4df3014 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -239,7 +239,7 @@ def visit_param_spec(self, t: ParamSpecType) -> Type: # TODO: what does prefix mean in this case? # TODO: why does this case even happen? Instances aren't plural. return repl - elif isinstance(repl, (ParamSpecType, Parameters, CallableType)): + elif isinstance(repl, (ParamSpecType, Parameters)): if isinstance(repl, ParamSpecType): return repl.copy_modified( flavor=t.flavor, @@ -395,7 +395,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType: # must expand both of them with all the argument types, # kinds and names in the replacement. The return type in # the replacement is ignored. - if isinstance(repl, (CallableType, Parameters)): + if isinstance(repl, Parameters): # Substitute *args: P.args, **kwargs: P.kwargs prefix = param_spec.prefix # we need to expand the types in the prefix, so might as well diff --git a/mypy/types.py b/mypy/types.py index d13cff00c06d..9bbe3dbbad35 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2042,9 +2042,8 @@ def param_spec(self) -> ParamSpecType | None: return arg_type.copy_modified(flavor=ParamSpecFlavor.BARE, prefix=prefix) - def expand_param_spec( - self, c: CallableType | Parameters, no_prefix: bool = False - ) -> CallableType: + def expand_param_spec(self, c: Parameters, no_prefix: bool = False) -> CallableType: + # TODO: try deleting variables from Parameters after new type inference is default. variables = c.variables if no_prefix: From e796c6a557a94287bf0ed9cd177502f6db9887b4 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 8 Aug 2023 19:46:17 +0100 Subject: [PATCH 03/15] Add pop-off/on tests plus small fix --- mypy/constraints.py | 2 +- mypy/options.py | 2 +- mypy/solve.py | 1 + test-data/unit/check-generics.test | 40 ++++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 2 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 6e04e0c972b3..e04c9a4ec5ce 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -999,7 +999,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction)) if extra_tvars: for c in res: - c.extra_tvars = list(cactual.variables) + c.extra_tvars += cactual.variables return res elif isinstance(self.actual, AnyType): param_spec = template.param_spec() diff --git a/mypy/options.py b/mypy/options.py index 75343acd38bb..e7074f07f7f6 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -349,7 +349,7 @@ def __init__(self) -> None: # -1 means unlimited. self.many_errors_threshold = defaults.MANY_ERRORS_THRESHOLD # Enable new experimental type inference algorithm. - self.new_type_inference = False + self.new_type_inference = True # Disable recursive type aliases (currently experimental) self.disable_recursive_aliases = False # Deprecated reverse version of the above, do not use. diff --git a/mypy/solve.py b/mypy/solve.py index ea2db3c9a3cf..3fcda0c46b03 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -346,6 +346,7 @@ def normalize_constraints( """ res = constraints.copy() for c in constraints: + # TODO: be careful with ParamSpecType here. if isinstance(c.target, TypeVarType): res.append(Constraint(c.target, neg_op(c.op), c.origin_type_var)) return [c for c in remove_dups(constraints) if c.type_var in vars] diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 23c739ebb41c..6d2bbc8cec7a 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -3071,3 +3071,43 @@ def either(x: U, y: U) -> U: ... reveal_type(dec(id)) # N: Revealed type is "def [T] (x: builtins.list[T`2]) -> T`2" reveal_type(dec(either)) # N: Revealed type is "def [T] (x: builtins.list[T`4], y: builtins.list[T`4]) -> T`4" [builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericParamSpecPopOff] +# flags: --new-type-inference +from typing import TypeVar, Callable, List, Tuple +from typing_extensions import ParamSpec, Concatenate + +T = TypeVar('T') +S = TypeVar('S') +P = ParamSpec('P') +U = TypeVar('U') +V = TypeVar('V') + +def dec(f: Callable[Concatenate[T, P], S]) -> Callable[P, Callable[[T], S]]: ... +def id(x: U) -> U: ... +def either(x: U, y: U) -> U: ... +def pair(x: U, y: V) -> Tuple[U, V]: ... +reveal_type(dec(id)) # N: Revealed type is "def () -> def [T] (T`1) -> T`1" +reveal_type(dec(either)) # N: Revealed type is "def [T] (y: T`4) -> def (T`4) -> T`4" +reveal_type(dec(pair)) # N: Revealed type is "def [V] (y: V`-2) -> def [T] (T`7) -> Tuple[T`7, V`-2]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericParamSpecPopOn] +# flags: --new-type-inference +from typing import TypeVar, Callable, List, Tuple +from typing_extensions import ParamSpec, Concatenate + +T = TypeVar('T') +S = TypeVar('S') +P = ParamSpec('P') +U = TypeVar('U') +V = TypeVar('V') + +def dec(f: Callable[P, Callable[[T], S]]) -> Callable[Concatenate[T, P], S]: ... +def id() -> Callable[[U], U]: ... +def either(x: U) -> Callable[[U], U]: ... +def pair(x: U) -> Callable[[V], Tuple[V, U]]: ... +reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> T`2" +reveal_type(dec(either)) # N: Revealed type is "def [T] (T`5, x: T`5) -> T`5" +reveal_type(dec(pair)) # N: Revealed type is "def [T, U] (T`8, x: U`-1) -> Tuple[T`8, U`-1]" +[builtins fixtures/list.pyi] From 177b312306cdc934c8f28dff07274d99b1e8488e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 8 Aug 2023 20:22:39 +0100 Subject: [PATCH 04/15] Some more ParamSpec cleanup --- mypy/constraints.py | 12 ++++++------ mypy/join.py | 10 ++++++++-- mypy/meet.py | 5 +++-- mypy/options.py | 2 +- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index e04c9a4ec5ce..22af47d7a7fc 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -699,7 +699,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]: suffix = get_proper_type(instance_arg) if isinstance(suffix, Parameters): - # no such thing as variance for ParamSpecs + # No such thing as variance for ParamSpecs, consider them covariant # TODO: is there a case I am missing? # TODO: constraints between prefixes prefix = mapped_arg.prefix @@ -708,9 +708,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]: suffix.arg_kinds[len(prefix.arg_kinds) :], suffix.arg_names[len(prefix.arg_names) :], ) - res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix)) + res.append(Constraint(mapped_arg, self.direction, suffix)) elif isinstance(suffix, ParamSpecType): - res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix)) + res.append(Constraint(mapped_arg, self.direction, suffix)) else: # This case should have been handled above. assert not isinstance(tvar, TypeVarTupleType) @@ -764,7 +764,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]: suffix = get_proper_type(mapped_arg) if isinstance(suffix, Parameters): - # no such thing as variance for ParamSpecs + # No such thing as variance for ParamSpecs, consider them covariant # TODO: is there a case I am missing? # TODO: constraints between prefixes prefix = template_arg.prefix @@ -774,9 +774,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]: suffix.arg_kinds[len(prefix.arg_kinds) :], suffix.arg_names[len(prefix.arg_names) :], ) - res.append(Constraint(template_arg, SUPERTYPE_OF, suffix)) + res.append(Constraint(template_arg, self.direction, suffix)) elif isinstance(suffix, ParamSpecType): - res.append(Constraint(template_arg, SUPERTYPE_OF, suffix)) + res.append(Constraint(template_arg, self.direction, suffix)) else: # This case should have been handled above. assert not isinstance(tvar, TypeVarTupleType) diff --git a/mypy/join.py b/mypy/join.py index f4af59f4e50b..806c644a680c 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -315,8 +315,14 @@ def visit_unpack_type(self, t: UnpackType) -> UnpackType: raise NotImplementedError def visit_parameters(self, t: Parameters) -> ProperType: - if self.s == t: - return t + if isinstance(self.s, Parameters): + if len(t.arg_types) != len(self.s.arg_types): + return self.default(self.s) + return t.copy_modified( + # Note that since during constraint inference we already treat whole ParamSpec as + # contravariant, we should join individual items, not meet them like for Callables + arg_types=[join_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)] + ) else: return self.default(self.s) diff --git a/mypy/meet.py b/mypy/meet.py index 29c4d3663503..e3a22a226575 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -701,11 +701,12 @@ def visit_unpack_type(self, t: UnpackType) -> ProperType: raise NotImplementedError def visit_parameters(self, t: Parameters) -> ProperType: - # TODO: is this the right variance? - if isinstance(self.s, (Parameters, CallableType)): + if isinstance(self.s, Parameters): if len(t.arg_types) != len(self.s.arg_types): return self.default(self.s) return t.copy_modified( + # Note that since during constraint inference we already treat whole ParamSpec as + # contravariant, we should meet individual items, not join them like for Callables arg_types=[meet_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)] ) else: diff --git a/mypy/options.py b/mypy/options.py index e7074f07f7f6..75343acd38bb 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -349,7 +349,7 @@ def __init__(self) -> None: # -1 means unlimited. self.many_errors_threshold = defaults.MANY_ERRORS_THRESHOLD # Enable new experimental type inference algorithm. - self.new_type_inference = True + self.new_type_inference = False # Disable recursive type aliases (currently experimental) self.disable_recursive_aliases = False # Deprecated reverse version of the above, do not use. From 420f60d6f108382dc91a3846ab7efbe3a32b039d Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 8 Aug 2023 23:40:56 +0100 Subject: [PATCH 05/15] More tests and fixes --- mypy/checkexpr.py | 14 ++++--- mypy/constraints.py | 19 ++++++++- mypy/solve.py | 17 ++++++-- mypy/type_visitor.py | 2 +- mypy/typeanal.py | 20 +++++++-- test-data/unit/check-generics.test | 42 +++++++++++++++++++ .../unit/check-parameter-specification.test | 3 +- 7 files changed, 101 insertions(+), 16 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index dc8c268c34f5..80fdff852220 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5638,11 +5638,15 @@ def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None: def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]: found_vars = [] for arg in t.arg_types: - found_vars += [ - tv - for tv in get_all_type_vars(arg) - if tv in self.poly_tvars and tv not in self.bound_tvars - ] + for tv in get_all_type_vars(arg): + if isinstance(tv, ParamSpecType): + normalized: TypeVarLikeType = tv.copy_modified( + flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], []) + ) + else: + normalized = tv + if normalized in self.poly_tvars and normalized not in self.bound_tvars: + found_vars.append(normalized) return remove_dups(found_vars) def visit_callable_type(self, t: CallableType) -> Type: diff --git a/mypy/constraints.py b/mypy/constraints.py index 22af47d7a7fc..2457536db56e 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -580,6 +580,17 @@ def visit_parameters(self, template: Parameters) -> list[Constraint]: # ... which seems like the only case this can happen. Better to fail loudly. if isinstance(self.actual, AnyType): return self.infer_against_any(template.arg_types, self.actual) + if type_state.infer_polymorphic and isinstance(self.actual, Parameters): + res = [] + if len(template.arg_types) == len(self.actual.arg_types): + # TODO: this may assume positional arguments + for tt, at, k in zip( + template.arg_types, self.actual.arg_types, self.actual.arg_kinds + ): + if k in (ARG_STAR, ARG_STAR2): + continue + res.extend(infer_constraints(tt, at, self.direction)) + return res raise RuntimeError("Parameters cannot be constrained to") # Non-leaf types @@ -986,8 +997,12 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: arg_names=cactual.arg_names[:prefix_len], ) - # TODO: this assumes positional arguments - for t, a in zip(prefix.arg_types, cactual_prefix.arg_types): + # TODO: this may assume positional arguments + for t, a, k in zip( + prefix.arg_types, cactual_prefix.arg_types, cactual_prefix.arg_kinds + ): + if k in (ARG_STAR, ARG_STAR2): + continue res.extend(infer_constraints(t, a, neg_op(self.direction))) template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type diff --git a/mypy/solve.py b/mypy/solve.py index 3fcda0c46b03..f02932fa753d 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -17,6 +17,7 @@ AnyType, Instance, NoneType, + ParamSpecType, ProperType, Type, TypeOfAny, @@ -346,8 +347,11 @@ def normalize_constraints( """ res = constraints.copy() for c in constraints: - # TODO: be careful with ParamSpecType here. - if isinstance(c.target, TypeVarType): + if ( + isinstance(c.target, TypeVarType) + or isinstance(c.target, ParamSpecType) + and not c.target.prefix.arg_types + ): res.append(Constraint(c.target, neg_op(c.op), c.origin_type_var)) return [c for c in remove_dups(constraints) if c.type_var in vars] @@ -381,7 +385,14 @@ def transitive_closure( remaining = set(constraints) while remaining: c = remaining.pop() - if isinstance(c.target, TypeVarType) and c.target.id in tvars: + # Note that ParamSpec constraint P <: Q may be considered linear only if Q has no prefix, + # for cases like P <: Concatenate[T, Q] we should consider this non-linear and put {P} and + # {T, Q} into separate SCCs. + if ( + isinstance(c.target, TypeVarType) + or isinstance(c.target, ParamSpecType) + and not c.target.prefix.arg_types + ) and c.target.id in tvars: if c.op == SUBTYPE_OF: lower, upper = c.type_var, c.target.id else: diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index cbfa43a77b81..1860a43eb14f 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -348,7 +348,7 @@ def visit_type_var(self, t: TypeVarType) -> T: return self.query_types([t.upper_bound, t.default] + t.values) def visit_param_spec(self, t: ParamSpecType) -> T: - return self.query_types([t.upper_bound, t.default]) + return self.query_types([t.upper_bound, t.default, t.prefix]) def visit_type_var_tuple(self, t: TypeVarTupleType) -> T: return self.query_types([t.upper_bound, t.default]) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index d894e2cc8c51..a125f8953741 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -1244,9 +1244,23 @@ def analyze_callable_type(self, t: UnboundType) -> Type: ) else: # Callable[P, RET] (where P is ParamSpec) - maybe_ret = self.analyze_callable_args_for_paramspec( - callable_args, ret_type, fallback - ) or self.analyze_callable_args_for_concatenate(callable_args, ret_type, fallback) + with self.tvar_scope_frame(): + # Temporarily bind ParamSpecs to allow code like this: + # my_fun: Callable[Q, Foo[Q]] + # We usually do this later in visit_callable_type(), but the analysis + # below happens at very early stage. + variables = [] + for name, tvar_expr in self.find_type_var_likes(callable_args): + variables.append(self.tvar_scope.bind_new(name, tvar_expr)) + maybe_ret = self.analyze_callable_args_for_paramspec( + callable_args, ret_type, fallback + ) or self.analyze_callable_args_for_concatenate( + callable_args, ret_type, fallback + ) + if maybe_ret: + maybe_ret = maybe_ret.copy_modified( + ret_type=ret_type.accept(self), variables=variables + ) if maybe_ret is None: # Callable[?, RET] (where ? is something invalid) self.fail( diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 6d2bbc8cec7a..3926815374a4 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -3090,6 +3090,7 @@ def pair(x: U, y: V) -> Tuple[U, V]: ... reveal_type(dec(id)) # N: Revealed type is "def () -> def [T] (T`1) -> T`1" reveal_type(dec(either)) # N: Revealed type is "def [T] (y: T`4) -> def (T`4) -> T`4" reveal_type(dec(pair)) # N: Revealed type is "def [V] (y: V`-2) -> def [T] (T`7) -> Tuple[T`7, V`-2]" +reveal_type(dec(dec)) # N: Revealed type is "def () -> def [T, P, S] (def (T`-1, *P.args, **P.kwargs) -> S`-3) -> def (*P.args, **P.kwargs) -> def (T`-1) -> S`-3" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericParamSpecPopOn] @@ -3110,4 +3111,45 @@ def pair(x: U) -> Callable[[V], Tuple[V, U]]: ... reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> T`2" reveal_type(dec(either)) # N: Revealed type is "def [T] (T`5, x: T`5) -> T`5" reveal_type(dec(pair)) # N: Revealed type is "def [T, U] (T`8, x: U`-1) -> Tuple[T`8, U`-1]" +# This is counter-intuitive but looks correct, dec matches itself only if P is empty +reveal_type(dec(dec)) # N: Revealed type is "def [T, S] (T`11, f: def () -> def (T`11) -> S`12) -> S`12" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericParamSpecVsParamSpec] +# flags: --new-type-inference +from typing import TypeVar, Callable, List, Tuple, Generic +from typing_extensions import ParamSpec, Concatenate + +T = TypeVar('T') +P = ParamSpec('P') +Q = ParamSpec('Q') + +class Foo(Generic[P]): ... +class Bar(Generic[P, T]): ... + +def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ... +def f(*args: Q.args, **kwargs: Q.kwargs) -> Foo[Q]: ... +reveal_type(dec(f)) # N: Revealed type is "def [P] (*P.args, **P.kwargs) -> builtins.list[__main__.Foo[P`1]]" +g: Callable[Concatenate[int, Q], Foo[Q]] +reveal_type(dec(g)) # N: Revealed type is "def [Q] (builtins.int, *Q.args, **Q.kwargs) -> builtins.list[__main__.Foo[Q`-1]]" +h: Callable[Concatenate[T, Q], Bar[Q, T]] +reveal_type(dec(h)) # N: Revealed type is "def [T, Q] (T`-1, *Q.args, **Q.kwargs) -> builtins.list[__main__.Bar[Q`-2, T`-1]]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericParamSpecSecondary] +# flags: --new-type-inference +from typing import TypeVar, Callable, List, Tuple, Generic +from typing_extensions import ParamSpec, Concatenate + +T = TypeVar('T') +P = ParamSpec('P') +Q = ParamSpec('Q') + +class Foo(Generic[P]): ... + +def dec(f: Callable[P, Foo[P]]) -> Callable[P, Foo[P]]: ... +g: Callable[[T], Foo[[int]]] +reveal_type(dec(g)) # N: Revealed type is "def (builtins.int) -> __main__.Foo[[builtins.int]]" +h: Callable[Q, Foo[[int]]] +reveal_type(dec(g)) # N: Revealed type is "def (builtins.int) -> __main__.Foo[[builtins.int]]" [builtins fixtures/list.pyi] diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index cea1ce089eff..d9b1f1052051 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1473,8 +1473,7 @@ reveal_type(gs) # N: Revealed type is "builtins.list[def (builtins.int, builtin T = TypeVar("T") class C(Generic[T]): ... -C[Callable[P, int]]() # E: The first argument to Callable must be a list of types, parameter specification, or "..." \ - # N: See https://mypy.readthedocs.io/en/stable/kinds_of_types.html#callable-types-and-lambdas +C[Callable[P, int]]() [builtins fixtures/paramspec.pyi] [case testConcatDeferralNoCrash] From d7cfbe9fd1dc282affd03c025a0ac95199af98c6 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 9 Aug 2023 11:17:07 +0100 Subject: [PATCH 06/15] Support lambdas --- mypy/checker.py | 13 +++++++-- mypy/checkexpr.py | 20 +++++++++++-- test-data/unit/check-generics.test | 31 ++++++++++++++++++++- test-data/unit/check-inference-context.test | 3 +- test-data/unit/check-inference.test | 8 ++++-- 5 files changed, 65 insertions(+), 10 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index b786155079e5..b39ba7d0f1b4 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4280,12 +4280,14 @@ def check_return_stmt(self, s: ReturnStmt) -> None: return_type = self.return_types[-1] return_type = get_proper_type(return_type) + is_lambda = isinstance(self.scope.top_function(), LambdaExpr) if isinstance(return_type, UninhabitedType): - self.fail(message_registry.NO_RETURN_EXPECTED, s) - return + # Avoid extra error messages for failed inference in lambdas + if not is_lambda or not return_type.ambiguous: + self.fail(message_registry.NO_RETURN_EXPECTED, s) + return if s.expr: - is_lambda = isinstance(self.scope.top_function(), LambdaExpr) declared_none_return = isinstance(return_type, NoneType) declared_any_return = isinstance(return_type, AnyType) @@ -7366,6 +7368,11 @@ def visit_erased_type(self, t: ErasedType) -> bool: # This can happen inside a lambda. return True + def visit_type_var(self, t: TypeVarType) -> bool: + # This is needed to prevent leaking into partial types during + # multi-step type inference. + return t.id.is_meta_var() + class SetNothingToAny(TypeTranslator): """Replace all ambiguous types with Any (to avoid spurious extra errors).""" diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 80fdff852220..3454320ac194 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1858,6 +1858,8 @@ def infer_function_type_arguments_using_context( # def identity(x: T) -> T: return x # # expects_literal(identity(3)) # Should type-check + # TODO: we may want to add similar exception if all arguments are lambdas, since + # in this case external context is almost everything we have. if not is_generic_instance(ctx) and not is_literal_type_like(ctx): return callable.copy_modified() args = infer_type_arguments(callable.variables, ret_type, erased_ctx) @@ -4677,8 +4679,22 @@ def infer_lambda_type_using_context( # they must be considered as indeterminate. We use ErasedType since it # does not affect type inference results (it is for purposes like this # only). - callable_ctx = get_proper_type(replace_meta_vars(ctx, ErasedType())) - assert isinstance(callable_ctx, CallableType) + if self.chk.options.new_type_inference: + # With new type inference we can preserve argument types even if they + # are generic, since new inference algorithm can handle constraints + # like S <: T (we still erase return type since it's ultimately unknown). + extra_vars = [] + for arg in ctx.arg_types: + meta_vars = [tv for tv in get_all_type_vars(arg) if tv.id.is_meta_var()] + extra_vars.extend([tv for tv in meta_vars if tv not in extra_vars]) + callable_ctx = ctx.copy_modified( + ret_type=replace_meta_vars(ctx.ret_type, ErasedType()), + variables=list(ctx.variables) + extra_vars, + ) + else: + erased_ctx = replace_meta_vars(ctx, ErasedType()) + assert isinstance(erased_ctx, ProperType) and isinstance(erased_ctx, CallableType) + callable_ctx = erased_ctx # The callable_ctx may have a fallback of builtins.type if the context # is a constructor -- but this fallback doesn't make sense for lambdas. diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 3926815374a4..0f98ef91cc66 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2713,6 +2713,7 @@ reveal_type(func(1)) # N: Revealed type is "builtins.int" [builtins fixtures/tuple.pyi] [case testGenericLambdaGenericMethodNoCrash] +# flags: --new-type-inference from typing import TypeVar, Union, Callable, Generic S = TypeVar("S") @@ -2723,7 +2724,7 @@ def f(x: Callable[[G[T]], int]) -> T: ... class G(Generic[T]): def g(self, x: S) -> Union[S, T]: ... -f(lambda x: x.g(0)) # E: Cannot infer type argument 1 of "f" +f(lambda x: x.g(0)) # E: Incompatible return value type (got "Union[int, T]", expected "int") [case testDictStarInference] class B: ... @@ -3036,6 +3037,34 @@ reveal_type(dec2(id1)) # N: Revealed type is "def [UC <: __main__.C] (UC`5) -> reveal_type(dec2(id2)) # N: Revealed type is "def () -> builtins.list[]" \ # E: Argument 1 to "dec2" has incompatible type "Callable[[V], V]"; expected "Callable[[], ]" +[case testInferenceAgainstGenericLambdas] +# flags: --new-type-inference +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') + +def dec1(f: Callable[[T], T]) -> Callable[[T], List[T]]: + ... +def dec2(f: Callable[[S], T]) -> Callable[[S], List[T]]: + ... +def dec3(f: Callable[[List[S]], T]) -> Callable[[S], T]: + ... +def dec4(f: Callable[[S], List[T]]) -> Callable[[S], T]: + ... +def dec5(f: Callable[[int], T]) -> Callable[[int], List[T]]: + def g(x: int) -> List[T]: + return [f(x)] * x + return g + +reveal_type(dec1(lambda x: x)) # N: Revealed type is "def [T] (T`2) -> builtins.list[T`2]" +reveal_type(dec2(lambda x: x)) # N: Revealed type is "def [S] (S`3) -> builtins.list[S`3]" +reveal_type(dec3(lambda x: x[0])) # N: Revealed type is "def [S] (S`5) -> S`5" +reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`7) -> S`7" +reveal_type(dec1(lambda x: 1)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]" +reveal_type(dec5(lambda x: x)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]" +[builtins fixtures/list.pyi] + [case testInferenceAgainstGenericParamSpecBasicInList] # flags: --new-type-inference from typing import TypeVar, Callable, List, Tuple diff --git a/test-data/unit/check-inference-context.test b/test-data/unit/check-inference-context.test index ba36c1548532..5f25b007dd47 100644 --- a/test-data/unit/check-inference-context.test +++ b/test-data/unit/check-inference-context.test @@ -693,6 +693,7 @@ f(lambda: None) g(lambda: None) [case testIsinstanceInInferredLambda] +# flags: --new-type-inference from typing import TypeVar, Callable, Optional T = TypeVar('T') S = TypeVar('S') @@ -700,7 +701,7 @@ class A: pass class B(A): pass class C(A): pass def f(func: Callable[[T], S], *z: T, r: Optional[S] = None) -> S: pass -f(lambda x: 0 if isinstance(x, B) else 1) # E: Cannot infer type argument 1 of "f" +reveal_type(f(lambda x: 0 if isinstance(x, B) else 1)) # N: Revealed type is "builtins.int" f(lambda x: 0 if isinstance(x, B) else 1, A())() # E: "int" not callable f(lambda x: x if isinstance(x, B) else B(), A(), r=B())() # E: "B" not callable f( diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index e0f29a19ec1d..c53cb8b75da3 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -1375,19 +1375,21 @@ class B: pass [builtins fixtures/list.pyi] [case testUninferableLambda] +# flags: --new-type-inference from typing import TypeVar, Callable X = TypeVar('X') def f(x: Callable[[X], X]) -> X: pass -y = f(lambda x: x) # E: Cannot infer type argument 1 of "f" +y = f(lambda x: x) # E: Need type annotation for "y" [case testUninferableLambdaWithTypeError] +# flags: --new-type-inference from typing import TypeVar, Callable X = TypeVar('X') def f(x: Callable[[X], X], y: str) -> X: pass y = f(lambda x: x, 1) # Fail [out] -main:4: error: Cannot infer type argument 1 of "f" -main:4: error: Argument 2 to "f" has incompatible type "int"; expected "str" +main:5: error: Need type annotation for "y" +main:5: error: Argument 2 to "f" has incompatible type "int"; expected "str" [case testInferLambdaNone] # flags: --no-strict-optional From d4c91462072d3525310eef77a651a11dc0846c42 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 9 Aug 2023 16:33:02 +0100 Subject: [PATCH 07/15] Fix and support Concatenate --- mypy/applytype.py | 1 - mypy/constraints.py | 119 ++++++++++++++++++----------- mypy/expandtype.py | 100 ++++++++++-------------- mypy/solve.py | 4 +- mypy/subtypes.py | 14 ++-- mypy/test/testtypes.py | 2 +- mypy/typeanal.py | 2 + mypy/types.py | 42 +++++----- test-data/unit/check-generics.test | 41 ++++++++++ 9 files changed, 185 insertions(+), 140 deletions(-) diff --git a/mypy/applytype.py b/mypy/applytype.py index b919aaec67ef..a98797270768 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -109,7 +109,6 @@ def apply_generic_arguments( if param_spec is not None: nt = id_to_type.get(param_spec.id) if nt is not None: - nt = get_proper_type(nt) if isinstance(nt, Parameters): callable = callable.expand_param_spec(nt) diff --git a/mypy/constraints.py b/mypy/constraints.py index 2457536db56e..837f47dedcdc 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -576,18 +576,22 @@ def visit_unpack_type(self, template: UnpackType) -> list[Constraint]: raise RuntimeError("Mypy bug: unpack should be handled at a higher level.") def visit_parameters(self, template: Parameters) -> list[Constraint]: - # constraining Any against C[P] turns into infer_against_any([P], Any) - # ... which seems like the only case this can happen. Better to fail loudly. + # Constraining Any against C[P] turns into infer_against_any([P], Any) + # ... which seems like the only case this can happen. Better to fail loudly otherwise. if isinstance(self.actual, AnyType): return self.infer_against_any(template.arg_types, self.actual) if type_state.infer_polymorphic and isinstance(self.actual, Parameters): + # For polymorphic inference we need to be able to infer secondary constraints + # in situations like [x: T] <: P <: [x: int]. res = [] if len(template.arg_types) == len(self.actual.arg_types): - # TODO: this may assume positional arguments - for tt, at, k in zip( - template.arg_types, self.actual.arg_types, self.actual.arg_kinds + for tt, at, tk, ak in zip( + template.arg_types, + self.actual.arg_types, + template.arg_kinds, + self.actual.arg_kinds, ): - if k in (ARG_STAR, ARG_STAR2): + if tk == ARG_STAR and ak != ARG_STAR or tk == ARG_STAR2 and ak != ARG_STAR2: continue res.extend(infer_constraints(tt, at, self.direction)) return res @@ -696,7 +700,6 @@ def visit_instance(self, template: Instance) -> list[Constraint]: # N.B: We use zip instead of indexing because the lengths might have # mismatches during daemon reprocessing. for tvar, mapped_arg, instance_arg in zip(tvars, mapped_args, instance_args): - # TODO(PEP612): More ParamSpec work (or is Parameters the only thing accepted) if isinstance(tvar, TypeVarType): # The constraints for generic type parameters depend on variance. # Include constraints from both directions if invariant. @@ -707,21 +710,27 @@ def visit_instance(self, template: Instance) -> list[Constraint]: infer_constraints(mapped_arg, instance_arg, neg_op(self.direction)) ) elif isinstance(tvar, ParamSpecType) and isinstance(mapped_arg, ParamSpecType): - suffix = get_proper_type(instance_arg) - - if isinstance(suffix, Parameters): - # No such thing as variance for ParamSpecs, consider them covariant - # TODO: is there a case I am missing? + prefix = mapped_arg.prefix + if isinstance(instance_arg, Parameters): + # No such thing as variance for ParamSpecs, consider them invariant # TODO: constraints between prefixes - prefix = mapped_arg.prefix - suffix = suffix.copy_modified( - suffix.arg_types[len(prefix.arg_types) :], - suffix.arg_kinds[len(prefix.arg_kinds) :], - suffix.arg_names[len(prefix.arg_names) :], + suffix: Type = instance_arg.copy_modified( + instance_arg.arg_types[len(prefix.arg_types) :], + instance_arg.arg_kinds[len(prefix.arg_kinds) :], + instance_arg.arg_names[len(prefix.arg_names) :], + ) + res.append(Constraint(mapped_arg, SUBTYPE_OF, suffix)) + res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix)) + elif isinstance(instance_arg, ParamSpecType): + suffix = instance_arg.copy_modified( + prefix=Parameters( + instance_arg.prefix.arg_types[len(prefix.arg_types) :], + instance_arg.prefix.arg_kinds[len(prefix.arg_kinds) :], + instance_arg.prefix.arg_names[len(prefix.arg_names) :], + ) ) - res.append(Constraint(mapped_arg, self.direction, suffix)) - elif isinstance(suffix, ParamSpecType): - res.append(Constraint(mapped_arg, self.direction, suffix)) + res.append(Constraint(mapped_arg, SUBTYPE_OF, suffix)) + res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix)) else: # This case should have been handled above. assert not isinstance(tvar, TypeVarTupleType) @@ -772,22 +781,27 @@ def visit_instance(self, template: Instance) -> list[Constraint]: elif isinstance(tvar, ParamSpecType) and isinstance( template_arg, ParamSpecType ): - suffix = get_proper_type(mapped_arg) - - if isinstance(suffix, Parameters): - # No such thing as variance for ParamSpecs, consider them covariant - # TODO: is there a case I am missing? + prefix = template_arg.prefix + if isinstance(mapped_arg, Parameters): + # No such thing as variance for ParamSpecs, consider them invariant # TODO: constraints between prefixes - prefix = template_arg.prefix - - suffix = suffix.copy_modified( - suffix.arg_types[len(prefix.arg_types) :], - suffix.arg_kinds[len(prefix.arg_kinds) :], - suffix.arg_names[len(prefix.arg_names) :], + suffix = mapped_arg.copy_modified( + mapped_arg.arg_types[len(prefix.arg_types) :], + mapped_arg.arg_kinds[len(prefix.arg_kinds) :], + mapped_arg.arg_names[len(prefix.arg_names) :], ) - res.append(Constraint(template_arg, self.direction, suffix)) - elif isinstance(suffix, ParamSpecType): - res.append(Constraint(template_arg, self.direction, suffix)) + res.append(Constraint(template_arg, SUBTYPE_OF, suffix)) + res.append(Constraint(template_arg, SUPERTYPE_OF, suffix)) + elif isinstance(mapped_arg, ParamSpecType): + suffix = mapped_arg.copy_modified( + prefix=Parameters( + mapped_arg.prefix.arg_types[len(prefix.arg_types) :], + mapped_arg.prefix.arg_kinds[len(prefix.arg_kinds) :], + mapped_arg.prefix.arg_names[len(prefix.arg_names) :], + ) + ) + res.append(Constraint(template_arg, SUBTYPE_OF, suffix)) + res.append(Constraint(template_arg, SUPERTYPE_OF, suffix)) else: # This case should have been handled above. assert not isinstance(tvar, TypeVarTupleType) @@ -926,7 +940,8 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # We can't infer constraints from arguments if the template is Callable[..., T] # (with literal '...'). if not template.is_ellipsis_args: - if find_unpack_in_list(template.arg_types) is not None: + unpack_present = find_unpack_in_list(template.arg_types) + if unpack_present is not None: ( unpack_constraints, cactual_args_t, @@ -942,17 +957,25 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: template_args = template.arg_types cactual_args = cactual.arg_types # The lengths should match, but don't crash (it will error elsewhere). - for t, a in zip(template_args, cactual_args): - if isinstance(a, ParamSpecType) and not isinstance(t, ParamSpecType): + for t, a, tk, ak in zip( + template_args, cactual_args, template.arg_kinds, cactual.arg_kinds + ): + # Unpack may have shifted indices. + if not unpack_present: # This avoids bogus constraints like T <: P.args - # TODO: figure out a more principled way to skip arg_kind mismatch - # (see also a similar to do item in corresponding branch below) + if ( + tk == ARG_STAR + and ak != ARG_STAR + or tk == ARG_STAR2 + and ak != ARG_STAR2 + ): + continue + if isinstance(a, ParamSpecType): + # TODO: can we infer something useful for *T vs P? continue # Negate direction due to function argument type contravariance. res.extend(infer_constraints(t, a, neg_op(self.direction))) else: - # sometimes, it appears we try to get constraints between two paramspec callables? - # TODO: check the prefixes match prefix = param_spec.prefix prefix_len = len(prefix.arg_types) @@ -985,19 +1008,23 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: ) ) else: - if not param_spec.prefix.arg_types or cactual_ps.prefix.arg_types: - # TODO: figure out a more general logic to reject shorter prefix in actual. - # This may be actually fixed by a more general to do item above. + if len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types): + cactual_ps = cactual_ps.copy_modified( + prefix=Parameters( + arg_types=cactual_ps.prefix.arg_types[prefix_len:], + arg_kinds=cactual_ps.prefix.arg_kinds[prefix_len:], + arg_names=cactual_ps.prefix.arg_names[prefix_len:], + ) + ) res.append(Constraint(param_spec, neg_op(self.direction), cactual_ps)) - # compare prefixes + # Compare prefixes as well cactual_prefix = cactual.copy_modified( arg_types=cactual.arg_types[:prefix_len], arg_kinds=cactual.arg_kinds[:prefix_len], arg_names=cactual.arg_names[:prefix_len], ) - # TODO: this may assume positional arguments for t, a, k in zip( prefix.arg_types, cactual_prefix.arg_types, cactual_prefix.arg_kinds ): diff --git a/mypy/expandtype.py b/mypy/expandtype.py index a8e4b4df3014..0e98ed048197 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -231,44 +231,27 @@ def visit_type_var(self, t: TypeVarType) -> Type: return repl def visit_param_spec(self, t: ParamSpecType) -> Type: - # set prefix to something empty so we don't duplicate it - repl = get_proper_type( - self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], []))) - ) - if isinstance(repl, Instance): - # TODO: what does prefix mean in this case? - # TODO: why does this case even happen? Instances aren't plural. - return repl - elif isinstance(repl, (ParamSpecType, Parameters)): - if isinstance(repl, ParamSpecType): - return repl.copy_modified( - flavor=t.flavor, - prefix=t.prefix.copy_modified( - arg_types=t.prefix.arg_types + repl.prefix.arg_types, - arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds, - arg_names=t.prefix.arg_names + repl.prefix.arg_names, - ), - ) - else: - # if the paramspec is *P.args or **P.kwargs: - if t.flavor != ParamSpecFlavor.BARE: - assert isinstance(repl, CallableType), "Should not be able to get here." - # Is this always the right thing to do? - param_spec = repl.param_spec() - if param_spec: - return param_spec.with_flavor(t.flavor) - else: - return repl - else: - return Parameters( - t.prefix.arg_types + repl.arg_types, - t.prefix.arg_kinds + repl.arg_kinds, - t.prefix.arg_names + repl.arg_names, - variables=[*t.prefix.variables, *repl.variables], - ) - + # Set prefix to something empty, so we don't duplicate it below. + repl = self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], []))) + if isinstance(repl, ParamSpecType): + return repl.copy_modified( + flavor=t.flavor, + prefix=t.prefix.copy_modified( + arg_types=self.expand_types(t.prefix.arg_types + repl.prefix.arg_types), + arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds, + arg_names=t.prefix.arg_names + repl.prefix.arg_names, + ), + ) + elif isinstance(repl, Parameters): + assert t.flavor == ParamSpecFlavor.BARE + return Parameters( + self.expand_types(t.prefix.arg_types + repl.arg_types), + t.prefix.arg_kinds + repl.arg_kinds, + t.prefix.arg_names + repl.arg_names, + variables=[*t.prefix.variables, *repl.variables], + ) else: - # TODO: should this branch be removed? better not to fail silently + # TODO: replace this with "assert False" return repl def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: @@ -387,7 +370,7 @@ def interpolate_args_for_unpack( def visit_callable_type(self, t: CallableType) -> CallableType: param_spec = t.param_spec() if param_spec is not None: - repl = get_proper_type(self.variables.get(param_spec.id)) + repl = self.variables.get(param_spec.id) # If a ParamSpec in a callable type is substituted with a # callable type, we can't use normal substitution logic, # since ParamSpec is actually split into two components @@ -396,34 +379,29 @@ def visit_callable_type(self, t: CallableType) -> CallableType: # kinds and names in the replacement. The return type in # the replacement is ignored. if isinstance(repl, Parameters): - # Substitute *args: P.args, **kwargs: P.kwargs - prefix = param_spec.prefix - # we need to expand the types in the prefix, so might as well - # not get them in the first place - t = t.expand_param_spec(repl, no_prefix=True) + # We need to expand both the types in the prefix and the ParamSpec itself + t = t.expand_param_spec(repl) return t.copy_modified( - arg_types=self.expand_types(prefix.arg_types) + t.arg_types, - arg_kinds=prefix.arg_kinds + t.arg_kinds, - arg_names=prefix.arg_names + t.arg_names, + arg_types=self.expand_types(t.arg_types), + arg_kinds=t.arg_kinds, + arg_names=t.arg_names, ret_type=t.ret_type.accept(self), type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), ) - # TODO: Conceptually, the "len(t.arg_types) == 2" should not be here. However, this - # errors without it. Either figure out how to eliminate this or place an - # explanation for why this is necessary. - elif isinstance(repl, ParamSpecType) and len(t.arg_types) == 2: - # We're substituting one paramspec for another; this can mean that the prefix - # changes. (e.g. sub Concatenate[int, P] for Q) + elif isinstance(repl, ParamSpecType): + # We're substituting one ParamSpec for another; this can mean that the prefix + # changes, e.g. substitute Concatenate[int, P] in place of Q. prefix = repl.prefix - old_prefix = param_spec.prefix - - # Check assumptions. I'm not sure what order to place new prefix vs old prefix: - assert not old_prefix.arg_types or not prefix.arg_types - - t = t.copy_modified( - arg_types=prefix.arg_types + old_prefix.arg_types + t.arg_types, - arg_kinds=prefix.arg_kinds + old_prefix.arg_kinds + t.arg_kinds, - arg_names=prefix.arg_names + old_prefix.arg_names + t.arg_names, + clean_repl = repl.copy_modified(prefix=Parameters([], [], [])) + return t.copy_modified( + arg_types=self.expand_types(t.arg_types[:-2] + prefix.arg_types) + + [ + clean_repl.with_flavor(ParamSpecFlavor.ARGS), + clean_repl.with_flavor(ParamSpecFlavor.KWARGS), + ], + arg_kinds=t.arg_kinds[:-2] + prefix.arg_kinds + t.arg_kinds[-2:], + arg_names=t.arg_names[:-2] + prefix.arg_names + t.arg_names[-2:], + ret_type=t.ret_type.accept(self), ) var_arg = t.var_arg() diff --git a/mypy/solve.py b/mypy/solve.py index f02932fa753d..0c2b71f60d35 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -336,7 +336,9 @@ def is_trivial_bound(tp: ProperType) -> bool: def normalize_constraints( - constraints: list[Constraint], vars: list[TypeVarId] + # TODO: delete this function? + constraints: list[Constraint], + vars: list[TypeVarId], ) -> list[Constraint]: """Normalize list of constraints (to simplify life for the non-linear solver). diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 5712d7375e50..5042f40325ff 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1698,11 +1698,15 @@ def unify_generic_callable( return_constraint_direction = mypy.constraints.SUBTYPE_OF constraints: list[mypy.constraints.Constraint] = [] - for arg_type, target_arg_type in zip(type.arg_types, target.arg_types): - c = mypy.constraints.infer_constraints( - arg_type, target_arg_type, mypy.constraints.SUPERTYPE_OF - ) - constraints.extend(c) + # There is some special logic for inference in callables, so better use them + # as wholes instead of picking separate arguments. + cs = mypy.constraints.infer_constraints( + type.copy_modified(ret_type=UninhabitedType()), + target.copy_modified(ret_type=UninhabitedType()), + mypy.constraints.SUBTYPE_OF, + skip_neg_op=True, + ) + constraints.extend(cs) if not ignore_return: c = mypy.constraints.infer_constraints( type.ret_type, target.ret_type, return_constraint_direction diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 59457dfa5d3b..56ac86058ce4 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -1464,7 +1464,7 @@ def make_call(*items: tuple[str, str | None]) -> CallExpr: class TestExpandTypeLimitGetProperType(TestCase): # WARNING: do not increase this number unless absolutely necessary, # and you understand what you are doing. - ALLOWED_GET_PROPER_TYPES = 8 + ALLOWED_GET_PROPER_TYPES = 6 @skipUnless(mypy.expandtype.__file__.endswith(".py"), "Skip for compiled mypy") def test_count_get_proper_type(self) -> None: diff --git a/mypy/typeanal.py b/mypy/typeanal.py index a125f8953741..8ac73cdf8aac 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -1546,6 +1546,7 @@ def anal_type(self, t: Type, nested: bool = True, *, allow_param_spec: bool = Fa if analyzed.prefix.arg_types: self.fail("Invalid location for Concatenate", t, code=codes.VALID_TYPE) self.note("You can use Concatenate as the first argument to Callable", t) + analyzed = AnyType(TypeOfAny.from_error) else: self.fail( f'Invalid location for ParamSpec "{analyzed.name}"', t, code=codes.VALID_TYPE @@ -1555,6 +1556,7 @@ def anal_type(self, t: Type, nested: bool = True, *, allow_param_spec: bool = Fa "'Callable[{}, int]'".format(analyzed.name), t, ) + analyzed = AnyType(TypeOfAny.from_error) return analyzed def anal_var_def(self, var_def: TypeVarLikeType) -> TypeVarLikeType: diff --git a/mypy/types.py b/mypy/types.py index 9bbe3dbbad35..23f5c0b86f7b 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1577,6 +1577,7 @@ def __init__( self.arg_kinds = arg_kinds self.arg_names = list(arg_names) assert len(arg_types) == len(arg_kinds) == len(arg_names) + assert not any(isinstance(t, (Parameters, ParamSpecType)) for t in arg_types) self.min_args = arg_kinds.count(ARG_POS) self.is_ellipsis_args = is_ellipsis_args self.variables = variables or [] @@ -1788,6 +1789,11 @@ def __init__( ) -> None: super().__init__(line, column) assert len(arg_types) == len(arg_kinds) == len(arg_names) + for t, k in zip(arg_types, arg_kinds): + if isinstance(t, ParamSpecType): + assert not t.prefix.arg_types + # TODO: should we assert that only ARG_STAR contain ParamSpecType? + # See testParamSpecJoin, that relies on passing e.g `P.args` as plain argument. if variables is None: variables = [] self.arg_types = list(arg_types) @@ -2033,35 +2039,21 @@ def param_spec(self) -> ParamSpecType | None: if not isinstance(arg_type, ParamSpecType): return None - # sometimes paramspectypes are analyzed in from mysterious places, - # e.g. def f(prefix..., *args: P.args, **kwargs: P.kwargs) -> ...: ... - prefix = arg_type.prefix - if not prefix.arg_types: - # TODO: confirm that all arg kinds are positional - prefix = Parameters(self.arg_types[:-2], self.arg_kinds[:-2], self.arg_names[:-2]) - + # Prepend prefix for def f(prefix..., *args: P.args, **kwargs: P.kwargs) -> ... + # TODO: confirm that all arg kinds are positional + prefix = Parameters(self.arg_types[:-2], self.arg_kinds[:-2], self.arg_names[:-2]) return arg_type.copy_modified(flavor=ParamSpecFlavor.BARE, prefix=prefix) - def expand_param_spec(self, c: Parameters, no_prefix: bool = False) -> CallableType: + def expand_param_spec(self, c: Parameters) -> CallableType: # TODO: try deleting variables from Parameters after new type inference is default. variables = c.variables - - if no_prefix: - return self.copy_modified( - arg_types=c.arg_types, - arg_kinds=c.arg_kinds, - arg_names=c.arg_names, - is_ellipsis_args=c.is_ellipsis_args, - variables=[*variables, *self.variables], - ) - else: - return self.copy_modified( - arg_types=self.arg_types[:-2] + c.arg_types, - arg_kinds=self.arg_kinds[:-2] + c.arg_kinds, - arg_names=self.arg_names[:-2] + c.arg_names, - is_ellipsis_args=c.is_ellipsis_args, - variables=[*variables, *self.variables], - ) + return self.copy_modified( + arg_types=self.arg_types[:-2] + c.arg_types, + arg_kinds=self.arg_kinds[:-2] + c.arg_kinds, + arg_names=self.arg_names[:-2] + c.arg_names, + is_ellipsis_args=c.is_ellipsis_args, + variables=[*variables, *self.variables], + ) def with_unpacked_kwargs(self) -> NormalizedCallableType: if not self.unpack_kwargs: diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 0f98ef91cc66..f4d12fb69707 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -3165,6 +3165,25 @@ h: Callable[Concatenate[T, Q], Bar[Q, T]] reveal_type(dec(h)) # N: Revealed type is "def [T, Q] (T`-1, *Q.args, **Q.kwargs) -> builtins.list[__main__.Bar[Q`-2, T`-1]]" [builtins fixtures/list.pyi] +[case testInferenceAgainstGenericParamSpecVsParamSpecConcatenate] +# flags: --new-type-inference +from typing import TypeVar, Callable, List, Tuple, Generic +from typing_extensions import ParamSpec, Concatenate + +T = TypeVar('T') +P = ParamSpec('P') +Q = ParamSpec('Q') + +class Foo(Generic[P]): ... +class Bar(Generic[P, T]): ... + +def dec(f: Callable[P, int]) -> Callable[P, Foo[P]]: ... +h: Callable[Concatenate[T, Q], int] +g: Callable[Concatenate[T, Q], int] +h = g +reveal_type(dec(h)) # N: Revealed type is "def [T, Q] (T`-1, *Q.args, **Q.kwargs) -> __main__.Foo[[T`-1, **Q`-2]]" +[builtins fixtures/list.pyi] + [case testInferenceAgainstGenericParamSpecSecondary] # flags: --new-type-inference from typing import TypeVar, Callable, List, Tuple, Generic @@ -3182,3 +3201,25 @@ reveal_type(dec(g)) # N: Revealed type is "def (builtins.int) -> __main__.Foo[[ h: Callable[Q, Foo[[int]]] reveal_type(dec(g)) # N: Revealed type is "def (builtins.int) -> __main__.Foo[[builtins.int]]" [builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericParamSpecSecondOrder] +# flags: --new-type-inference +from typing import TypeVar, Callable +from typing_extensions import ParamSpec, Concatenate + +T = TypeVar('T') +S = TypeVar('S') +P = ParamSpec('P') +Q = ParamSpec('Q') +U = TypeVar('U') +W = ParamSpec('W') + +def transform( + dec: Callable[[Callable[P, T]], Callable[Q, S]] +) -> Callable[[Callable[Concatenate[int, P], T]], Callable[Concatenate[int, Q], S]]: ... + +def dec(f: Callable[W, U]) -> Callable[W, U]: ... +def dec2(f: Callable[Concatenate[str, W], U]) -> Callable[Concatenate[bytes, W], U]: ... +reveal_type(transform(dec)) # N: Revealed type is "def [P, T] (def (builtins.int, *P.args, **P.kwargs) -> T`2) -> def (builtins.int, *P.args, **P.kwargs) -> T`2" +reveal_type(transform(dec2)) # N: Revealed type is "def [W, T] (def (builtins.int, builtins.str, *W.args, **W.kwargs) -> T`6) -> def (builtins.int, builtins.bytes, *W.args, **W.kwargs) -> T`6" +[builtins fixtures/tuple.pyi] From 0af630f77c681cf57c111b917005c52af97ab901 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 9 Aug 2023 20:00:58 +0100 Subject: [PATCH 08/15] Fix accidental TypeVar id clash --- mypy/checkexpr.py | 17 ++++++++++++++++- test-data/unit/check-functions.test | 8 ++++---- test-data/unit/check-generics.test | 17 +++++++++++++++++ .../unit/check-parameter-specification.test | 4 ++-- 4 files changed, 39 insertions(+), 7 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3454320ac194..128ca4129f76 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -17,7 +17,12 @@ from mypy.checkstrformat import StringFormatterChecker from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars from mypy.errors import ErrorWatcher, report_internal_error -from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars +from mypy.expandtype import ( + expand_type, + expand_type_by_instance, + freshen_all_functions_type_vars, + freshen_function_type_vars, +) from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments from mypy.literals import literal from mypy.maptype import map_instance_to_supertype @@ -1573,6 +1578,16 @@ def check_callable_call( lambda i: self.accept(args[i]), ) + # This is tricky: return type may contain its own type variables, like in + # def [S] (S) -> def [T] (T) -> tuple[S, T], so we need to update their ids + # to avoid possible id clashes if this call itself appears in a generic + # function body. + ret_type = get_proper_type(callee.ret_type) + if isinstance(ret_type, CallableType) and ret_type.variables: + fresh_ret_type = freshen_all_functions_type_vars(callee.ret_type) + freeze_all_type_vars(fresh_ret_type) + callee = callee.copy_modified(ret_type=fresh_ret_type) + if callee.is_generic(): need_refresh = any( isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index a8722d8190b9..f49541420cc0 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -2330,7 +2330,7 @@ T = TypeVar('T') def deco() -> Callable[[T], T]: pass reveal_type(deco) # N: Revealed type is "def () -> def [T] (T`-1) -> T`-1" f = deco() -reveal_type(f) # N: Revealed type is "def [T] (T`-1) -> T`-1" +reveal_type(f) # N: Revealed type is "def [T] (T`1) -> T`1" i = f(3) reveal_type(i) # N: Revealed type is "builtins.int" @@ -2343,7 +2343,7 @@ U = TypeVar('U') def deco(x: U) -> Callable[[T, U], T]: pass reveal_type(deco) # N: Revealed type is "def [U] (x: U`-1) -> def [T] (T`-2, U`-1) -> T`-2" f = deco("foo") -reveal_type(f) # N: Revealed type is "def [T] (T`-2, builtins.str) -> T`-2" +reveal_type(f) # N: Revealed type is "def [T] (T`1, builtins.str) -> T`1" i = f(3, "eggs") reveal_type(i) # N: Revealed type is "builtins.int" @@ -2354,9 +2354,9 @@ T = TypeVar('T') R = TypeVar('R') def deco() -> Callable[[T], Callable[[T, R], R]]: pass f = deco() -reveal_type(f) # N: Revealed type is "def [T] (T`-1) -> def [R] (T`-1, R`-2) -> R`-2" +reveal_type(f) # N: Revealed type is "def [T] (T`2) -> def [R] (T`2, R`1) -> R`1" g = f(3) -reveal_type(g) # N: Revealed type is "def [R] (builtins.int, R`-2) -> R`-2" +reveal_type(g) # N: Revealed type is "def [R] (builtins.int, R`3) -> R`3" s = g(4, "foo") reveal_type(s) # N: Revealed type is "builtins.str" diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index f4d12fb69707..219c310c2079 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -3223,3 +3223,20 @@ def dec2(f: Callable[Concatenate[str, W], U]) -> Callable[Concatenate[bytes, W], reveal_type(transform(dec)) # N: Revealed type is "def [P, T] (def (builtins.int, *P.args, **P.kwargs) -> T`2) -> def (builtins.int, *P.args, **P.kwargs) -> T`2" reveal_type(transform(dec2)) # N: Revealed type is "def [W, T] (def (builtins.int, builtins.str, *W.args, **W.kwargs) -> T`6) -> def (builtins.int, builtins.bytes, *W.args, **W.kwargs) -> T`6" [builtins fixtures/tuple.pyi] + +[case testNoAccidentalVariableClashInNestedGeneric] +# flags: --new-type-inference +from typing import TypeVar, Callable, Generic, Tuple + +T = TypeVar('T') +S = TypeVar('S') +U = TypeVar('U') + +def pipe(x: T, f1: Callable[[T], S], f2: Callable[[S], U]) -> U: ... +def and_then(a: T) -> Callable[[S], Tuple[S, T]]: ... + +def apply(a: S, b: T) -> None: + v1 = and_then(b) + v2: Callable[[Tuple[S, T]], None] + return pipe(a, v1, v2) +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index d9b1f1052051..c7684a9acb2b 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1029,7 +1029,7 @@ j = Job(generic_f) reveal_type(j) # N: Revealed type is "__main__.Job[[x: _T`-1]]" jf = j.into_callable() -reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`-1)" +reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`2)" reveal_type(jf(1)) # N: Revealed type is "None" [builtins fixtures/paramspec.pyi] @@ -1051,7 +1051,7 @@ j = Job(generic_f) reveal_type(j) # N: Revealed type is "__main__.Job[[x: _T`2], _T`2]" jf = j.into_callable() -reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`2) -> _T`2" +reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`3) -> _T`3" reveal_type(jf(1)) # N: Revealed type is "builtins.int" [builtins fixtures/paramspec.pyi] From c5c1b76f89a6cc6f2b16b8bdfa4b82411833f65a Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 9 Aug 2023 20:40:16 +0100 Subject: [PATCH 09/15] Some cleanups --- mypy/checkexpr.py | 5 +++++ mypy/constraints.py | 15 ++++++--------- mypy/solve.py | 30 +----------------------------- 3 files changed, 12 insertions(+), 38 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 128ca4129f76..54c8fb517c2c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5715,6 +5715,11 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type: def visit_instance(self, t: Instance) -> Type: if t.type.has_param_spec_type: + # We need this special-casing to preserve the possibility to store a + # generic function in an instance type. Things like + # forall T . Foo[[x: T], T] + # are not really expressible in current type system, but this looks like + # a useful feature, so let's keep it. param_spec_index = next( i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType) ) diff --git a/mypy/constraints.py b/mypy/constraints.py index 837f47dedcdc..6c9ef198fa0d 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -960,15 +960,12 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: for t, a, tk, ak in zip( template_args, cactual_args, template.arg_kinds, cactual.arg_kinds ): - # Unpack may have shifted indices. - if not unpack_present: - # This avoids bogus constraints like T <: P.args - if ( - tk == ARG_STAR - and ak != ARG_STAR - or tk == ARG_STAR2 - and ak != ARG_STAR2 - ): + # This avoids bogus constraints like T <: P.args + if (tk == ARG_STAR and ak != ARG_STAR) or ( + tk == ARG_STAR2 and ak != ARG_STAR2 + ): + # Unpack may have shifted indices. + if not unpack_present: continue if isinstance(a, ParamSpecType): # TODO: can we infer something useful for *T vs P? diff --git a/mypy/solve.py b/mypy/solve.py index 0c2b71f60d35..4b2b899c2a8d 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -6,7 +6,7 @@ from typing import Iterable, Sequence from typing_extensions import TypeAlias as _TypeAlias -from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op +from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints from mypy.expandtype import expand_type from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort from mypy.join import join_types @@ -27,7 +27,6 @@ UninhabitedType, UnionType, get_proper_type, - remove_dups, ) from mypy.typestate import type_state @@ -63,10 +62,6 @@ def solve_constraints( for c in constraints: extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars]) originals.update({v.id: v for v in c.extra_tvars if v.id not in originals}) - if allow_polymorphic: - # Constraints like T :> S and S <: T are semantically the same, but they are - # represented differently. Normalize the constraint list w.r.t this equivalence. - constraints = normalize_constraints(constraints, vars + extra_vars) # Collect a list of constraints for each type variable. cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars} @@ -335,29 +330,6 @@ def is_trivial_bound(tp: ProperType) -> bool: return isinstance(tp, Instance) and tp.type.fullname == "builtins.object" -def normalize_constraints( - # TODO: delete this function? - constraints: list[Constraint], - vars: list[TypeVarId], -) -> list[Constraint]: - """Normalize list of constraints (to simplify life for the non-linear solver). - - This includes two things currently: - * Complement T :> S by S <: T - * Remove strict duplicates - * Remove constrains for unrelated variables - """ - res = constraints.copy() - for c in constraints: - if ( - isinstance(c.target, TypeVarType) - or isinstance(c.target, ParamSpecType) - and not c.target.prefix.arg_types - ): - res.append(Constraint(c.target, neg_op(c.op), c.origin_type_var)) - return [c for c in remove_dups(constraints) if c.type_var in vars] - - def transitive_closure( tvars: list[TypeVarId], constraints: list[Constraint] ) -> tuple[Graph, Bounds, Bounds]: From 582a4deafd0a4c64ac6a53eb648b9f73502fc28b Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 9 Aug 2023 23:59:09 +0100 Subject: [PATCH 10/15] Fix urllib crash; some assert tweak --- mypy/applytype.py | 10 +++++++--- mypy/types.py | 2 +- test-data/unit/check-parameter-specification.test | 13 +++++++++++++ 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/mypy/applytype.py b/mypy/applytype.py index a98797270768..65ebbe89ccde 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -9,7 +9,6 @@ AnyType, CallableType, Instance, - Parameters, ParamSpecType, PartialType, TupleType, @@ -109,8 +108,13 @@ def apply_generic_arguments( if param_spec is not None: nt = id_to_type.get(param_spec.id) if nt is not None: - if isinstance(nt, Parameters): - callable = callable.expand_param_spec(nt) + # ParamSpec expansion is special-cased, so we need to always expand callable + # as a whole, not expanding arguments individually. + callable = expand_type(callable, id_to_type) + assert isinstance(callable, CallableType) + return callable.copy_modified( + variables=[tv for tv in tvars if tv.id not in id_to_type] + ) # Apply arguments to argument types. var_arg = callable.var_arg() diff --git a/mypy/types.py b/mypy/types.py index 23f5c0b86f7b..359ca713616b 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1577,7 +1577,7 @@ def __init__( self.arg_kinds = arg_kinds self.arg_names = list(arg_names) assert len(arg_types) == len(arg_kinds) == len(arg_names) - assert not any(isinstance(t, (Parameters, ParamSpecType)) for t in arg_types) + assert not any(isinstance(t, Parameters) for t in arg_types) self.min_args = arg_kinds.count(ARG_POS) self.is_ellipsis_args = is_ellipsis_args self.variables = variables or [] diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index c7684a9acb2b..e8161750ba56 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1555,3 +1555,16 @@ def test(x: U) -> U: ... reveal_type(dec) # N: Revealed type is "def [P, T] (f: def (*P.args, **P.kwargs) -> T`-2) -> def (*P.args, **P.kwargs) -> builtins.list[T`-2]" reveal_type(dec(test)) # N: Revealed type is "def [T] (x: T`2) -> builtins.list[T`2]" [builtins fixtures/paramspec.pyi] + +[case testParamSpecNestedApplyNoCrash] +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") + +def apply(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... +def test() -> None: ... +# TODO: avoid this error, although it may be non-trivial. +apply(apply, test) # E: Argument 2 to "apply" has incompatible type "Callable[[], None]"; expected "Callable[P, T]" +[builtins fixtures/paramspec.pyi] From 4f8afce33b2ec285741d7462eabd33c67ca41e18 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 10 Aug 2023 10:34:51 +0100 Subject: [PATCH 11/15] Address CR --- test-data/unit/check-generics.test | 14 +++++++++----- test-data/unit/check-inference.test | 6 ++---- test-data/unit/check-parameter-specification.test | 8 ++++++++ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 219c310c2079..13111e01f013 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -3049,7 +3049,9 @@ def dec1(f: Callable[[T], T]) -> Callable[[T], List[T]]: def dec2(f: Callable[[S], T]) -> Callable[[S], List[T]]: ... def dec3(f: Callable[[List[S]], T]) -> Callable[[S], T]: - ... + def g(x: S) -> T: + return f([x]) + return g def dec4(f: Callable[[S], List[T]]) -> Callable[[S], T]: ... def dec5(f: Callable[[int], T]) -> Callable[[int], List[T]]: @@ -3057,12 +3059,14 @@ def dec5(f: Callable[[int], T]) -> Callable[[int], List[T]]: return [f(x)] * x return g -reveal_type(dec1(lambda x: x)) # N: Revealed type is "def [T] (T`2) -> builtins.list[T`2]" -reveal_type(dec2(lambda x: x)) # N: Revealed type is "def [S] (S`3) -> builtins.list[S`3]" -reveal_type(dec3(lambda x: x[0])) # N: Revealed type is "def [S] (S`5) -> S`5" -reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`7) -> S`7" +reveal_type(dec1(lambda x: x)) # N: Revealed type is "def [T] (T`3) -> builtins.list[T`3]" +reveal_type(dec2(lambda x: x)) # N: Revealed type is "def [S] (S`4) -> builtins.list[S`4]" +reveal_type(dec3(lambda x: x[0])) # N: Revealed type is "def [S] (S`6) -> S`6" +reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`8) -> S`8" reveal_type(dec1(lambda x: 1)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]" reveal_type(dec5(lambda x: x)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]" +reveal_type(dec3(lambda x: x)) # N: Revealed type is "def [S] (S`15) -> builtins.list[S`15]" +dec4(lambda x: x) # E: Incompatible return value type (got "S", expected "List[object]") [builtins fixtures/list.pyi] [case testInferenceAgainstGenericParamSpecBasicInList] diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index c53cb8b75da3..9ee30b4df859 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -1386,10 +1386,8 @@ y = f(lambda x: x) # E: Need type annotation for "y" from typing import TypeVar, Callable X = TypeVar('X') def f(x: Callable[[X], X], y: str) -> X: pass -y = f(lambda x: x, 1) # Fail -[out] -main:5: error: Need type annotation for "y" -main:5: error: Argument 2 to "f" has incompatible type "int"; expected "str" +y = f(lambda x: x, 1) # E: Need type annotation for "y" \ + # E: Argument 2 to "f" has incompatible type "int"; expected "str" [case testInferLambdaNone] # flags: --no-strict-optional diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index e8161750ba56..f523cb005a2c 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1554,6 +1554,14 @@ def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ... def test(x: U) -> U: ... reveal_type(dec) # N: Revealed type is "def [P, T] (f: def (*P.args, **P.kwargs) -> T`-2) -> def (*P.args, **P.kwargs) -> builtins.list[T`-2]" reveal_type(dec(test)) # N: Revealed type is "def [T] (x: T`2) -> builtins.list[T`2]" + +class A: ... +TA = TypeVar("TA", bound=A) + +def test_with_bound(x: TA) -> TA: ... +reveal_type(dec(test_with_bound)) # N: Revealed type is "def [T <: __main__.A] (x: T`4) -> builtins.list[T`4]" +dec(test_with_bound)(0) # E: Value of type variable "T" of function cannot be "int" +dec(test_with_bound)(A()) # OK [builtins fixtures/paramspec.pyi] [case testParamSpecNestedApplyNoCrash] From 2d210324dabaac9d315e1dabe0a92e30d8284654 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 10 Aug 2023 15:40:58 +0100 Subject: [PATCH 12/15] Couple fixes --- mypy/checkexpr.py | 34 ++++++++++++++++++++++++++++++++-- mypy/constraints.py | 1 - 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 54c8fb517c2c..66e073e0c26b 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1606,7 +1606,7 @@ def check_callable_call( lambda i: self.accept(args[i]), ) callee = self.infer_function_type_arguments( - callee, args, arg_kinds, formal_to_actual, context + callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context ) if need_refresh: formal_to_actual = map_actuals_to_formals( @@ -1896,7 +1896,9 @@ def infer_function_type_arguments( callee_type: CallableType, args: list[Expression], arg_kinds: list[ArgKind], + arg_names: Sequence[str | None] | None, formal_to_actual: list[list[int]], + need_refresh: bool, context: Context, ) -> CallableType: """Infer the type arguments for a generic callee type. @@ -1938,7 +1940,14 @@ def infer_function_type_arguments( if 2 in arg_pass_nums: # Second pass of type inference. (callee_type, inferred_args) = self.infer_function_type_arguments_pass2( - callee_type, args, arg_kinds, formal_to_actual, inferred_args, context + callee_type, + args, + arg_kinds, + arg_names, + formal_to_actual, + inferred_args, + need_refresh, + context, ) if ( @@ -1964,6 +1973,17 @@ def infer_function_type_arguments( or set(get_type_vars(a)) & set(callee_type.variables) for a in inferred_args ): + if need_refresh: + # Technically we need to refresh formal_to_actual after *each* inference pass, + # since each pass can expand ParamSpec or TypeVarTuple. Although such situations + # are very rare, not doing this can cause crashes. + formal_to_actual = map_actuals_to_formals( + arg_kinds, + arg_names, + callee_type.arg_kinds, + callee_type.arg_names, + lambda a: self.accept(args[a]), + ) # If the regular two-phase inference didn't work, try inferring type # variables while allowing for polymorphic solutions, i.e. for solutions # potentially involving free variables. @@ -2011,8 +2031,10 @@ def infer_function_type_arguments_pass2( callee_type: CallableType, args: list[Expression], arg_kinds: list[ArgKind], + arg_names: Sequence[str | None] | None, formal_to_actual: list[list[int]], old_inferred_args: Sequence[Type | None], + need_refresh: bool, context: Context, ) -> tuple[CallableType, list[Type | None]]: """Perform second pass of generic function type argument inference. @@ -2034,6 +2056,14 @@ def infer_function_type_arguments_pass2( if isinstance(arg, (NoneType, UninhabitedType)) or has_erased_component(arg): inferred_args[i] = None callee_type = self.apply_generic_arguments(callee_type, inferred_args, context) + if need_refresh: + formal_to_actual = map_actuals_to_formals( + arg_kinds, + arg_names, + callee_type.arg_kinds, + callee_type.arg_names, + lambda a: self.accept(args[a]), + ) arg_types = self.infer_arg_types_in_context(callee_type, args, arg_kinds, formal_to_actual) diff --git a/mypy/constraints.py b/mypy/constraints.py index 6c9ef198fa0d..3e057ccc5021 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -918,7 +918,6 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: if ( type_state.infer_polymorphic and cactual.variables - and cactual.param_spec() is None and not self.skip_neg_op # Technically, the correct inferred type for application of e.g. # Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic From 59963c4ce1caf4d6ec17c2378543c42059c48c77 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 10 Aug 2023 18:26:36 +0100 Subject: [PATCH 13/15] Fix overload issue --- mypy/constraints.py | 4 ++-- test-data/unit/check-overloading.test | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 3e057ccc5021..715169f8f99e 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -963,8 +963,8 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: if (tk == ARG_STAR and ak != ARG_STAR) or ( tk == ARG_STAR2 and ak != ARG_STAR2 ): - # Unpack may have shifted indices. - if not unpack_present: + if cactual.param_spec(): + # TODO: we should be more careful also for non-ParamSpec functions continue if isinstance(a, ParamSpecType): # TODO: can we infer something useful for *T vs P? diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 466216e11d15..454ae30fb39e 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6569,3 +6569,20 @@ S = TypeVar("S", bound=str) def foo(x: int = ...) -> Callable[[T], T]: ... @overload def foo(x: S = ...) -> Callable[[T], T]: ... + +[case testOverloadGenericStarArgOverlap] +from typing import Any, Callable, TypeVar, overload, Union, Tuple, List + +F = TypeVar("F", bound=Callable[..., Any]) +S = TypeVar("S", bound=int) + +def id(f: F) -> F: ... + +@overload +def struct(*cols: S) -> int: ... +@overload +def struct(__cols: Union[List[S], Tuple[S, ...]]) -> int: ... +@id +def struct(*cols: Union[S, Union[List[S], Tuple[S, ...]]]) -> int: + pass +[builtins fixtures/tuple.pyi] From 72da8f5d3293619cfc04804ed4ad30e7f0a7e055 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 11 Aug 2023 09:58:20 +0100 Subject: [PATCH 14/15] Fix corner case with upper bounds --- mypy/constraints.py | 2 +- test-data/unit/check-generics.test | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 715169f8f99e..f7b3a8047163 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -862,7 +862,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]: elif isinstance(actual, TupleType) and self.direction == SUPERTYPE_OF: return infer_constraints(template, mypy.typeops.tuple_fallback(actual), self.direction) elif isinstance(actual, TypeVarType): - if not actual.values: + if not actual.values and not actual.id.is_meta_var(): return infer_constraints(template, actual.upper_bound, self.direction) return [] elif isinstance(actual, ParamSpecType): diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 13111e01f013..8c7c4e035961 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -3244,3 +3244,22 @@ def apply(a: S, b: T) -> None: v2: Callable[[Tuple[S, T]], None] return pipe(a, v1, v2) [builtins fixtures/tuple.pyi] + +[case testInferenceAgainstGenericParamSpecSpuriousBoundsNotUsed] +# flags: --new-type-inference +from typing import TypeVar, Callable, Generic +from typing_extensions import ParamSpec, Concatenate + +Q = ParamSpec("Q") +class Foo(Generic[Q]): ... + +T1 = TypeVar("T1", bound=Foo[...]) +T2 = TypeVar("T2", bound=Foo[...]) +P = ParamSpec("P") +def pop_off(fn: Callable[Concatenate[T1, P], T2]) -> Callable[P, Callable[[T1], T2]]: + ... + +@pop_off +def test(command: Foo[Q]) -> Foo[Q]: ... +reveal_type(test) # N: Revealed type is "def () -> def [Q] (__main__.Foo[Q`-1]) -> __main__.Foo[Q`-1]" +[builtins fixtures/tuple.pyi] From 7a87692feff877fd227d6e7247fc559c1ea59d20 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 11 Aug 2023 21:34:43 +0100 Subject: [PATCH 15/15] Postpone few unneeded hacky changes; re-organize TODOs --- mypy/constraints.py | 38 +++++++++++++++----------------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index f7b3a8047163..04c3378ce16b 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -165,6 +165,8 @@ def infer_constraints_for_callable( actual_type = mapper.expand_actual_type( actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i] ) + # TODO: if callee has ParamSpec, we need to collect all actuals that map to star + # args and create single constraint between P and resulting Parameters instead. c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF) constraints.extend(c) @@ -585,13 +587,9 @@ def visit_parameters(self, template: Parameters) -> list[Constraint]: # in situations like [x: T] <: P <: [x: int]. res = [] if len(template.arg_types) == len(self.actual.arg_types): - for tt, at, tk, ak in zip( - template.arg_types, - self.actual.arg_types, - template.arg_kinds, - self.actual.arg_kinds, - ): - if tk == ARG_STAR and ak != ARG_STAR or tk == ARG_STAR2 and ak != ARG_STAR2: + for tt, at in zip(template.arg_types, self.actual.arg_types): + # This avoids bogus constraints like T <: P.args + if isinstance(at, ParamSpecType): continue res.extend(infer_constraints(tt, at, self.direction)) return res @@ -906,6 +904,8 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # Normalize callables before matching against each other. # Note that non-normalized callables can be created in annotations # using e.g. callback protocols. + # TODO: check that callables match? Ideally we should not infer constraints + # callables that can never be subtypes of one another in given direction. template = template.with_unpacked_kwargs() extra_tvars = False if isinstance(self.actual, CallableType): @@ -913,7 +913,6 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: cactual = self.actual.with_unpacked_kwargs() param_spec = template.param_spec() if param_spec is None: - # TODO: verify argument counts; more generally, use some "formal to actual" map # TODO: Erase template variables if it is generic? if ( type_state.infer_polymorphic @@ -955,24 +954,19 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: else: template_args = template.arg_types cactual_args = cactual.arg_types - # The lengths should match, but don't crash (it will error elsewhere). - for t, a, tk, ak in zip( - template_args, cactual_args, template.arg_kinds, cactual.arg_kinds - ): + # TODO: use some more principled "formal to actual" logic + # instead of this lock-step loop over argument types. This identical + # logic should be used in 5 places: in Parameters vs Parameters + # inference, in Instance vs Instance inference for prefixes (two + # branches), and in Callable vs Callable inference (two branches). + for t, a in zip(template_args, cactual_args): # This avoids bogus constraints like T <: P.args - if (tk == ARG_STAR and ak != ARG_STAR) or ( - tk == ARG_STAR2 and ak != ARG_STAR2 - ): - if cactual.param_spec(): - # TODO: we should be more careful also for non-ParamSpec functions - continue if isinstance(a, ParamSpecType): # TODO: can we infer something useful for *T vs P? continue # Negate direction due to function argument type contravariance. res.extend(infer_constraints(t, a, neg_op(self.direction))) else: - # TODO: check the prefixes match prefix = param_spec.prefix prefix_len = len(prefix.arg_types) cactual_ps = cactual.param_spec() @@ -1021,10 +1015,8 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: arg_names=cactual.arg_names[:prefix_len], ) - for t, a, k in zip( - prefix.arg_types, cactual_prefix.arg_types, cactual_prefix.arg_kinds - ): - if k in (ARG_STAR, ARG_STAR2): + for t, a in zip(prefix.arg_types, cactual_prefix.arg_types): + if isinstance(a, ParamSpecType): continue res.extend(infer_constraints(t, a, neg_op(self.direction)))