Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/arm/test/BUCK
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -53,7 +53,7 @@ fbcode_target(_kind = runtime.python_library,
name = "arm_tester_serialize",
srcs = ["tester/serialize.py"],
deps = [
"//executorch/backends/xnnpack/test/tester:tester",
"//executorch/backends/test/harness:tester",
"//executorch/devtools/backend_debug:delegation_info",
]
)
Expand All @@ -63,7 +63,7 @@ fbcode_target(_kind = runtime.python_library,
srcs = glob(["tester/*.py"], exclude = ["tester/serialize.py"]),
deps = [
":common",
"//executorch/backends/xnnpack/test/tester:tester",
"//executorch/backends/test/harness:tester",
"//executorch/backends/arm:ethosu",
"//executorch/backends/arm/quantizer:lib",
"//executorch/backends/arm/tosa:mapping",
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/misc/test_bn_relu_folding_qat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -14,7 +14,7 @@
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT
from executorch.backends.arm.tosa import TosaSpecification

from executorch.backends.xnnpack.test.tester.tester import Quantize
from executorch.backends.test.harness.tester import Quantize
from torch import nn


Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/ops/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
)

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.quantize import ArmQuantize
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU85PipelineINT,
OpNotSupportedPipeline,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)
from executorch.backends.xnnpack.test.tester.tester import Quantize

aten_op = "torch.ops.aten.where.self"
exir_op = "executorch_exir_dialects_edge__ops_aten_where_self"
Expand Down Expand Up @@ -269,7 +269,7 @@ def test_where_self_u55_INT_not_delegated(test_module):
u55_subset=True,
)
pipeline.change_args(
"quantize", Quantize(quantizer, get_symmetric_quantization_config())
"quantize", ArmQuantize(quantizer, get_symmetric_quantization_config())
)
pipeline.run()

Expand Down
121 changes: 95 additions & 26 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import copy

import inspect

import logging

from collections import Counter, defaultdict
Expand All @@ -24,7 +26,9 @@
Union,
)

import executorch.backends.xnnpack.test.tester.tester as tester
import executorch.backends.test.harness.stages as BaseStages

import executorch.backends.test.harness.tester as tester

import torch.fx
import torch.utils._pytree as pytree
Expand All @@ -46,7 +50,7 @@
dump_error_output,
print_error_diffs,
)
from executorch.backends.arm.test.tester.quantize import ArmQuantize as Quantize
from executorch.backends.arm.test.tester.quantize import ArmQuantize
from executorch.backends.arm.test.tester.serialize import Serialize

from executorch.backends.arm.tosa import TosaSpecification
Expand All @@ -62,14 +66,7 @@

from executorch.backends.test.harness.error_statistics import ErrorStatistics
from executorch.backends.test.harness.stages import Stage, StageType
from executorch.backends.xnnpack.test.tester import (
Partition as XnnpackPartitionStage,
Quantize as XnnpackQuantize,
Tester,
ToEdge as XnnpackToEdge,
ToEdgeTransformAndLower as XnnpackToEdgeTransformAndLower,
ToExecutorch as XnnpackToExecutorch,
)

from executorch.devtools.backend_debug import get_delegation_info

