Skip to content

Commit fa13016

Browse files
avikchaudhurimeta-codesync[bot]
authored andcommitted
TypeVar variant for arithmetic
Summary: Add torch_shapes.TypeVar and TypeVarTuple classes that support arithmetic operators (N + 1, N * 2, etc.) at Python runtime, unlike typing.TypeVar which raises TypeError. These use __class__ = typing.TypeVar so that isinstance checks pass and Generic[N] works correctly. On the pyrefly side, extend the TypeVar/TypeVarTuple special export recognition to accept torch_shapes as a valid defining module, so torch_shapes.TypeVar is treated identically to typing.TypeVar by the type checker. Also add int_type_var model variants (nanogpt, gptfast) that use torch_shapes.TypeVar instead of PEP 695 syntax, with both type-checked and runnable versions, plus comprehensive runtime and type-checking tests. Reviewed By: stroxler Differential Revision: D95282862 fbshipit-source-id: ec07dbb9efc94b653d6afc4ec9bf6b3d821580f0
1 parent 02b181c commit fa13016

10 files changed

Lines changed: 2785 additions & 22 deletions

pyrefly/lib/export/special.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,11 @@ impl SpecialExport {
139139

140140
pub fn defined_in(self, m: ModuleName) -> bool {
141141
match self {
142+
Self::TypeVar | Self::TypeVarTuple => {
143+
matches!(m.as_str(), "typing" | "typing_extensions" | "torch_shapes")
144+
}
142145
Self::TypeAlias
143-
| Self::TypeVar
144146
| Self::ParamSpec
145-
| Self::TypeVarTuple
146147
| Self::Annotated
147148
| Self::Literal
148149
| Self::TypedDict

test/tensor_shapes/fixtures/torch_shapes/__init__.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
don't crash when evaluated by Python.
1313
"""
1414

15+
import typing
16+
1517
import torch
1618
import torch.nn as nn
1719

@@ -31,3 +33,93 @@ class Dim[T]:
3133
"""
3234

3335
pass
36+
37+
38+
class TypeVar:
39+
"""TypeVar with arithmetic support for tensor shape dimensions.
40+
41+
Like typing.TypeVar but arithmetic operations (N + 1, N * 2, etc.)
42+
return self instead of raising TypeError. Setting
43+
__class__ = typing.TypeVar makes isinstance(x, typing.TypeVar)
44+
return True, so Generic[N] and TypedDict + Generic[N] both work.
45+
46+
In pyrefly, torch_shapes.TypeVar is treated identically to
47+
typing.TypeVar.
48+
"""
49+
50+
__class__ = typing.TypeVar
51+
52+
def __init__(self, name: str):
53+
self.__name__ = name
54+
self.name = name
55+
56+
def __repr__(self):
57+
return self.name
58+
59+
def __hash__(self):
60+
return hash(self.name)
61+
62+
def __eq__(self, other):
63+
return self is other
64+
65+
def __add__(self, other):
66+
return self
67+
68+
def __radd__(self, other):
69+
return self
70+
71+
def __sub__(self, other):
72+
return self
73+
74+
def __rsub__(self, other):
75+
return self
76+
77+
def __mul__(self, other):
78+
return self
79+
80+
def __rmul__(self, other):
81+
return self
82+
83+
def __floordiv__(self, other):
84+
return self
85+
86+
def __typing_subst__(self, arg):
87+
return arg
88+
89+
90+
class TypeVarTuple:
91+
"""TypeVarTuple with support for integer shape dimensions.
92+
93+
Like typing.TypeVarTuple but for use in tensor shape annotations.
94+
Setting __class__ = typing.TypeVarTuple and providing
95+
__typing_is_unpacked_typevartuple__ makes Generic[*Ns] work.
96+
97+
In pyrefly, torch_shapes.TypeVarTuple is treated identically to
98+
typing.TypeVarTuple.
99+
100+
__iter__ yields self so that *Ns unpacking works in subscripts
101+
like Generic[*Ns] or Tensor[*Ns, 3]. Python's star-unpacking
102+
calls __iter__ on the object.
103+
"""
104+
105+
__class__ = typing.TypeVarTuple
106+
107+
def __init__(self, name: str):
108+
self.__name__ = name
109+
self.name = name
110+
111+
def __repr__(self):
112+
return f"*{self.name}"
113+
114+
def __hash__(self):
115+
return hash(self.name)
116+
117+
def __eq__(self, other):
118+
return self is other
119+
120+
def __iter__(self):
121+
yield self
122+
123+
@property
124+
def __typing_is_unpacked_typevartuple__(self):
125+
return True

0 commit comments

Comments
 (0)