Skip to content

Commit 4c8f994

Browse files
Fix constructor calls for union-bounded TypeVars (#21571)
Fixes #21106
1 parent cd75c4e commit 4c8f994

4 files changed

Lines changed: 51 additions & 7 deletions

File tree

mypy/checkexpr.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,12 +1948,7 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type:
19481948
# but better than AnyType...), but replace the return type
19491949
# with typevar.
19501950
callee = self.analyze_type_type_callee(get_proper_type(item.upper_bound), context)
1951-
callee = get_proper_type(callee)
1952-
if isinstance(callee, CallableType):
1953-
callee = callee.copy_modified(ret_type=item)
1954-
elif isinstance(callee, Overloaded):
1955-
callee = Overloaded([c.copy_modified(ret_type=item) for c in callee.items])
1956-
return callee
1951+
return self.replace_type_type_callee_ret_type(callee, item)
19571952
# We support Type of namedtuples but not of tuples in general
19581953
if isinstance(item, TupleType) and tuple_fallback(item).type.fullname != "builtins.tuple":
19591954
return self.analyze_type_type_callee(tuple_fallback(item), context)
@@ -1963,6 +1958,23 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type:
19631958
self.msg.unsupported_type_type(item, context)
19641959
return AnyType(TypeOfAny.from_error)
19651960

1961+
def replace_type_type_callee_ret_type(self, callee: Type, ret_type: Type) -> Type:
1962+
callee = get_proper_type(callee)
1963+
if isinstance(callee, CallableType):
1964+
return callee.copy_modified(ret_type=ret_type)
1965+
if isinstance(callee, Overloaded):
1966+
return Overloaded([c.copy_modified(ret_type=ret_type) for c in callee.items])
1967+
if isinstance(callee, UnionType):
1968+
return UnionType(
1969+
[
1970+
self.replace_type_type_callee_ret_type(item, ret_type)
1971+
for item in callee.relevant_items()
1972+
],
1973+
line=callee.line,
1974+
column=callee.column,
1975+
)
1976+
return callee
1977+
19661978
def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]:
19671979
"""Infer argument expression types in an empty context.
19681980

test-data/unit/check-classes.test

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3889,6 +3889,24 @@ def process(cls: Type[U]):
38893889
[builtins fixtures/classmethod.pyi]
38903890
[out]
38913891

3892+
[case testTypeUsingTypeCConstructorReturnFromTypeVarUnionBound]
3893+
from typing import Optional, Type, TypeVar, Union
3894+
3895+
class A:
3896+
def __init__(self, value: str = "") -> None: pass
3897+
class B:
3898+
def __init__(self, value: str = "") -> None: pass
3899+
3900+
T = TypeVar("T", bound=Union[A, B])
3901+
3902+
def make(ftype: Type[T], value: Optional[str]) -> T:
3903+
if value is None:
3904+
return ftype()
3905+
return ftype(value)
3906+
3907+
reveal_type(make(A, "a")) # N: Revealed type is "__main__.A"
3908+
reveal_type(make(B, None)) # N: Revealed type is "__main__.B"
3909+
38923910
[case testTypeUsingTypeCErrorUnsupportedType]
38933911
from typing import Type, Tuple
38943912
def foo(arg: Type[Tuple[int]]):

test-data/unit/check-python310.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3531,7 +3531,7 @@ def switch(choice: type[T_Choice]) -> None:
35313531
reveal_type(choice()) # N: Revealed type is "b.Two"
35323532
case _:
35333533
reveal_type(choice) # N: Revealed type is "type[T_Choice`-1]"
3534-
reveal_type(choice()) # N: Revealed type is "b.One | b.Two"
3534+
reveal_type(choice()) # N: Revealed type is "T_Choice`-1"
35353535

35363536
[file b.py]
35373537
class One: ...

test-data/unit/check-python312.test

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,20 @@ f(1, u)
824824
f('x', None) # E: Value of type variable "T" of "f" cannot be "str" \
825825
# E: Value of type variable "S" of "f" cannot be "None"
826826

827+
[case testPEP695UpperBoundTypeTypeConstructorReturnType]
828+
class A:
829+
def __init__(self, value: str = "") -> None: pass
830+
class B:
831+
def __init__(self, value: str = "") -> None: pass
832+
833+
def make[T: A | B](ftype: type[T], value: str | None) -> T:
834+
if value is None:
835+
return ftype()
836+
return ftype(value)
837+
838+
reveal_type(make(A, "a")) # N: Revealed type is "__main__.A"
839+
reveal_type(make(B, None)) # N: Revealed type is "__main__.B"
840+
827841
[case testPEP695InferVarianceOfTupleType]
828842
class Cov[T](tuple[int, str]):
829843
def f(self) -> T: pass

0 commit comments

Comments
 (0)