diff --git a/backends/arm/test/BUCK b/backends/arm/test/BUCK index 4ec8d9f865e..af1c36a6532 100644 --- a/backends/arm/test/BUCK +++ b/backends/arm/test/BUCK @@ -1,5 +1,5 @@ load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -53,7 +53,7 @@ fbcode_target(_kind = runtime.python_library, name = "arm_tester_serialize", srcs = ["tester/serialize.py"], deps = [ - "//executorch/backends/xnnpack/test/tester:tester", + "//executorch/backends/test/harness:tester", "//executorch/devtools/backend_debug:delegation_info", ] ) @@ -63,7 +63,7 @@ fbcode_target(_kind = runtime.python_library, srcs = glob(["tester/*.py"], exclude = ["tester/serialize.py"]), deps = [ ":common", - "//executorch/backends/xnnpack/test/tester:tester", + "//executorch/backends/test/harness:tester", "//executorch/backends/arm:ethosu", "//executorch/backends/arm/quantizer:lib", "//executorch/backends/arm/tosa:mapping", diff --git a/backends/arm/test/misc/test_bn_relu_folding_qat.py b/backends/arm/test/misc/test_bn_relu_folding_qat.py index 41675b33765..535ab6ea4e4 100644 --- a/backends/arm/test/misc/test_bn_relu_folding_qat.py +++ b/backends/arm/test/misc/test_bn_relu_folding_qat.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -14,7 +14,7 @@ from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.xnnpack.test.tester.tester import Quantize +from executorch.backends.test.harness.tester import Quantize from torch import nn diff --git a/backends/arm/test/ops/test_where.py b/backends/arm/test/ops/test_where.py index 6b519bea08d..40b0847b838 100644 --- a/backends/arm/test/ops/test_where.py +++ b/backends/arm/test/ops/test_where.py @@ -12,6 +12,7 @@ ) from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.quantize import ArmQuantize from executorch.backends.arm.test.tester.test_pipeline import ( EthosU85PipelineINT, OpNotSupportedPipeline, @@ -19,7 +20,6 @@ TosaPipelineINT, VgfPipeline, ) -from executorch.backends.xnnpack.test.tester.tester import Quantize aten_op = "torch.ops.aten.where.self" exir_op = "executorch_exir_dialects_edge__ops_aten_where_self" @@ -269,7 +269,7 @@ def test_where_self_u55_INT_not_delegated(test_module): u55_subset=True, ) pipeline.change_args( - "quantize", Quantize(quantizer, get_symmetric_quantization_config()) + "quantize", ArmQuantize(quantizer, get_symmetric_quantization_config()) ) pipeline.run() diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 45b1b3b8485..7d3a945d54f 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -5,6 +5,8 @@ import copy +import inspect + import logging from collections import Counter, defaultdict @@ -24,7 +26,9 @@ Union, ) -import executorch.backends.xnnpack.test.tester.tester as tester +import executorch.backends.test.harness.stages as BaseStages + +import executorch.backends.test.harness.tester as tester import torch.fx import torch.utils._pytree as pytree @@ -46,7 +50,7 @@ dump_error_output, print_error_diffs, ) -from executorch.backends.arm.test.tester.quantize import ArmQuantize as Quantize +from executorch.backends.arm.test.tester.quantize import ArmQuantize from executorch.backends.arm.test.tester.serialize import Serialize from executorch.backends.arm.tosa import TosaSpecification @@ -62,14 +66,7 @@ from executorch.backends.test.harness.error_statistics import ErrorStatistics from executorch.backends.test.harness.stages import Stage, StageType -from executorch.backends.xnnpack.test.tester import ( - Partition as XnnpackPartitionStage, - Quantize as XnnpackQuantize, - Tester, - ToEdge as XnnpackToEdge, - ToEdgeTransformAndLower as XnnpackToEdgeTransformAndLower, - ToExecutorch as XnnpackToExecutorch, -) + from executorch.devtools.backend_debug import get_delegation_info from executorch.exir import ( @@ -85,8 +82,10 @@ from executorch.exir.lowered_backend_module import LoweredBackendModule from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassType + from executorch.exir.program._program import ( _copy_module, + _transform, _update_exported_program_graph_module, ) from tabulate import tabulate # type: ignore[import-untyped] @@ -143,7 +142,13 @@ def _dump_lowered_modules_artifact( _dump_str(output, path_to_dump) -class Partition(tester.Partition): +def _get_default_edge_compile_config( + skip_dim_order: bool = False, +) -> EdgeCompileConfig: + return EdgeCompileConfig(_check_ir_validity=False, _skip_dim_order=skip_dim_order) + + +class Partition(BaseStages.Partition): def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) artifact = cast(Optional[EdgeProgramManager], self.artifact) @@ -156,7 +161,7 @@ def dump_artifact(self, path_to_dump: Optional[str]): _dump_lowered_modules_artifact(path_to_dump, artifact, graph_module) -class ToEdgeTransformAndLower(tester.ToEdgeTransformAndLower): +class ToEdgeTransformAndLower(BaseStages.ToEdgeTransformAndLower): def __init__( self, partitioners: Optional[List[Partitioner]] = None, @@ -165,10 +170,17 @@ def __init__( transform_passes: Optional[ Union[Sequence[PassType], Dict[str, Sequence[PassType]]] ] = None, + compile_spec: Optional[ArmCompileSpec] = None, ): - super().__init__(partitioners, edge_compile_config) + super().__init__( + default_partitioner_cls=None, + partitioners=partitioners, + edge_compile_config=edge_compile_config + or _get_default_edge_compile_config(), + ) self.constant_methods = constant_methods self.transform_passes = transform_passes + self.partitioners = partitioners or [] def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) @@ -195,7 +207,7 @@ def run( ) -class ToExecutorch(tester.ToExecutorch): +class ToExecutorch(BaseStages.ToExecutorch): def run_artifact(self, inputs): with TosaReferenceModelDispatch(): # Check if the model has mutable buffers. These are not delegated to the backend @@ -214,7 +226,38 @@ def run_artifact(self, inputs): return super().run_artifact(inputs) -class RunPasses(tester.RunPasses): +class RunPasses(BaseStages.RunPasses): + class TesterPassManager: + def __init__( + self, + exported_program: ExportedProgram, + passes: Optional[List[Type[PassType]]] = None, + ) -> None: + self._exported_program = exported_program + self.passes = passes or [] + + @property + def exported_program(self) -> ExportedProgram: + return self._exported_program + + def _instantiate_pass(self, pass_): + if not isinstance(pass_, type): + return pass_ + + if issubclass(pass_, ExportPass): + init_sig = inspect.signature(pass_.__init__) + if "exported_program" in init_sig.parameters: + return pass_(self.exported_program) + return pass_() + + return pass_() + + def transform(self) -> ExportedProgram: + ep = self.exported_program + for pass_ in self.passes: + ep = _transform(ep, self._instantiate_pass(pass_)) + return ep + @no_type_check def __init__( self, @@ -228,7 +271,11 @@ def __init__( passes_with_exported_program ) - super().__init__(pass_list, pass_functions) + super().__init__( + pass_manager_cls=RunPasses.TesterPassManager, + pass_list=pass_list, + pass_functions=pass_functions, + ) def run( self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None @@ -282,7 +329,7 @@ def run_artifact(self, inputs): return self.model.forward(*inputs) -class ArmTester(Tester): +class ArmTester(tester.Tester): def __init__( self, model: torch.nn.Module, @@ -307,7 +354,19 @@ def __init__( self.transform_passes = transform_passes self.constant_methods = constant_methods self.compile_spec = compile_spec - super().__init__(model, example_inputs, dynamic_shapes) + stage_classes = tester.Tester.default_stage_classes() | { + StageType.PARTITION: Partition, + StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower, + StageType.TO_EXECUTORCH: ToExecutorch, + StageType.RUN_PASSES: RunPasses, + StageType.SERIALIZE: Serialize, + } + super().__init__( + model, + example_inputs, + dynamic_shapes=dynamic_shapes, + stage_classes=stage_classes, + ) self.pipeline[StageType.INITIAL_MODEL] = [ StageType.QUANTIZE, StageType.EXPORT, @@ -323,12 +382,12 @@ def __init__( @no_type_check def quantize( self, - quantize_stage: Optional[XnnpackQuantize] = None, + quantize_stage: Optional[BaseStages.Quantize] = None, ): # Same stage type as parent but exposed via module alias if quantize_stage is None: quantizer = create_quantizer(self.compile_spec) - quantize_stage = Quantize( + quantize_stage = ArmQuantize( quantizer, get_symmetric_quantization_config(), ) @@ -337,14 +396,16 @@ def quantize( @no_type_check def to_edge( self, - to_edge_stage: Optional[XnnpackToEdge] = None, + to_edge_stage: Optional[BaseStages.ToEdge] = None, # Keep config keyword-only to avoid positional clashes with legacy calls. *, config: Optional[EdgeCompileConfig] = None, ): # Allow optional config override beyond base signature if to_edge_stage is None: - to_edge_stage = tester.ToEdge(config) + to_edge_stage = BaseStages.ToEdge( + config or _get_default_edge_compile_config() + ) else: if config is not None: to_edge_stage.edge_compile_conf = config @@ -352,7 +413,7 @@ def to_edge( return super().to_edge(to_edge_stage) @no_type_check - def partition(self, partition_stage: Optional[XnnpackPartitionStage] = None): + def partition(self, partition_stage: Optional[BaseStages.Partition] = None): # Accept Arm-specific partition stage subclass if partition_stage is None: arm_partitioner = create_partitioner(self.compile_spec) @@ -362,7 +423,7 @@ def partition(self, partition_stage: Optional[XnnpackPartitionStage] = None): @no_type_check def to_edge_transform_and_lower( self, - to_edge_and_lower_stage: Optional[XnnpackToEdgeTransformAndLower] = None, + to_edge_and_lower_stage: Optional[BaseStages.ToEdgeTransformAndLower] = None, generate_etrecord: bool = False, # Force the optional tuning knobs to be keyword-only for readability/back-compat. *, @@ -391,6 +452,7 @@ def to_edge_transform_and_lower( edge_compile_config, constant_methods=self.constant_methods, transform_passes=self.transform_passes, + compile_spec=self.compile_spec, ) else: if partitioners is not None: @@ -402,7 +464,9 @@ def to_edge_transform_and_lower( ) @no_type_check - def to_executorch(self, to_executorch_stage: Optional[XnnpackToExecutorch] = None): + def to_executorch( + self, to_executorch_stage: Optional[BaseStages.ToExecutorch] = None + ): # Allow custom ExecuTorch stage subclass if to_executorch_stage is None: to_executorch_stage = ToExecutorch() @@ -411,7 +475,7 @@ def to_executorch(self, to_executorch_stage: Optional[XnnpackToExecutorch] = Non @no_type_check def serialize( self, - serialize_stage: Optional[Serialize] = None, + serialize_stage: Optional[BaseStages.Serialize] = None, # Keep timeout keyword-only so positional usage matches the base class. *, timeout: int = 480, @@ -428,6 +492,11 @@ def serialize( def is_quantized(self) -> bool: return self.stages[StageType.QUANTIZE] is not None + def run_passes(self, run_passes_stage: Optional[BaseStages.RunPasses] = None): + if run_passes_stage is None: + run_passes_stage = RunPasses() + return super().run_passes(run_passes_stage) + def _get_input_and_stages( self, inputs, stage, reference_stage_type, run_eager_mode: bool ): diff --git a/backends/arm/test/tester/serialize.py b/backends/arm/test/tester/serialize.py index 7f95446391c..cc19f15b8f0 100644 --- a/backends/arm/test/tester/serialize.py +++ b/backends/arm/test/tester/serialize.py @@ -8,7 +8,7 @@ import tempfile from typing import Optional -import executorch.backends.xnnpack.test.tester.tester as tester +import executorch.backends.test.harness.stages as BaseStages import torch.fx @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) -class Serialize(tester.Serialize): +class Serialize(BaseStages.Serialize): def __init__( self, compile_spec: ArmCompileSpec,