Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions backends/cadence/aot/tests/test_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,48 @@ def test_get_arg_list_type_mismatch_raises(self) -> None:
_, node = self._create_graph_with_kwargs(input="not_a_list", other=2)
with self.assertRaises(TypeError):
get_arg(node, "input", list)

def _create_aten_add_node(self) -> torch.fx.Node:
"""A graph with aten.add.Tensor(self, other) called positionally.

Its schema names the first arg ``self``; torch.fx renames that to
``input`` in the normalized signature. Args are positional (not kwargs) so
get_arg resolves them through the normalization path, not node.kwargs.
"""
graph = torch.fx.Graph()
x = graph.placeholder("x")
y = graph.placeholder("y")
node = graph.call_function(torch.ops.aten.add.Tensor, args=(x, y))
graph.output(node)
# Owns the graph so node.graph.owning_module is set for normalization.
torch.fx.GraphModule(torch.nn.Module(), graph)
return node

def test_get_arg_self_resolves_first_arg(self) -> None:
"""get_arg resolves the schema arg named 'self' (e.g. aten.add.Tensor),
which torch.fx renames to 'input' in the normalized signature."""
node = self._create_aten_add_node()
x, y = node.args
self.assertIs(get_arg(node, "self"), x)
# A sibling arg with an unchanged name still resolves alongside 'self'.
self.assertIs(get_arg(node, "other"), y)

def test_get_arg_self_op_fills_defaults(self) -> None:
"""A trailing arg left at its default (alpha) resolves via get_arg on a
'self' op even though it is absent from node.args."""
node = self._create_aten_add_node()
self.assertEqual(get_arg(node, "alpha"), 1)

def test_get_arg_self_rejected_for_input_op(self) -> None:
"""'self' must not silently resolve to a genuine 'input' arg. aten.linear
names its first arg 'input' (no 'self'), so the self->input remap must NOT
apply: 'input' resolves, but 'self' is an invalid name and raises."""
graph = torch.fx.Graph()
x = graph.placeholder("x")
w = graph.placeholder("w")
node = graph.call_function(torch.ops.aten.linear.default, args=(x, w))
graph.output(node)
torch.fx.GraphModule(torch.nn.Module(), graph)
self.assertIs(get_arg(node, "input"), x)
with self.assertRaises(KeyError):
get_arg(node, "self")
20 changes: 18 additions & 2 deletions backends/transforms/permute_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,18 @@ def get_arg(
kwarg_name: str,
expected_type: Type[T] = Argument,
) -> T:
"""Get the arg with kwarg_name of the node."""
"""Get the arg with kwarg_name of the node.

``kwarg_name`` is the op's schema argument name. Note that torch.fx renames a
schema arg named ``self`` to ``input`` when building the op signature (see
torch/fx/operator_schemas.py), so the normalized kwargs key for the first
tensor arg of ops like add/mul/bmm is ``input``. We mirror that rename so
callers can pass the real schema name ``self`` and still resolve the arg --
but only for ops whose schema actually declares a ``self`` arg. For ops with
a genuine ``input`` arg (e.g. linear/conv), ``self`` is not a valid name and
must NOT silently resolve to ``input``; such a lookup raises (KeyError) as it
would for any other unknown arg name.
"""
if kwarg_name in node.kwargs:
value = node.kwargs[kwarg_name]
else:
Expand All @@ -99,7 +110,12 @@ def get_arg(
raise RuntimeError(
f"get_arg: Node {node} does not support normalization of arguments"
)
value = normalized_args.kwargs[kwarg_name]
normalized_name = kwarg_name
if kwarg_name == "self":
schema = getattr(node.target, "_schema", None)
if schema is not None and any(a.name == "self" for a in schema.arguments):
normalized_name = "input"
value = normalized_args.kwargs[normalized_name]

if expected_type is not Argument:
try:
Expand Down
Loading