Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions backends/arm/_passes/aten_to_tosa_activation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,21 @@ def rewrite_clamp(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec | Non
exir_ops.backend.tosa.CLAMP.default,
(node.args[0], *min_max_args),
)


def get_activation_replacement(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec | None:
# Dispatch activation rewrites from their ATen target to the matching TOSA
# dialect node builder.
match node.target:
case exir_ops.edge.aten.clamp.default:
return rewrite_clamp(node, pass_)
case exir_ops.edge.aten.erf.default:
return rewrite_erf(node, pass_)
case exir_ops.edge.aten.sigmoid.default:
return rewrite_sigmoid(node, pass_)
case exir_ops.edge.aten.tanh.default:
return rewrite_tanh(node, pass_)
case _:
return None
26 changes: 26 additions & 0 deletions backends/arm/_passes/aten_to_tosa_tensor_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

from executorch.backends.transforms.aten_to_dialect_pass import (
AtenToDialectPass,
DialectNodeSpec,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx import Node


def rewrite_argmax(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec:
input_node = cast(Node, node.args[0])
dim = cast(int, node.kwargs["dim"] if "dim" in node.kwargs else node.args[1])
if dim < 0:
dim += len(input_node.meta["val"].shape)

return DialectNodeSpec(
exir_ops.backend.tosa.ARGMAX.default,
(input_node, dim),
{},
)
43 changes: 22 additions & 21 deletions backends/arm/_passes/exir_to_tosa_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,38 @@

import executorch.backends.arm.tosa.dialect # noqa: F401
from executorch.backends.arm._passes.aten_to_tosa_activation_functions import (
rewrite_clamp,
rewrite_erf,
rewrite_sigmoid,
rewrite_tanh,
get_activation_replacement,
)
from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import rewrite_argmax
from executorch.backends.transforms.aten_to_dialect_pass import (
AtenToDialectPass,
DialectNodeSpec,
)
from executorch.backends.transforms.aten_to_dialect_pass import AtenToDialectPass
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx import Node


class ExirToTosaPass(AtenToDialectPass):
"""Rewrite simple EXIR ops to equivalent backend TOSA dialect ops.

Rewrite functions are grouped by op category and registered with the shared
ATen-to-dialect pass infrastructure.
Rewrite functions are registered with the shared ATen-to-dialect pass
infrastructure.

"""


_ACTIVATION_FUNCTION_REWRITES = {
exir_ops.edge.aten.clamp.default: rewrite_clamp,
exir_ops.edge.aten.erf.default: rewrite_erf,
exir_ops.edge.aten.sigmoid.default: rewrite_sigmoid,
exir_ops.edge.aten.tanh.default: rewrite_tanh,
}
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.argmax.default)
def _get_tensor_operators_replacement(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec:
return rewrite_argmax(node, pass_)

_DIRECT_REWRITE_CATEGORIES = {
"activation_functions": _ACTIVATION_FUNCTION_REWRITES,
}

# Register each category's ATen targets with the function that builds the
# corresponding TOSA dialect node spec.
for _rewrite_category in _DIRECT_REWRITE_CATEGORIES.values():
for _edge_target, _rewrite_fn in _rewrite_category.items():
ExirToTosaPass.register_dialect_substitution(_edge_target)(_rewrite_fn)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.clamp.default)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.erf.default)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.sigmoid.default)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.tanh.default)
def _get_activation_replacement(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec | None:
return get_activation_replacement(node, pass_)
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.aten.pad.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.argmax.default,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
exir_ops.edge.aten.eye.default,
Expand Down Expand Up @@ -237,6 +238,7 @@
operator.getitem,
exir_ops.edge.aten.pad.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.argmax.default,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
exir_ops.edge.aten.eye.default,
Expand Down
104 changes: 101 additions & 3 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def tosa_support_factory(

# Negative checks: Remove nodes from partitioning
negative_checks: list[OperatorSupportBase] = [
CheckInt64InputsAndOutputs(exported_program, reporter),
CheckInt64InputsAndOutputs(exported_program, reporter, tosa_spec),
RankCheck(reporter, max_rank=MAX_RANK),
*[
reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}")
Expand Down Expand Up @@ -588,7 +588,10 @@ class CheckInt64InputsAndOutputs(OperatorSupportBase):
"""

def __init__(
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
self,
exported_program: ExportedProgram,
reporter: WhyNoPartitionReporter,
tosa_spec: TosaSpecification,
):
"""Initialize the check with program context and reporter."""
self.input_names = [
Expand All @@ -597,6 +600,7 @@ def __init__(
if spec.kind == InputKind.USER_INPUT
]
self.reporter = reporter
self.tosa_spec = tosa_spec
self.int32_min = torch.iinfo(torch.int32).min
self.int32_max = torch.iinfo(torch.int32).max
super().__init__()
Expand All @@ -609,6 +613,100 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
min_val, max_val = int(torch.min(data)), int(torch.max(data))
return min_val >= self.int32_min and max_val <= self.int32_max

def has_rejected_int64_output(
self, node: torch.fx.Node, tensor_list: Sequence[typing.Any]
) -> bool:
if node.target in (
torch.ops.aten.argmax.default,
exir_ops.edge.aten.argmax.default,
):
return not self._is_tosa_argmax_supported(node)
return any(tensor.dtype == torch.int64 for tensor in tensor_list)

def _is_tosa_argmax_dtype_supported(
self, node: torch.fx.Node, input_dtype: torch.dtype
) -> bool:
if input_dtype == torch.int8:
if not self.tosa_spec.support_integer():
self.reporter.report_reject(
node, "TOSA ARGMAX requires PRO-INT for int8 input."
)
return False
elif input_dtype == torch.int16:
if not (
self.tosa_spec.support_integer()
and self.tosa_spec.support_extension("int16")
):
self.reporter.report_reject(
node, "TOSA ARGMAX requires EXT-INT16 for int16 input."
)
return False
elif input_dtype in (torch.float16, torch.float32):
if not self.tosa_spec.support_float():
self.reporter.report_reject(
node, f"TOSA ARGMAX requires PRO-FP for {input_dtype} input."
)
return False
elif input_dtype == torch.bfloat16:
if not (
self.tosa_spec.support_float()
and self.tosa_spec.support_extension("bf16")
):
self.reporter.report_reject(
node, "TOSA ARGMAX requires EXT-BF16 for bfloat16 input."
)
return False
else:
self.reporter.report_reject(
node, f"TOSA ARGMAX does not support {input_dtype} input."
)
return False
return True

def _is_tosa_argmax_supported(self, node: torch.fx.Node) -> bool:
dim = node.kwargs.get("dim", node.args[1] if len(node.args) > 1 else None)
if dim is None:
self.reporter.report_reject(
node, "TOSA ARGMAX requires an explicit reduction dimension."
)
return False
if not isinstance(dim, int):
self.reporter.report_reject(
node, "TOSA ARGMAX requires a statically known reduction dimension."
)
return False

input_node = typing.cast(torch.fx.Node, node.args[0])
input_tensor = get_first_fake_tensor(input_node)
if not self._is_tosa_argmax_dtype_supported(node, input_tensor.dtype):
return False

input_rank = len(input_tensor.shape)
if input_rank == 0:
self.reporter.report_reject(
node, "TOSA ARGMAX requires an input with rank at least 1."
)
return False

axis = dim + input_rank if dim < 0 else dim
if axis < 0 or axis >= input_rank:
self.reporter.report_reject(
node,
f"TOSA ARGMAX axis must be in [0, {input_rank - 1}] but got {dim}.",
)
return False

keepdim = node.kwargs.get(
"keepdim", node.args[2] if len(node.args) > 2 else False
)
if keepdim:
self.reporter.report_reject(
node, "TOSA ARGMAX does not support keepdim=True."
)
return False

return True

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
Expand All @@ -618,7 +716,7 @@ def is_node_supported(
vals = node.meta["val"]
tensor_list = vals if isinstance(vals, (list, tuple)) else [vals]

any_int64 = any(tensor.dtype == torch.int64 for tensor in tensor_list)
any_int64 = self.has_rejected_int64_output(node, tensor_list)
# Don't partition nodes with int64 output...
if any_int64:
# ... Except for constant ops that are directly cast to something non-int64.
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
op_sub,
op_sum,
op_to_dim_order_copy,
op_tosa_argmax,
op_tosa_avg_pool2d,
op_tosa_avg_pool2d_adaptive,
op_tosa_cast_to_block_scaled,
Expand Down
63 changes: 63 additions & 0 deletions backends/arm/operators/op_tosa_argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List

import torch.fx
import tosa_serializer as ts

from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg


@register_node_visitor
class ArgMaxVisitor(NodeVisitor):
target = "tosa.ARGMAX.default"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, 2)
validate_valid_dtype(
self.target,
inputs[0],
[
ts.DType.INT8,
ts.DType.INT16,
ts.DType.FP16,
ts.DType.FP32,
ts.DType.BF16,
],
self.tosa_spec,
)
validate_valid_dtype(self.target, output, ts.DType.INT32, self.tosa_spec)

axis = inputs[1].number
if axis < 0:
tensor = get_first_fake_tensor(node)
axis += len(tensor.size())

attr = ts.TosaSerializerAttribute()
attr.ArgMaxAttribute(axis, ts.NanPropagationMode.PROPAGATE)
self._serialize_operator(
node,
tosa_graph,
ts.Op.ARGMAX,
[inputs[0].name],
[output.name],
attr,
)
Loading
Loading