|
8 | 8 |
|
9 | 9 | import torch |
10 | 10 | import deepspeed.comm as dist |
11 | | -from torch._subclasses.fake_tensor import FakeTensorMode |
| 11 | +from torch._subclasses.fake_tensor import FakeTensorMode, maybe_get_fake_mode |
12 | 12 | from torch.fx import GraphModule, Node |
13 | 13 | from torch.fx.passes.fake_tensor_prop import FakeTensorProp |
14 | 14 | from torch.fx.experimental.symbolic_shapes import ShapeEnv |
@@ -80,7 +80,7 @@ def pass_shard_seq_dim(gm: GraphModule, example_inputs): |
80 | 80 | seq_symint = val.shape[1] |
81 | 81 | assert isinstance( |
82 | 82 | seq_symint, |
83 | | - torch.SymInt), f"expected sequence dimension to be of type `torch.SymInt` but found `{type(seq_symint)}`" |
| 83 | + torch.SymInt), f"expected sequence dimension to be of type {torch.SymInt!r} but found {type(seq_symint)!r}" |
84 | 84 |
|
85 | 85 | sym_seq_dim_node = find_node_by_name(gm, str(seq_symint)) |
86 | 86 | if sym_seq_dim_node is None: |
@@ -184,15 +184,52 @@ def pass_canonicalize(gm: GraphModule, real_inputs): |
184 | 184 |
|
185 | 185 |
|
186 | 186 | def pass_propagate_shapes(gm: torch.fx.GraphModule, real_inputs): |
187 | | - shape_env = ShapeEnv() |
188 | | - fake_mode = FakeTensorMode(shape_env=shape_env) |
| 187 | + fake_mode = None |
| 188 | + for node in gm.graph.nodes: |
| 189 | + # Reuse the graph's existing fake mode when metadata is already present. |
| 190 | + # Its ShapeEnv owns the symbolic dims captured during tracing, so using a |
| 191 | + # fresh mode here can desynchronize fake inputs from graph metadata. |
| 192 | + if node.op == "placeholder" and "val" in node.meta: |
| 193 | + fake_val = node.meta["val"] |
| 194 | + if fake_val is not None and isinstance(fake_val, torch.Tensor): |
| 195 | + fake_mode = maybe_get_fake_mode(fake_val) |
| 196 | + elif fake_mode is None: |
| 197 | + fake_val = node.meta.get("example_value", node.meta.get("val")) |
| 198 | + if fake_val is not None and isinstance(fake_val, torch.Tensor): |
| 199 | + fake_mode = maybe_get_fake_mode(fake_val) |
| 200 | + if fake_mode is not None: |
| 201 | + break |
| 202 | + |
| 203 | + if fake_mode is None: |
| 204 | + # Some graphs do not carry fake tensor metadata yet; create a fallback |
| 205 | + # mode so FakeTensorProp can still run shape-only execution. |
| 206 | + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) |
| 207 | + |
189 | 208 | fake_inputs = [] |
190 | 209 | for t in real_inputs: |
191 | 210 | if isinstance(t, torch.Tensor): |
192 | 211 | fake_inputs.append(fake_mode.from_tensor(t)) |
193 | 212 | else: |
194 | 213 | fake_inputs.append(t) |
195 | | - FakeTensorProp(gm).propagate(*fake_inputs) |
| 214 | + |
| 215 | + # Torch 2.9 can fail fake propagation through SDPA's masked fake-CUDA path, |
| 216 | + # even though this pass only needs output metadata. Temporarily clear |
| 217 | + # attn_mask so shape propagation can proceed, then restore it immediately; |
| 218 | + # SDPA output shapes are still determined by Q/K/V shapes, not mask values. |
| 219 | + saved_sdpa_masks = [] |
| 220 | + for attn_node in get_sdpa_nodes(gm): |
| 221 | + attn_mask = attn_node.kwargs.get("attn_mask") |
| 222 | + if attn_mask is not None: |
| 223 | + saved_sdpa_masks.append((attn_node, attn_mask)) |
| 224 | + attn_node.update_kwarg("attn_mask", None) |
| 225 | + |
| 226 | + try: |
| 227 | + # fake_inputs are already created under fake_mode above, so run |
| 228 | + # propagation without reconverting them into a different fake mode. |
| 229 | + FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs) |
| 230 | + finally: |
| 231 | + for attn_node, attn_mask in saved_sdpa_masks: |
| 232 | + attn_node.update_kwarg("attn_mask", attn_mask) |
196 | 233 |
|
197 | 234 |
|
198 | 235 | def apply_autosp(gm: GraphModule, |
|
0 commit comments