From 6a1d8973d0061565d2ad5ff1d6784ed0d7b2b74f Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 11 Jun 2026 10:00:51 -0700 Subject: [PATCH] Fix get_arg to handle "self" (#20200) Summary: torch/fx/operator_schemas.py:112 In _torchscript_schema_to_signature_impl, self is renamed input. We need to copy this behavior in get_arg. Reviewed By: ethansfng Differential Revision: D108186918 --- backends/cadence/aot/tests/test_pass_utils.py | 45 +++++++++++++++++++ backends/transforms/permute_pass_utils.py | 20 ++++++++- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/backends/cadence/aot/tests/test_pass_utils.py b/backends/cadence/aot/tests/test_pass_utils.py index c9987cb7196..0b89fb41013 100644 --- a/backends/cadence/aot/tests/test_pass_utils.py +++ b/backends/cadence/aot/tests/test_pass_utils.py @@ -83,3 +83,48 @@ def test_get_arg_list_type_mismatch_raises(self) -> None: _, node = self._create_graph_with_kwargs(input="not_a_list", other=2) with self.assertRaises(TypeError): get_arg(node, "input", list) + + def _create_aten_add_node(self) -> torch.fx.Node: + """A graph with aten.add.Tensor(self, other) called positionally. + + Its schema names the first arg ``self``; torch.fx renames that to + ``input`` in the normalized signature. Args are positional (not kwargs) so + get_arg resolves them through the normalization path, not node.kwargs. + """ + graph = torch.fx.Graph() + x = graph.placeholder("x") + y = graph.placeholder("y") + node = graph.call_function(torch.ops.aten.add.Tensor, args=(x, y)) + graph.output(node) + # Owns the graph so node.graph.owning_module is set for normalization. + torch.fx.GraphModule(torch.nn.Module(), graph) + return node + + def test_get_arg_self_resolves_first_arg(self) -> None: + """get_arg resolves the schema arg named 'self' (e.g. aten.add.Tensor), + which torch.fx renames to 'input' in the normalized signature.""" + node = self._create_aten_add_node() + x, y = node.args + self.assertIs(get_arg(node, "self"), x) + # A sibling arg with an unchanged name still resolves alongside 'self'. + self.assertIs(get_arg(node, "other"), y) + + def test_get_arg_self_op_fills_defaults(self) -> None: + """A trailing arg left at its default (alpha) resolves via get_arg on a + 'self' op even though it is absent from node.args.""" + node = self._create_aten_add_node() + self.assertEqual(get_arg(node, "alpha"), 1) + + def test_get_arg_self_rejected_for_input_op(self) -> None: + """'self' must not silently resolve to a genuine 'input' arg. aten.linear + names its first arg 'input' (no 'self'), so the self->input remap must NOT + apply: 'input' resolves, but 'self' is an invalid name and raises.""" + graph = torch.fx.Graph() + x = graph.placeholder("x") + w = graph.placeholder("w") + node = graph.call_function(torch.ops.aten.linear.default, args=(x, w)) + graph.output(node) + torch.fx.GraphModule(torch.nn.Module(), graph) + self.assertIs(get_arg(node, "input"), x) + with self.assertRaises(KeyError): + get_arg(node, "self") diff --git a/backends/transforms/permute_pass_utils.py b/backends/transforms/permute_pass_utils.py index fca8946165e..b879e08d3ce 100644 --- a/backends/transforms/permute_pass_utils.py +++ b/backends/transforms/permute_pass_utils.py @@ -88,7 +88,18 @@ def get_arg( kwarg_name: str, expected_type: Type[T] = Argument, ) -> T: - """Get the arg with kwarg_name of the node.""" + """Get the arg with kwarg_name of the node. + + ``kwarg_name`` is the op's schema argument name. Note that torch.fx renames a + schema arg named ``self`` to ``input`` when building the op signature (see + torch/fx/operator_schemas.py), so the normalized kwargs key for the first + tensor arg of ops like add/mul/bmm is ``input``. We mirror that rename so + callers can pass the real schema name ``self`` and still resolve the arg -- + but only for ops whose schema actually declares a ``self`` arg. For ops with + a genuine ``input`` arg (e.g. linear/conv), ``self`` is not a valid name and + must NOT silently resolve to ``input``; such a lookup raises (KeyError) as it + would for any other unknown arg name. + """ if kwarg_name in node.kwargs: value = node.kwargs[kwarg_name] else: @@ -99,7 +110,12 @@ def get_arg( raise RuntimeError( f"get_arg: Node {node} does not support normalization of arguments" ) - value = normalized_args.kwargs[kwarg_name] + normalized_name = kwarg_name + if kwarg_name == "self": + schema = getattr(node.target, "_schema", None) + if schema is not None and any(a.name == "self" for a in schema.arguments): + normalized_name = "input" + value = normalized_args.kwargs[normalized_name] if expected_type is not Argument: try: