Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
from .fuse_view_copy_transform_pass import FuseViewCopyTransformPass # noqa
from .insert_const_shapes import InsertConstShapesPass # noqa
from .insert_int32_casts_after_int64_placeholders import ( # noqa
InsertInt32CastsAfterInt64PlaceholdersPass,
)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
FuseEqualPlaceholdersPass,
FuseQuantizedActivationPass,
FuseViewCopyTransformPass,
InsertConstShapesPass,
InsertControlFlowRescalesPass,
InsertInt32CastsAfterInt64PlaceholdersPass,
InsertRescaleInt32Pass,
Expand Down Expand Up @@ -376,6 +377,7 @@ def _tosa_pipeline(
RewriteMatmulPass(),
RewritePadPass(),
RewriteSlicePass(),
InsertConstShapesPass(),
]
)

Expand Down
63 changes: 63 additions & 0 deletions backends/arm/_passes/insert_const_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark
from executorch.exir.dialects._ops import ops as exir_ops


class InsertConstShapesPass(ArmPass):
"""Materialize literal shape arguments as CONST_SHAPE nodes.

This pass targets ops such as `aten.view_copy` and `aten.repeat` whose shape
arguments might otherwise remain raw Python lists/tuples. Replacing them
with explicit CONST_SHAPE nodes simplifies the serialization of these ops
the serialization of their arguments is handled by the CONST_SHAPE node visitor.

"""

_passes_required_after = set()
targeted_ops = {
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.repeat.default,
}

@staticmethod
def _is_shape_arg(arg: Any) -> bool:
"""Return True when `arg` looks like a literal shape list/tuple."""
is_shape_op = meta_has_shape_mark(arg.meta) if hasattr(arg, "meta") else False
return (
not is_shape_op
and isinstance(arg, (list, tuple))
and all(isinstance(x, int) for x in arg)
)

def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
if op not in self.targeted_ops:
return super().call_operator(op, args, kwargs, meta, updated)
if any(InsertConstShapesPass._is_shape_arg(arg) for arg in args):
new_args = []
for arg in args:
if InsertConstShapesPass._is_shape_arg(arg):
# Insert a const node for the shape argument
if op == exir_ops.edge.aten.view_copy.default:
arg = meta.data["val"].shape
const_node = super().call_shape_operator(
exir_ops.backend.tosa.CONST_SHAPE.default,
(arg,),
{},
meta,
True,
)
new_args.append(const_node)
updated = True
else:
new_args.append(arg)

return super().call_operator(op, tuple(new_args), kwargs, meta, updated)

return super().call_operator(op, args, kwargs, meta, updated)
15 changes: 1 addition & 14 deletions backends/arm/operators/op_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.utils import tosa_shape


@register_node_visitor
Expand Down Expand Up @@ -53,25 +52,13 @@ def define_node(
self.tosa_spec,
)

multiples = inputs[1].special

if len(multiples) == 0:
raise ValueError(f"Length of multiples argument is 0: {inputs[1]}!")

multiple_shapes = tosa_graph.addConst(
(len(multiples),),
ts.DType.SHAPE,
list(tosa_shape(multiples, output.dim_order)),
name=output.name + "_multiples",
)

attr = ts.TosaSerializerAttribute()
attr.TileAttribute()
self._serialize_operator(
node,
tosa_graph,
ts.Op.TILE,
[inputs[0].name, multiple_shapes.name],
[inputs[0].name, inputs[1].name],
[output.name],
attr,
)
17 changes: 1 addition & 16 deletions backends/arm/operators/op_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.utils import tosa_shape


@register_node_visitor
Expand Down Expand Up @@ -55,27 +54,13 @@ def define_node(

tosa_graph = cast(ts.TosaSerializer, tosa_graph)

if len(output.shape) != 0:
shape_len = [len(output.shape)]
shape_data = list(tosa_shape(output.shape, output.dim_order))
else:
shape_len = []
shape_data = []

shape = tosa_graph.addConst(
shape_len,
ts.DType.SHAPE,
shape_data,
name=output.name + "_shape",
)

attr = ts.TosaSerializerAttribute()
attr.ReshapeAttribute()
self._serialize_operator(
node,
tosa_graph,
ts.Op.RESHAPE,
[inputs[0].name, shape.name],
[inputs[0].name, inputs[1].name],
[output.name],
attr,
)
Loading