from executorch.exir import (
Expand All @@ -85,8 +82,10 @@
from executorch.exir.lowered_backend_module import LoweredBackendModule
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_manager import PassType

from executorch.exir.program._program import (
_copy_module,
_transform,
_update_exported_program_graph_module,
)
from tabulate import tabulate # type: ignore[import-untyped]
Expand Down Expand Up @@ -143,7 +142,13 @@ def _dump_lowered_modules_artifact(
_dump_str(output, path_to_dump)


class Partition(tester.Partition):
def _get_default_edge_compile_config(
skip_dim_order: bool = False,
) -> EdgeCompileConfig:
return EdgeCompileConfig(_check_ir_validity=False, _skip_dim_order=skip_dim_order)


class Partition(BaseStages.Partition):
def dump_artifact(self, path_to_dump: Optional[str]):
super().dump_artifact(path_to_dump)
artifact = cast(Optional[EdgeProgramManager], self.artifact)
Expand All @@ -156,7 +161,7 @@ def dump_artifact(self, path_to_dump: Optional[str]):
_dump_lowered_modules_artifact(path_to_dump, artifact, graph_module)


class ToEdgeTransformAndLower(tester.ToEdgeTransformAndLower):
class ToEdgeTransformAndLower(BaseStages.ToEdgeTransformAndLower):
def __init__(
self,
partitioners: Optional[List[Partitioner]] = None,
Expand All @@ -165,10 +170,17 @@ def __init__(
transform_passes: Optional[
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
] = None,
compile_spec: Optional[ArmCompileSpec] = None,
):
super().__init__(partitioners, edge_compile_config)
super().__init__(
default_partitioner_cls=None,
partitioners=partitioners,
edge_compile_config=edge_compile_config
or _get_default_edge_compile_config(),
)
self.constant_methods = constant_methods
self.transform_passes = transform_passes
self.partitioners = partitioners or []

def dump_artifact(self, path_to_dump: Optional[str]):
super().dump_artifact(path_to_dump)
Expand All @@ -195,7 +207,7 @@ def run(
)


class ToExecutorch(tester.ToExecutorch):
class ToExecutorch(BaseStages.ToExecutorch):
def run_artifact(self, inputs):
with TosaReferenceModelDispatch():
# Check if the model has mutable buffers. These are not delegated to the backend
Expand All @@ -214,7 +226,38 @@ def run_artifact(self, inputs):
return super().run_artifact(inputs)


class RunPasses(tester.RunPasses):
class RunPasses(BaseStages.RunPasses):
class TesterPassManager:
def __init__(
self,
exported_program: ExportedProgram,
passes: Optional[List[Type[PassType]]] = None,
) -> None:
self._exported_program = exported_program
self.passes = passes or []

@property
def exported_program(self) -> ExportedProgram:
return self._exported_program

def _instantiate_pass(self, pass_):
if not isinstance(pass_, type):
return pass_

if issubclass(pass_, ExportPass):
init_sig = inspect.signature(pass_.__init__)
if "exported_program" in init_sig.parameters:
return pass_(self.exported_program)
return pass_()

return pass_()

def transform(self) -> ExportedProgram:
ep = self.exported_program
for pass_ in self.passes:
ep = _transform(ep, self._instantiate_pass(pass_))
return ep

@no_type_check
def __init__(
self,
Expand All @@ -228,7 +271,11 @@ def __init__(
passes_with_exported_program
)

super().__init__(pass_list, pass_functions)
super().__init__(
pass_manager_cls=RunPasses.TesterPassManager,
pass_list=pass_list,
pass_functions=pass_functions,
)

def run(
self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None
Expand Down Expand Up @@ -282,7 +329,7 @@ def run_artifact(self, inputs):
return self.model.forward(*inputs)


class ArmTester(Tester):
class ArmTester(tester.Tester):
def __init__(
self,
model: torch.nn.Module,
Expand All @@ -307,7 +354,19 @@ def __init__(
self.transform_passes = transform_passes
self.constant_methods = constant_methods
self.compile_spec = compile_spec
super().__init__(model, example_inputs, dynamic_shapes)
stage_classes = tester.Tester.default_stage_classes() | {
StageType.PARTITION: Partition,
StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower,
StageType.TO_EXECUTORCH: ToExecutorch,
StageType.RUN_PASSES: RunPasses,
StageType.SERIALIZE: Serialize,
}
super().__init__(
model,
example_inputs,
dynamic_shapes=dynamic_shapes,
stage_classes=stage_classes,
)
self.pipeline[StageType.INITIAL_MODEL] = [
StageType.QUANTIZE,
StageType.EXPORT,
Expand All @@ -323,12 +382,12 @@ def __init__(
@no_type_check
def quantize(
self,
quantize_stage: Optional[XnnpackQuantize] = None,
quantize_stage: Optional[BaseStages.Quantize] = None,
):
# Same stage type as parent but exposed via module alias
if quantize_stage is None:
quantizer = create_quantizer(self.compile_spec)
quantize_stage = Quantize(
quantize_stage = ArmQuantize(
quantizer,
get_symmetric_quantization_config(),
)
Expand All @@ -337,22 +396,24 @@ def quantize(
@no_type_check
def to_edge(
self,
to_edge_stage: Optional[XnnpackToEdge] = None,
to_edge_stage: Optional[BaseStages.ToEdge] = None,
# Keep config keyword-only to avoid positional clashes with legacy calls.
*,
config: Optional[EdgeCompileConfig] = None,
):
# Allow optional config override beyond base signature
if to_edge_stage is None:
to_edge_stage = tester.ToEdge(config)
to_edge_stage = BaseStages.ToEdge(
config or _get_default_edge_compile_config()
)
else:
if config is not None:
to_edge_stage.edge_compile_conf = config

return super().to_edge(to_edge_stage)

@no_type_check
def partition(self, partition_stage: Optional[XnnpackPartitionStage] = None):
def partition(self, partition_stage: Optional[BaseStages.Partition] = None):
# Accept Arm-specific partition stage subclass
if partition_stage is None:
arm_partitioner = create_partitioner(self.compile_spec)
Expand All @@ -362,7 +423,7 @@ def partition(self, partition_stage: Optional[XnnpackPartitionStage] = None):
@no_type_check
def to_edge_transform_and_lower(
self,
to_edge_and_lower_stage: Optional[XnnpackToEdgeTransformAndLower] = None,
to_edge_and_lower_stage: Optional[BaseStages.ToEdgeTransformAndLower] = None,
generate_etrecord: bool = False,
# Force the optional tuning knobs to be keyword-only for readability/back-compat.
*,
Expand Down Expand Up @@ -391,6 +452,7 @@ def to_edge_transform_and_lower(
edge_compile_config,
constant_methods=self.constant_methods,
transform_passes=self.transform_passes,
compile_spec=self.compile_spec,
)
else:
if partitioners is not None:
Expand All @@ -402,7 +464,9 @@ def to_edge_transform_and_lower(
)

@no_type_check
def to_executorch(self, to_executorch_stage: Optional[XnnpackToExecutorch] = None):
def to_executorch(
self, to_executorch_stage: Optional[BaseStages.ToExecutorch] = None
):
# Allow custom ExecuTorch stage subclass
if to_executorch_stage is None:
to_executorch_stage = ToExecutorch()
Expand All @@ -411,7 +475,7 @@ def to_executorch(self, to_executorch_stage: Optional[XnnpackToExecutorch] = Non
@no_type_check
def serialize(
self,
serialize_stage: Optional[Serialize] = None,
serialize_stage: Optional[BaseStages.Serialize] = None,
# Keep timeout keyword-only so positional usage matches the base class.
*,
timeout: int = 480,
Expand All @@ -428,6 +492,11 @@ def serialize(
def is_quantized(self) -> bool:
return self.stages[StageType.QUANTIZE] is not None

def run_passes(self, run_passes_stage: Optional[BaseStages.RunPasses] = None):
if run_passes_stage is None:
run_passes_stage = RunPasses()
return super().run_passes(run_passes_stage)

def _get_input_and_stages(
self, inputs, stage, reference_stage_type, run_eager_mode: bool
):
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/tester/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tempfile
from typing import Optional

import executorch.backends.xnnpack.test.tester.tester as tester
import executorch.backends.test.harness.stages as BaseStages

import torch.fx

Expand All @@ -27,7 +27,7 @@
logger = logging.getLogger(__name__)


class Serialize(tester.Serialize):
class Serialize(BaseStages.Serialize):
def __init__(
self,
compile_spec: ArmCompileSpec,
Expand Down
Loading