From 8f347eae89e6a5f299d822a51a31344c71acb927 Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Tue, 2 Jun 2026 11:49:56 +0100 Subject: [PATCH 1/2] Arm backend: Add TOSA dialect ARGMAX node visitor - Use _DIALECT_SUBSTITUTIONS for activations Change-Id: I7a4bb7a53789b0b86618686927a93840464b80ba Signed-off-by: Saoirse Stewart --- .../aten_to_tosa_activation_functions.py | 18 +++ .../_passes/aten_to_tosa_tensor_operators.py | 26 ++++ backends/arm/_passes/exir_to_tosa_pass.py | 43 +++---- .../tosa_profile_supported_op_lists.py | 2 + .../tosa_supported_operators.py | 52 +++++++- backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_tosa_argmax.py | 63 ++++++++++ backends/arm/test/ops/test_argmax.py | 119 ++++++++++++++++++ 8 files changed, 302 insertions(+), 22 deletions(-) create mode 100644 backends/arm/_passes/aten_to_tosa_tensor_operators.py create mode 100644 backends/arm/operators/op_tosa_argmax.py create mode 100644 backends/arm/test/ops/test_argmax.py diff --git a/backends/arm/_passes/aten_to_tosa_activation_functions.py b/backends/arm/_passes/aten_to_tosa_activation_functions.py index 9b92b31e630..8d51f092991 100644 --- a/backends/arm/_passes/aten_to_tosa_activation_functions.py +++ b/backends/arm/_passes/aten_to_tosa_activation_functions.py @@ -128,3 +128,21 @@ def rewrite_clamp(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec | Non exir_ops.backend.tosa.CLAMP.default, (node.args[0], *min_max_args), ) + + +def get_activation_replacement( + node: Node, pass_: AtenToDialectPass +) -> DialectNodeSpec | None: + # Dispatch activation rewrites from their ATen target to the matching TOSA + # dialect node builder. + match node.target: + case exir_ops.edge.aten.clamp.default: + return rewrite_clamp(node, pass_) + case exir_ops.edge.aten.erf.default: + return rewrite_erf(node, pass_) + case exir_ops.edge.aten.sigmoid.default: + return rewrite_sigmoid(node, pass_) + case exir_ops.edge.aten.tanh.default: + return rewrite_tanh(node, pass_) + case _: + return None diff --git a/backends/arm/_passes/aten_to_tosa_tensor_operators.py b/backends/arm/_passes/aten_to_tosa_tensor_operators.py new file mode 100644 index 00000000000..140aa87615f --- /dev/null +++ b/backends/arm/_passes/aten_to_tosa_tensor_operators.py @@ -0,0 +1,26 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast + +from executorch.backends.transforms.aten_to_dialect_pass import ( + AtenToDialectPass, + DialectNodeSpec, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch.fx import Node + + +def rewrite_argmax(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec: + input_node = cast(Node, node.args[0]) + dim = cast(int, node.kwargs["dim"] if "dim" in node.kwargs else node.args[1]) + if dim < 0: + dim += len(input_node.meta["val"].shape) + + return DialectNodeSpec( + exir_ops.backend.tosa.ARGMAX.default, + (input_node, dim), + {}, + ) diff --git a/backends/arm/_passes/exir_to_tosa_pass.py b/backends/arm/_passes/exir_to_tosa_pass.py index b77171b9eaf..c0c6efb1a6c 100644 --- a/backends/arm/_passes/exir_to_tosa_pass.py +++ b/backends/arm/_passes/exir_to_tosa_pass.py @@ -5,37 +5,38 @@ import executorch.backends.arm.tosa.dialect # noqa: F401 from executorch.backends.arm._passes.aten_to_tosa_activation_functions import ( - rewrite_clamp, - rewrite_erf, - rewrite_sigmoid, - rewrite_tanh, + get_activation_replacement, +) +from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import rewrite_argmax +from executorch.backends.transforms.aten_to_dialect_pass import ( + AtenToDialectPass, + DialectNodeSpec, ) -from executorch.backends.transforms.aten_to_dialect_pass import AtenToDialectPass from executorch.exir.dialects._ops import ops as exir_ops +from torch.fx import Node class ExirToTosaPass(AtenToDialectPass): """Rewrite simple EXIR ops to equivalent backend TOSA dialect ops. - Rewrite functions are grouped by op category and registered with the shared - ATen-to-dialect pass infrastructure. + Rewrite functions are registered with the shared ATen-to-dialect pass + infrastructure. """ -_ACTIVATION_FUNCTION_REWRITES = { - exir_ops.edge.aten.clamp.default: rewrite_clamp, - exir_ops.edge.aten.erf.default: rewrite_erf, - exir_ops.edge.aten.sigmoid.default: rewrite_sigmoid, - exir_ops.edge.aten.tanh.default: rewrite_tanh, -} +@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.argmax.default) +def _get_tensor_operators_replacement( + node: Node, pass_: AtenToDialectPass +) -> DialectNodeSpec: + return rewrite_argmax(node, pass_) -_DIRECT_REWRITE_CATEGORIES = { - "activation_functions": _ACTIVATION_FUNCTION_REWRITES, -} -# Register each category's ATen targets with the function that builds the -# corresponding TOSA dialect node spec. -for _rewrite_category in _DIRECT_REWRITE_CATEGORIES.values(): - for _edge_target, _rewrite_fn in _rewrite_category.items(): - ExirToTosaPass.register_dialect_substitution(_edge_target)(_rewrite_fn) +@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.clamp.default) +@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.erf.default) +@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.sigmoid.default) +@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.tanh.default) +def _get_activation_replacement( + node: Node, pass_: AtenToDialectPass +) -> DialectNodeSpec | None: + return get_activation_replacement(node, pass_) diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index fab4e6c60c1..7d7a40e48c5 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -99,6 +99,7 @@ exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, exir_ops.edge.aten.pad.default, exir_ops.edge.aten.constant_pad_nd.default, + exir_ops.edge.aten.argmax.default, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, exir_ops.edge.aten.eye.default, @@ -237,6 +238,7 @@ operator.getitem, exir_ops.edge.aten.pad.default, exir_ops.edge.aten.constant_pad_nd.default, + exir_ops.edge.aten.argmax.default, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, exir_ops.edge.aten.eye.default, diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 2d064ed298c..fcddaebe8d4 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -592,6 +592,56 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool: min_val, max_val = int(torch.min(data)), int(torch.max(data)) return min_val >= self.int32_min and max_val <= self.int32_max + def has_rejected_int64_output( + self, node: torch.fx.Node, tensor_list: Sequence[typing.Any] + ) -> bool: + if node.target in ( + torch.ops.aten.argmax.default, + exir_ops.edge.aten.argmax.default, + ): + return not self._is_tosa_argmax_supported(node) + return any(tensor.dtype == torch.int64 for tensor in tensor_list) + + def _is_tosa_argmax_supported(self, node: torch.fx.Node) -> bool: + dim = node.kwargs.get("dim", node.args[1] if len(node.args) > 1 else None) + if dim is None: + self.reporter.report_reject( + node, "TOSA ARGMAX requires an explicit reduction dimension." + ) + return False + if not isinstance(dim, int): + self.reporter.report_reject( + node, "TOSA ARGMAX requires a statically known reduction dimension." + ) + return False + + input_node = typing.cast(torch.fx.Node, node.args[0]) + input_rank = len(get_first_fake_tensor(input_node).shape) + if input_rank == 0: + self.reporter.report_reject( + node, "TOSA ARGMAX requires an input with rank at least 1." + ) + return False + + axis = dim + input_rank if dim < 0 else dim + if axis < 0 or axis >= input_rank: + self.reporter.report_reject( + node, + f"TOSA ARGMAX axis must be in [0, {input_rank - 1}] but got {dim}.", + ) + return False + + keepdim = node.kwargs.get( + "keepdim", node.args[2] if len(node.args) > 2 else False + ) + if keepdim: + self.reporter.report_reject( + node, "TOSA ARGMAX does not support keepdim=True." + ) + return False + + return True + def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: @@ -601,7 +651,7 @@ def is_node_supported( vals = node.meta["val"] tensor_list = vals if isinstance(vals, (list, tuple)) else [vals] - any_int64 = any(tensor.dtype == torch.int64 for tensor in tensor_list) + any_int64 = self.has_rejected_int64_output(node, tensor_list) # Don't partition nodes with int64 output... if any_int64: # ... Except for constant ops that are directly cast to something non-int64. diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index d2c2846b68c..357f1704278 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -42,6 +42,7 @@ op_sub, op_sum, op_to_dim_order_copy, + op_tosa_argmax, op_tosa_avg_pool2d, op_tosa_avg_pool2d_adaptive, op_tosa_clamp, diff --git a/backends/arm/operators/op_tosa_argmax.py b/backends/arm/operators/op_tosa_argmax.py new file mode 100644 index 00000000000..01917e0870a --- /dev/null +++ b/backends/arm/operators/op_tosa_argmax.py @@ -0,0 +1,63 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import torch.fx +import tosa_serializer as ts + +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg + + +@register_node_visitor +class ArgMaxVisitor(NodeVisitor): + target = "tosa.ARGMAX.default" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, 2) + validate_valid_dtype( + self.target, + inputs[0], + [ + ts.DType.INT8, + ts.DType.INT16, + ts.DType.FP16, + ts.DType.FP32, + ts.DType.BF16, + ], + self.tosa_spec, + ) + validate_valid_dtype(self.target, output, ts.DType.INT32, self.tosa_spec) + + axis = inputs[1].number + if axis < 0: + tensor = get_first_fake_tensor(node) + axis += len(tensor.size()) + + attr = ts.TosaSerializerAttribute() + attr.ArgMaxAttribute(axis, ts.NanPropagationMode.PROPAGATE) + self._serialize_operator( + node, + tosa_graph, + ts.Op.ARGMAX, + [inputs[0].name], + [output.name], + attr, + ) diff --git a/backends/arm/test/ops/test_argmax.py b/backends/arm/test/ops/test_argmax.py new file mode 100644 index 00000000000..62433db65c3 --- /dev/null +++ b/backends/arm/test/ops/test_argmax.py @@ -0,0 +1,119 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + OpNotSupportedPipeline, + TosaPipelineFP, + TosaPipelineINT, +) + +aten_op = "torch.ops.aten.argmax.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_argmax_default" +input_t = Tuple[torch.Tensor] + + +class Argmax(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor): + return torch.argmax(x, dim=self.dim).to(torch.int32) + + test_data: dict[str, Tuple[input_t, int]] = { + "rank_1_dim_0": lambda: ((torch.rand(10),), 0), + "rank_2_dim_1": lambda: ((torch.rand(2, 5),), 1), + "rank_4_dim_2": lambda: ((torch.rand(1, 3, 4, 5),), 2), + "rank_4_dim_3": lambda: ((torch.rand(1, 3, 4, 5),), 3), + "rank_4_dim_neg1": lambda: ((torch.rand(1, 3, 4, 5),), -1), + } + + test_data_fp16: dict[str, Tuple[input_t, int]] = { + "rank_2_dim_1_fp16": lambda: ((torch.rand(2, 5, dtype=torch.float16),), 1), + } + + test_data_bf16: dict[str, Tuple[input_t, int]] = { + "rank_2_dim_1_bf16": lambda: ((torch.rand(2, 5, dtype=torch.bfloat16),), 1), + } + + test_data_int: dict[str, Tuple[input_t, int]] = { + "rank_1_dim_0_int8": lambda: ( + (torch.randint(-128, 127, (10,), dtype=torch.int8),), + 0, + ), + "rank_2_dim_1_int8": lambda: ( + (torch.randint(-128, 127, (2, 5), dtype=torch.int8),), + 1, + ), + "rank_4_dim_2_int8": lambda: ( + (torch.randint(-128, 127, (1, 3, 4, 5), dtype=torch.int8),), + 2, + ), + "rank_4_dim_3_int8": lambda: ( + (torch.randint(-128, 127, (1, 3, 4, 5), dtype=torch.int8),), + 3, + ), + } + + +class ArgmaxAll(torch.nn.Module): + def forward(self, x: torch.Tensor): + return torch.argmax(x) + + +class ArgmaxKeepDim(torch.nn.Module): + def forward(self, x: torch.Tensor): + return torch.argmax(x, dim=1, keepdim=True) + + +@common.parametrize( + "test_data", Argmax.test_data | Argmax.test_data_fp16 | Argmax.test_data_bf16 +) +def test_argmax_tosa_FP(test_data: Tuple[input_t, int]): + data, dim = test_data() + pipeline = TosaPipelineFP[input_t]( + Argmax(dim), + data, + aten_op, + exir_op, + tosa_extensions=["bf16"], + ) + pipeline.count_tosa_ops({"ARGMAX": 1}) + pipeline.run() + + +def test_argmax_all_tosa_FP_not_delegated(): + pipeline = OpNotSupportedPipeline[input_t]( + ArgmaxAll(), + (torch.rand(2, 5),), + {exir_op: 1}, + ) + pipeline.run() + + +def test_argmax_keepdim_tosa_FP_not_delegated(): + pipeline = OpNotSupportedPipeline[input_t]( + ArgmaxKeepDim(), + (torch.rand(2, 5),), + {exir_op: 1}, + ) + pipeline.run() + + +@common.parametrize("test_data", Argmax.test_data_int) +def test_argmax_tosa_INT(test_data: Tuple[input_t, int]): + data, dim = test_data() + pipeline = TosaPipelineINT[input_t]( + Argmax(dim), + data, + aten_op, + exir_op, + ) + pipeline.count_tosa_ops({"ARGMAX": 1}) + pipeline.run() From d9f70b08926d109aa9cc8d839020b29b283a950b Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Tue, 2 Jun 2026 11:49:56 +0100 Subject: [PATCH 2/2] Arm backend: Reject int32 for ARGMAX for now. Signed-off-by: Saoirse Stewart Change-Id: I7a4bb7a53789b0b86618686927a93840464b80ba --- .../tosa_supported_operators.py | 54 +++++++++++++++++-- backends/arm/test/ops/test_argmax.py | 14 +++++ 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 0657497174e..48c5382f6f1 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -292,7 +292,7 @@ def tosa_support_factory( # Negative checks: Remove nodes from partitioning negative_checks: list[OperatorSupportBase] = [ - CheckInt64InputsAndOutputs(exported_program, reporter), + CheckInt64InputsAndOutputs(exported_program, reporter, tosa_spec), RankCheck(reporter, max_rank=MAX_RANK), *[ reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}") @@ -588,7 +588,10 @@ class CheckInt64InputsAndOutputs(OperatorSupportBase): """ def __init__( - self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter + self, + exported_program: ExportedProgram, + reporter: WhyNoPartitionReporter, + tosa_spec: TosaSpecification, ): """Initialize the check with program context and reporter.""" self.input_names = [ @@ -597,6 +600,7 @@ def __init__( if spec.kind == InputKind.USER_INPUT ] self.reporter = reporter + self.tosa_spec = tosa_spec self.int32_min = torch.iinfo(torch.int32).min self.int32_max = torch.iinfo(torch.int32).max super().__init__() @@ -619,6 +623,46 @@ def has_rejected_int64_output( return not self._is_tosa_argmax_supported(node) return any(tensor.dtype == torch.int64 for tensor in tensor_list) + def _is_tosa_argmax_dtype_supported( + self, node: torch.fx.Node, input_dtype: torch.dtype + ) -> bool: + if input_dtype == torch.int8: + if not self.tosa_spec.support_integer(): + self.reporter.report_reject( + node, "TOSA ARGMAX requires PRO-INT for int8 input." + ) + return False + elif input_dtype == torch.int16: + if not ( + self.tosa_spec.support_integer() + and self.tosa_spec.support_extension("int16") + ): + self.reporter.report_reject( + node, "TOSA ARGMAX requires EXT-INT16 for int16 input." + ) + return False + elif input_dtype in (torch.float16, torch.float32): + if not self.tosa_spec.support_float(): + self.reporter.report_reject( + node, f"TOSA ARGMAX requires PRO-FP for {input_dtype} input." + ) + return False + elif input_dtype == torch.bfloat16: + if not ( + self.tosa_spec.support_float() + and self.tosa_spec.support_extension("bf16") + ): + self.reporter.report_reject( + node, "TOSA ARGMAX requires EXT-BF16 for bfloat16 input." + ) + return False + else: + self.reporter.report_reject( + node, f"TOSA ARGMAX does not support {input_dtype} input." + ) + return False + return True + def _is_tosa_argmax_supported(self, node: torch.fx.Node) -> bool: dim = node.kwargs.get("dim", node.args[1] if len(node.args) > 1 else None) if dim is None: @@ -633,7 +677,11 @@ def _is_tosa_argmax_supported(self, node: torch.fx.Node) -> bool: return False input_node = typing.cast(torch.fx.Node, node.args[0]) - input_rank = len(get_first_fake_tensor(input_node).shape) + input_tensor = get_first_fake_tensor(input_node) + if not self._is_tosa_argmax_dtype_supported(node, input_tensor.dtype): + return False + + input_rank = len(input_tensor.shape) if input_rank == 0: self.reporter.report_reject( node, "TOSA ARGMAX requires an input with rank at least 1." diff --git a/backends/arm/test/ops/test_argmax.py b/backends/arm/test/ops/test_argmax.py index 62433db65c3..08e2172cdeb 100644 --- a/backends/arm/test/ops/test_argmax.py +++ b/backends/arm/test/ops/test_argmax.py @@ -72,6 +72,11 @@ def forward(self, x: torch.Tensor): return torch.argmax(x, dim=1, keepdim=True) +class ArgmaxInt32(torch.nn.Module): + def forward(self, x: torch.Tensor): + return torch.argmax(x, dim=1).to(torch.int32) + + @common.parametrize( "test_data", Argmax.test_data | Argmax.test_data_fp16 | Argmax.test_data_bf16 ) @@ -106,6 +111,15 @@ def test_argmax_keepdim_tosa_FP_not_delegated(): pipeline.run() +def test_argmax_int32_tosa_FP_not_delegated(): + pipeline = OpNotSupportedPipeline[input_t]( + ArgmaxInt32(), + (torch.randint(0, 10, (2, 5), dtype=torch.int32),), + {exir_op: 1}, + ) + pipeline.run() + + @common.parametrize("test_data", Argmax.test_data_int) def test_argmax_tosa_INT(test_data: Tuple[input_t, int]): data, dim = test_data()