From 42926f1008f2a6c6ca58c77f192e87d5bcf59cc8 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Fri, 10 Oct 2025 11:10:05 +0200 Subject: [PATCH] Arm backend: Add pass to create CONST_SHAPEs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add pass that creates CONST_SHAPEs for view_copy and repeat based on their list arguments. Change-Id: Icf8e96383825cf0d17710d900e9041191f46f749 Co-authored-by: Per Åstrand Signed-off-by: Oscar Andersson --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/_passes/insert_const_shapes.py | 63 +++++++++++++++++++++ backends/arm/operators/op_repeat.py | 15 +---- backends/arm/operators/op_view.py | 17 +----- 5 files changed, 68 insertions(+), 30 deletions(-) create mode 100644 backends/arm/_passes/insert_const_shapes.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 4fc6bbf1cbc..89f21553771 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -105,6 +105,7 @@ from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa from .fuse_view_copy_transform_pass import FuseViewCopyTransformPass # noqa +from .insert_const_shapes import InsertConstShapesPass # noqa from .insert_int32_casts_after_int64_placeholders import ( # noqa InsertInt32CastsAfterInt64PlaceholdersPass, ) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index df2b85601bd..386c7e8aa54 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -99,6 +99,7 @@ FuseEqualPlaceholdersPass, FuseQuantizedActivationPass, FuseViewCopyTransformPass, + InsertConstShapesPass, InsertControlFlowRescalesPass, InsertInt32CastsAfterInt64PlaceholdersPass, InsertRescaleInt32Pass, @@ -376,6 +377,7 @@ def _tosa_pipeline( RewriteMatmulPass(), RewritePadPass(), RewriteSlicePass(), + InsertConstShapesPass(), ] ) diff --git a/backends/arm/_passes/insert_const_shapes.py b/backends/arm/_passes/insert_const_shapes.py new file mode 100644 index 00000000000..b03394379d9 --- /dev/null +++ b/backends/arm/_passes/insert_const_shapes.py @@ -0,0 +1,63 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark +from executorch.exir.dialects._ops import ops as exir_ops + + +class InsertConstShapesPass(ArmPass): + """Materialize literal shape arguments as CONST_SHAPE nodes. + + This pass targets ops such as `aten.view_copy` and `aten.repeat` whose shape + arguments might otherwise remain raw Python lists/tuples. Replacing them + with explicit CONST_SHAPE nodes simplifies the serialization of these ops + the serialization of their arguments is handled by the CONST_SHAPE node visitor. + + """ + + _passes_required_after = set() + targeted_ops = { + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.repeat.default, + } + + @staticmethod + def _is_shape_arg(arg: Any) -> bool: + """Return True when `arg` looks like a literal shape list/tuple.""" + is_shape_op = meta_has_shape_mark(arg.meta) if hasattr(arg, "meta") else False + return ( + not is_shape_op + and isinstance(arg, (list, tuple)) + and all(isinstance(x, int) for x in arg) + ) + + def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): + if op not in self.targeted_ops: + return super().call_operator(op, args, kwargs, meta, updated) + if any(InsertConstShapesPass._is_shape_arg(arg) for arg in args): + new_args = [] + for arg in args: + if InsertConstShapesPass._is_shape_arg(arg): + # Insert a const node for the shape argument + if op == exir_ops.edge.aten.view_copy.default: + arg = meta.data["val"].shape + const_node = super().call_shape_operator( + exir_ops.backend.tosa.CONST_SHAPE.default, + (arg,), + {}, + meta, + True, + ) + new_args.append(const_node) + updated = True + else: + new_args.append(arg) + + return super().call_operator(op, tuple(new_args), kwargs, meta, updated) + + return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 64fa3b19a29..b583cca4b43 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -19,7 +19,6 @@ validate_valid_dtype, ) from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.utils import tosa_shape @register_node_visitor @@ -53,25 +52,13 @@ def define_node( self.tosa_spec, ) - multiples = inputs[1].special - - if len(multiples) == 0: - raise ValueError(f"Length of multiples argument is 0: {inputs[1]}!") - - multiple_shapes = tosa_graph.addConst( - (len(multiples),), - ts.DType.SHAPE, - list(tosa_shape(multiples, output.dim_order)), - name=output.name + "_multiples", - ) - attr = ts.TosaSerializerAttribute() attr.TileAttribute() self._serialize_operator( node, tosa_graph, ts.Op.TILE, - [inputs[0].name, multiple_shapes.name], + [inputs[0].name, inputs[1].name], [output.name], attr, ) diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index fd9c42c4419..5903a0b1555 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -19,7 +19,6 @@ validate_valid_dtype, ) from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.utils import tosa_shape @register_node_visitor @@ -55,27 +54,13 @@ def define_node( tosa_graph = cast(ts.TosaSerializer, tosa_graph) - if len(output.shape) != 0: - shape_len = [len(output.shape)] - shape_data = list(tosa_shape(output.shape, output.dim_order)) - else: - shape_len = [] - shape_data = [] - - shape = tosa_graph.addConst( - shape_len, - ts.DType.SHAPE, - shape_data, - name=output.name + "_shape", - ) - attr = ts.TosaSerializerAttribute() attr.ReshapeAttribute() self._serialize_operator( node, tosa_graph, ts.Op.RESHAPE, - [inputs[0].name, shape.name], + [inputs[0].name, inputs[1].name], [output.name], attr, )