diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 6aa18fb72c2f..2a85087cb5e6 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -228,11 +228,11 @@ def visit_instance(self, t: Instance) -> Type: if t.type.fullname == "builtins.tuple": # Normalize Tuple[*Tuple[X, ...], ...] -> Tuple[X, ...] arg = args[0] - if isinstance(arg, UnpackType): + if isinstance(arg, UnpackType) and not ( + isinstance(arg.type, TypeAliasType) and arg.type.is_recursive + ): unpacked = get_proper_type(arg.type) if isinstance(unpacked, Instance): - # TODO: this and similar asserts below may be unsafe because get_proper_type() - # may be called during semantic analysis before all invalid types are removed. assert unpacked.type.fullname == "builtins.tuple" args = list(unpacked.args) return t.copy_modified(args=args) @@ -535,7 +535,9 @@ def visit_tuple_type(self, t: TupleType) -> Type: if len(items) == 1: # Normalize Tuple[*Tuple[X, ...]] -> Tuple[X, ...] item = items[0] - if isinstance(item, UnpackType): + if isinstance(item, UnpackType) and not ( + isinstance(item.type, TypeAliasType) and item.type.is_recursive + ): unpacked = get_proper_type(item.type) if isinstance(unpacked, Instance): # expand_type() may be called during semantic analysis, before invalid unpacks are fixed. diff --git a/mypy/semanal.py b/mypy/semanal.py index b0294c782ab9..43995d97c466 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -4263,6 +4263,8 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: # An alias gets updated. updated = False if isinstance(existing.node, TypeAlias): + # Invalidate recursive status cache in case it was previously set. + existing.node._is_recursive = None if existing.node.target != res: # Copy expansion to the existing alias, this matches how we update base classes # for a TypeInfo _in place_ if there are nested placeholders. @@ -4271,8 +4273,6 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: existing.node.alias_tvars = alias_tvars existing.node.no_args = no_args updated = True - # Invalidate recursive status cache in case it was previously set. - existing.node._is_recursive = None else: # Otherwise just replace existing placeholder with type alias *in place*. existing._node = alias_node @@ -5830,6 +5830,8 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None: ): updated = False if isinstance(existing.node, TypeAlias): + # Invalidate recursive status cache in case it was previously set. + existing.node._is_recursive = None if ( existing.node.target != res or existing.node.alias_tvars != alias_node.alias_tvars @@ -5840,8 +5842,6 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None: existing.node.default_depends = default_depends existing.node.alias_tvars = alias_tvars updated = True - # Invalidate recursive status cache in case it was previously set. - existing.node._is_recursive = None else: # Otherwise just replace existing placeholder with type alias *in place*. existing._node = alias_node diff --git a/mypy/semanal_typeargs.py b/mypy/semanal_typeargs.py index 0f62a4aa8b1a..5b6d9b16273e 100644 --- a/mypy/semanal_typeargs.py +++ b/mypy/semanal_typeargs.py @@ -108,7 +108,10 @@ def visit_type_alias_type(self, t: TypeAliasType) -> None: self.seen_aliases.discard(t) def visit_tuple_type(self, t: TupleType) -> None: - t.items = flatten_nested_tuples(t.items) + # Unfortunately, universal normalization of tuples is not possible in presence of + # recursive aliases, see testNoCrashOnNonNormalRecursiveTuple for an example. + # TODO: update the places where we handle tuples to always expect non-normal ones. + t.items = flatten_nested_tuples(t.items, handle_recursive=False) for i, it in enumerate(t.items): if self.check_non_paramspec(it, "tuple", t): t.items[i] = AnyType(TypeOfAny.from_error) diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index 075bf7cb540b..5b723711405a 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -340,6 +340,8 @@ def visit_var(self, node: Var) -> None: super().visit_var(node) def visit_type_alias(self, node: TypeAlias) -> None: + # Updating alias target can invalidate its recursive status. + node._is_recursive = None self.fixup_type(node.target) for v in node.alias_tvars: self.fixup_type(v) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 51d26afd55e4..a690ca583844 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -75,6 +75,7 @@ BoolTypeQuery, CallableArgument, CallableType, + CollectAliasesVisitor, DeletedType, EllipsisType, ErasedType, @@ -275,7 +276,9 @@ def __init__( self.prohibit_special_class_field_types = prohibit_special_class_field_types # Allow variables typed as Type[Any] and type (useful for base classes). self.allow_type_any = allow_type_any - self.allow_type_var_tuple = False + # Level of nesting at which a TypeVarTuple is allowed. Note we specify exact level + # to prohibit things like Unpack[list[Ts]], which are not supported. + self.allow_type_var_tuple = -1 self.allow_unpack = allow_unpack # Set when we are analyzing a default of a type variable. self.analyzing_tvar_def = analyzing_tvar_def @@ -453,7 +456,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) self.fail(msg, t, code=codes.VALID_TYPE) return AnyType(TypeOfAny.from_error) assert isinstance(tvar_def, TypeVarTupleType) - if not self.allow_type_var_tuple: + if self.allow_type_var_tuple != self.nesting_level: self.fail( f'TypeVarTuple "{t.name}" is only valid with an unpack', t, @@ -808,9 +811,9 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ if not self.allow_unpack: self.fail(message_registry.INVALID_UNPACK_POSITION, t, code=codes.VALID_TYPE) return AnyType(TypeOfAny.from_error) - self.allow_type_var_tuple = True + self.allow_type_var_tuple = self.nesting_level + 1 result = UnpackType(self.anal_type(t.args[0]), line=t.line, column=t.column) - self.allow_type_var_tuple = False + self.allow_type_var_tuple = -1 return result elif fullname in SELF_TYPE_NAMES: if t.args: @@ -1161,9 +1164,9 @@ def visit_unpack_type(self, t: UnpackType) -> Type: if not self.allow_unpack: self.fail(message_registry.INVALID_UNPACK_POSITION, t.type, code=codes.VALID_TYPE) return AnyType(TypeOfAny.from_error) - self.allow_type_var_tuple = True + self.allow_type_var_tuple = self.nesting_level + 1 result = UnpackType(self.anal_type(t.type), from_star_syntax=t.from_star_syntax) - self.allow_type_var_tuple = False + self.allow_type_var_tuple = -1 return result def visit_parameters(self, t: Parameters) -> Type: @@ -2518,6 +2521,15 @@ def detect_diverging_alias(node: TypeAlias, target: Type) -> bool: They may be handy in rare cases, e.g. to express a union of non-mixed nested lists: Nested = Union[T, Nested[List[T]]] ~> Union[T, List[T], List[List[T]], ...] """ + is_recursive = node._is_recursive + if is_recursive is None: + is_recursive = node in node.target.accept(CollectAliasesVisitor()) + if not is_recursive: + # Fast path: this is not a recursive alias at all. + return False + # Note we only cache positive case, caching negative case is risky, as this type alias + # (or more importantly any other alias it uses) may be not ready yet. + node._is_recursive = True visitor = DivergingAliasDetector({node}) _ = target.accept(visitor) return visitor.diverging diff --git a/mypy/types.py b/mypy/types.py index bc06e36d7a47..2c6f3ab75948 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -4251,12 +4251,12 @@ def find_unpack_in_list(items: Sequence[Type]) -> int | None: # Funky code here avoids mypyc narrowing the type of unpack_index. old_index = unpack_index assert old_index is None - # Don't return so that we can also sanity check there is only one. + # Don't return so that we can also sanity-check there is only one. unpack_index = i return unpack_index -def flatten_nested_tuples(types: Iterable[Type]) -> list[Type]: +def flatten_nested_tuples(types: Iterable[Type], handle_recursive: bool = True) -> list[Type]: """Recursively flatten TupleTypes nested with Unpack. For example this will transform @@ -4270,7 +4270,12 @@ def flatten_nested_tuples(types: Iterable[Type]) -> list[Type]: res.append(typ) continue p_type = get_proper_type(typ.type) - if not isinstance(p_type, TupleType): + if ( + not isinstance(p_type, TupleType) + or not handle_recursive + and isinstance(typ.type, TypeAliasType) + and typ.type.is_recursive + ): res.append(typ) continue if isinstance(typ.type, TypeAliasType): diff --git a/test-data/unit/check-typevar-tuple.test b/test-data/unit/check-typevar-tuple.test index 3f0765ba5c77..55f125d1fc03 100644 --- a/test-data/unit/check-typevar-tuple.test +++ b/test-data/unit/check-typevar-tuple.test @@ -882,7 +882,39 @@ z: C reveal_type(x) # N: Revealed type is "Any" reveal_type(y) # N: Revealed type is "Any" reveal_type(z) # N: Revealed type is "tuple[builtins.int, Unpack[builtins.tuple[Any, ...]]]" +[builtins fixtures/tuple.pyi] + +[case testBanPathologicalRecursiveTuplesGeneric] +from typing import TypeVarTuple, Unpack + +Ts = TypeVarTuple("Ts") +A = tuple[Unpack[B[Unpack[Ts]]]] # E: Invalid recursive alias: a tuple item of itself \ + # E: Name "B" is used before definition +B = tuple[Unpack[A[Unpack[Ts]]]] +[builtins fixtures/tuple.pyi] + +[case testNoCrashOnInvalidRecursiveUnpackOfUnion] +from typing import Unpack + +A = tuple[int, str] | list[tuple[Unpack[A]]] # E: "tuple[int, str] | list[tuple[Unpack[A]]]" cannot be unpacked (must be tuple or TypeVarTuple) +[builtins fixtures/tuple.pyi] + +[case testNoCrashOnNonNormalRecursiveTuple] +from typing import Unpack + +A = tuple[int, list[tuple[str, Unpack[A]]]] +a: A +x, y = a +y[0] = 1 # E: Incompatible types in assignment (expression has type "int", target has type "tuple[str, Unpack[A]]") +[builtins fixtures/list.pyi] +[case testBanTypeVarTupleNotImmediatelyInsideUnpack] +from typing import TypeVarTuple, Unpack + +Ts = TypeVarTuple("Ts") +A = tuple[Unpack[tuple[Ts]]] # E: TypeVarTuple "Ts" is only valid with an unpack +x: A[int, str] +reveal_type(x) # N: Revealed type is "tuple[Any]" [builtins fixtures/tuple.pyi] [case testInferenceAgainstGenericVariadicWithBadType]