diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/clamp_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/clamp_converter.py index 82347e38e8a..9294818b90c 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/clamp_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/clamp_converter.py @@ -6,15 +6,16 @@ from executorch.backends.nxp.backend.edge_helper import try_get_arg from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, - is_not_qdq_node, NodeConverter, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + activation_supported_on_target +) from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( BuiltinOperator, ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node -from torch.fx.passes.infra.partitioner import Partition from torch.nn import Parameter @@ -63,34 +64,9 @@ def _is_supported_on_target( if bounds not in ClampConverter.SUPPORTED_BOUNDS.values(): return False - return True - - @classmethod - def supports_partitioning_result( - cls, - node: Node, - partition_list: list[Partition], - custom_delegation_options: CustomDelegationOptions, - neutron_target_spec: NeutronTargetSpec, - parameters_mapping: dict[str, Parameter], - ) -> bool: - bounds = cls._get_clamp_bounds(node) - - if bounds in [cls.SUPPORTED_BOUNDS["Relu"], cls.SUPPORTED_BOUNDS["Relu6"]]: - # If this is the only operator in the partition, NeutronConverter will not create a NeutronNode for some - # reason. - clamp_partitions = [p for p in partition_list if node in p.nodes] - if len(clamp_partitions) != 1: - return False # Should never happen - - clamp_partition = clamp_partitions[0] - non_q_dq_partition_nodes = list( - filter(is_not_qdq_node, clamp_partition.nodes) - ) - if len(non_q_dq_partition_nodes) <= 1: - return False # This would be the only node in the partition, which would cause a crash later on. - - return True + # `clamp` is converted to `relu`, so we need to check if such activation + # is supported. + return activation_supported_on_target(node, neutron_target_spec) def convert(self, node: Node): """Convert the `aten.clamp.default` operator to Neutron IR `Relu*` operators. diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/hardtanh_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/hardtanh_converter.py index 14d69ed42fb..50567360b63 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/hardtanh_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/hardtanh_converter.py @@ -1,4 +1,4 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,9 +7,13 @@ CustomDelegationOptions, NodeConverter, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + activation_supported_on_target +) from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( BuiltinOperator, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -33,6 +37,17 @@ def _is_supported_in_IR( _, min_value, max_value = node.args return (min_value, max_value) in HardTanhConverter.supported_modes_map.keys() + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + # `hardtanh` is converted to `relu`, so we need to check if such activation + # is supported. + return activation_supported_on_target(node, neutron_target_spec) + def convert(self, node: Node): """Convert 'aten::hardtanh' to it's supported ReLU equivalent.""" self.assert_convertible(node) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/leaky_relu_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/leaky_relu_converter.py index e6fcf0e5110..a3b4e912a82 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/leaky_relu_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/leaky_relu_converter.py @@ -7,9 +7,13 @@ CustomDelegationOptions, NodeConverter, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + activation_supported_on_target +) from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.leaky_relu_options import ( LeakyRelu, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -24,6 +28,15 @@ def _is_supported_in_IR( ) -> bool: return True + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + return activation_supported_on_target(node, neutron_target_spec) + def convert(self, node: Node): """Convert the `aten.leaky_relu.default` operator to Neutron IR `LeakyRelu`. The schema is: diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/relu_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/relu_converter.py index eb9d62287c0..14751acbb25 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/relu_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/relu_converter.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 NXP +# Copyright 2024-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,9 +7,13 @@ CustomDelegationOptions, NodeConverter, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + activation_supported_on_target +) from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( BuiltinOperator, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -24,6 +28,15 @@ def _is_supported_in_IR( ) -> bool: return True + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + return activation_supported_on_target(node, neutron_target_spec) + def convert(self, node: Node): self.assert_convertible(node) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/sigmoid_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sigmoid_converter.py index 96e4655d011..9e111477fce 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/sigmoid_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sigmoid_converter.py @@ -1,4 +1,4 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,9 +7,13 @@ CustomDelegationOptions, NodeConverter, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + activation_supported_on_target +) from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( BuiltinOperator, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -24,6 +28,15 @@ def _is_supported_in_IR( ) -> bool: return True + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + return activation_supported_on_target(node, neutron_target_spec) + def convert(self, node: Node): self.assert_convertible(node) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/tanh_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/tanh_converter.py index 427865f8ee7..9f3fc6649d8 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/tanh_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/tanh_converter.py @@ -1,4 +1,4 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,9 +7,13 @@ CustomDelegationOptions, ) from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter +from executorch.backends.nxp.backend.neutron_operator_support import ( + activation_supported_on_target +) from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( BuiltinOperator, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -24,6 +28,15 @@ def _is_supported_in_IR( ) -> bool: return True + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + return activation_supported_on_target(node, neutron_target_spec) + def convert(self, node: Node): self.assert_convertible(node) diff --git a/backends/nxp/backend/neutron_operator_support.py b/backends/nxp/backend/neutron_operator_support.py index cdb46870b2e..7dba1ada0b3 100644 --- a/backends/nxp/backend/neutron_operator_support.py +++ b/backends/nxp/backend/neutron_operator_support.py @@ -1,10 +1,15 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec - +from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT +from executorch.backends.nxp.backend.edge_helper import input_tensor +from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( + dims_to_channels_last, +) +from torch.fx import Node def is_tensor_invariant_permutation( input_shape: list[int], permutation: list[int] @@ -77,3 +82,24 @@ def transposition_is_supported_on_neutron( return True return False + + +def activation_supported_on_target( + node: Node, neutron_target_spec: NeutronTargetSpec +) -> bool: + """This function determines if the current NeutronSoftware properly supports an activation operator represented by the given node. + + :param node: The node representing the activation operator. + :param neutron_target_spec: Object for querying the target platform to retrieve its properties. + """ + input_shape = list(input_tensor(node, 0).shape) + if node.args[0].meta[NXP_NODE_FORMAT].is_channels_first(): + input_shape = dims_to_channels_last(input_shape) + + n = 1 if len(input_shape) == 1 else input_shape[0] + c = input_shape[-1] + num_macs = neutron_target_spec.get_num_macs() + + # activations in Neutron are delegable only + # if `num_channels` % `num_macs` == 0 and `num_batches` == 1 + return n == 1 and c % num_macs == 0 diff --git a/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py index 2e9a1b393ff..c8356fadd6e 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py @@ -30,12 +30,12 @@ def reseed_model_per_test_run(): class ConvBlocksWithAbs(torch.nn.Module): - def __init__(self, conv_in_channels: int = 3): + def __init__(self, conv_in_channels: int = 3, conv_out_channels: int = 8): super().__init__() self.block1 = torch.nn.Sequential( torch.nn.Conv2d( in_channels=conv_in_channels, - out_channels=3, + out_channels=conv_out_channels, kernel_size=(2, 2), stride=(2, 2), ), @@ -64,8 +64,10 @@ def forward(self, x): return x.abs() -def test_conv_abs(mocker, use_qat, input_shape: tuple[int] = (1, 3, 112, 112)): - model = ConvBlocksWithAbs(conv_in_channels=input_shape[1]) +def test_conv_abs(mocker, use_qat, input_shape: tuple[int] = (1, 8, 112, 112)): + model = ConvBlocksWithAbs( + conv_in_channels=input_shape[1], conv_out_channels=input_shape[1] + ) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") diff --git a/backends/nxp/tests/ir/converter/node_converter/test_clamp_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_clamp_converter.py index 5fd71b7219c..61218646571 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_clamp_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_clamp_converter.py @@ -70,7 +70,7 @@ def forward(self, x): ], ) def test_convert_clamp__supported(mocker, min, max): - input_shape = (23,) + input_shape = (24,) model = AddClampModule(min, max) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") @@ -100,21 +100,31 @@ def test_convert_clamp__supported(mocker, min, max): # noinspection PyShadowingBuiltins @pytest.mark.parametrize( - "min, max", + "input_shape, min, max", [ - pytest.param(0, 6, id="min = 0, max = 6 (Relu6)"), - pytest.param(0, None, id="min = 0, max = None (Relu)"), + pytest.param( + (1, 8, 8, 7), + 0, + 6, + id="min = 0, max = 6 (Relu6), num_channels not divisible by NUM_MACS", + ), + pytest.param( + (1, 8, 8, 7), + 0, + None, + id="min = 0, max = None (Relu), num_channels not divisible by NUM_MACS", + ), + pytest.param( + (2, 16, 8, 8), 0, None, id="min = 0, max = None (Relu), num_batches != 1" + ), ], ) -def test_convert_clamp__single_op__not_delegated_variants(min, max): - # Test that Clamp representable as Relu6 or Relu is NOT delegated, because it is a single op model which is not - # supported by Neutron. - input_shape = (23,) +def test_convert_clamp__unsupported_shape(input_shape, min, max): model = ClampModule(min, max) delegated_ep = to_quantized_edge_program(model, input_shape).exported_program() - # Make sure the `clamp` was NOT delegated (single op model). + # Make sure the `clamp` was NOT delegated. assert not graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) assert graph_contains_any_of_ops(delegated_ep.graph, [Clamp]) @@ -129,7 +139,7 @@ def test_convert_clamp__single_op__not_delegated_variants(min, max): ) def test_convert_clamp__single_op__delegated_variants(mocker, min, max): # Test that Clamp representable as Relu0To1 or ReluN1To1 is delegated, even though it is a single op model. - input_shape = (23,) + input_shape = (24,) model = ClampModule(min, max) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") @@ -166,7 +176,7 @@ def test_convert_clamp__single_op__delegated_variants(mocker, min, max): ], ) def test_convert_clamp__no_delegation__unsupported_bounds(min, max): - input_shape = (23,) + input_shape = (24,) model = AddClampModule(min, max) delegated_ep = to_quantized_edge_program(model, input_shape).exported_program() diff --git a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py index fb272a2c650..40688fb1d19 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py @@ -1,4 +1,4 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -17,10 +17,10 @@ from executorch.backends.nxp.tests.executors import ( convert_run_compare, graph_contains_any_of_ops, - ToNCHWPreprocess, - ToNHWCPreprocess, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, ) -from executorch.backends.nxp.tests.models import Conv2dWithActivation +from executorch.backends.nxp.tests.models import Conv2dWithActivation, HardTanhModule from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -32,6 +32,11 @@ def reseed_model_per_test_run(): np.random.seed(23) +ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate +HardTanh = exir_ops.edge.aten.hardtanh.default +HardTanh_ = exir_ops.edge.aten.hardtanh_.default + + @pytest.mark.parametrize("input_shape", [(1, 3, 128, 128)]) @pytest.mark.parametrize("inplace", [True, False]) def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool, use_qat: bool): @@ -50,15 +55,15 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool, use_qat: bo tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] - ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default] - assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops) + assert not graph_contains_any_of_ops(quantized_program.graph, [HardTanh, HardTanh_]) + assert graph_contains_any_of_ops(quantized_program.graph, [ExecutorchDelegateCall]) input_data = (np.random.random(input_shape) * 50).astype(np.int8) convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data, atol=2.0, ) @@ -93,15 +98,42 @@ def test_custom_hardtanh_quant( tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] - ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default] - assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops) + assert not graph_contains_any_of_ops(quantized_program.graph, [HardTanh, HardTanh_]) + assert graph_contains_any_of_ops(quantized_program.graph, [ExecutorchDelegateCall]) input_data = (np.random.random(input_shape) * 50).astype(np.int8) convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data, atol=2.0, ) + + +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param((1, 8, 16, 7), id="num_channels not divisible by NUM_MACS"), + pytest.param((2, 8, 16, 24), id="num_batches != 1"), + ], +) +@pytest.mark.parametrize( + "activation_range", list(HardTanhConverter.supported_modes_map.keys()) +) +def test_hardtanh__unsupported( + mocker, + input_shape: tuple[int], + activation_range: tuple[int, int], + use_qat: bool, +): + min_val, max_val = activation_range + model = HardTanhModule(min_val, max_val) + delegated_ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `hardtanh` was NOT delegated. + assert not graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert graph_contains_any_of_ops(delegated_ep.graph, [HardTanh, HardTanh_]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_leaky_relu_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_leaky_relu_converter.py index 35b58c88608..b5cfea04307 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_leaky_relu_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_leaky_relu_converter.py @@ -76,7 +76,7 @@ def forward(self, x): def test_convert_leaky_relu__alpha(mocker, alpha): _assert_successful_delegation( LeakyReluModule(negative_slope=alpha), - (23,), + (24,), mocker, atol=1, # Common quantization rounding error. ) @@ -84,9 +84,7 @@ def test_convert_leaky_relu__alpha(mocker, alpha): def test_convert_leaky_relu__default_alpha(mocker): _assert_successful_delegation( - LeakyReluModule(), # Leave the default alpha. - (23,), - mocker, + LeakyReluModule(), (24,), mocker, atol=1 # Leave the default alpha. ) @@ -97,20 +95,18 @@ def test_convert_leaky_relu__default_alpha(mocker): ) def test_convert_leaky_relu__inplace(mocker, inplace): _assert_successful_delegation( - LeakyReluModule(inplace=inplace), - (23,), - mocker, + LeakyReluModule(inplace=inplace), (24,), mocker, atol=1 ) @pytest.mark.parametrize( "input_shape", [ - (5,), - (4, 5), - (3, 4, 5), - (2, 3, 4, 5), - (1, 2, 3, 4, 5), + (24,), + (1, 8), + (1, 4, 8), + (1, 3, 4, 8), + (1, 2, 3, 4, 8), ], ids=lambda input_shape: f"{len(input_shape)}D", ) @@ -121,3 +117,20 @@ def test_convert_leaky_relu__ranks(mocker, input_shape: tuple[int, ...]): mocker, atol=1, # Common quantization rounding error. ) + + +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param((1, 8, 7), id="num_channels not divisible by NUM_MACS"), + pytest.param((2, 8, 16), id="num_batches != 1"), + ], +) +def test_convert_leaky_relu__unsupported(mocker, input_shape: tuple[int, ...]): + model = LeakyReluModule() + + delegated_ep = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `leaky_relu` was NOT delegated. + assert not graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert graph_contains_any_of_ops(delegated_ep.graph, [LeakyRelu2D]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py index b91720324f2..c0d2a9f223a 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024,2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -9,6 +9,7 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, + exir_ops, ) from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, @@ -16,11 +17,11 @@ ) from executorch.backends.nxp.tests.executors import ( convert_run_compare, - ToNCHWPreprocess, - ToNHWCPreprocess, + graph_contains_any_of_ops, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, ) -from executorch.backends.nxp.tests.models import Conv2dModule, LinearModule, ReLUModule -from torch.export import ExportedProgram +from executorch.backends.nxp.tests.models import ReLUModule from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -30,32 +31,23 @@ def reseed_model_per_test_run(): np.random.seed(23) -class ConvReLUModule(torch.nn.Module): - def __init__(self): - super().__init__() - - self.conv = Conv2dModule() - self.relu = torch.nn.ReLU() +ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate +ReLU = exir_ops.edge.aten.relu.default - def forward(self, x): - x = self.conv(x) - return self.relu(x) - -class LinearReLUModule(torch.nn.Module): +class ConvReLUModule(torch.nn.Module): def __init__(self): super().__init__() - - self.linear = LinearModule(bias=True) + self.conv = torch.nn.Conv2d(8, 8, 1) self.relu = torch.nn.ReLU() def forward(self, x): - x = self.linear(x) + x = self.conv(x) return self.relu(x) def test_relu_conversion(): - input_shape = (10, 4, 32, 32) + input_shape = (1, 8, 32, 32) edge_program = to_edge_program(ReLUModule(), input_shape).exported_program() input_data = 2 * np.random.random(input_shape).astype(np.float32) - 1 @@ -63,52 +55,54 @@ def test_relu_conversion(): convert_run_compare(edge_program, input_data=input_data) +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param((1, 8, 7), id="num_channels not divisible by NUM_MACS"), + pytest.param((2, 16, 8), id="num_batches != 1"), + ], +) +def test_relu_conversion__unsupported(mocker, input_shape): + delegated_ep = to_quantized_edge_program( + ReLUModule(), input_shape + ).exported_program() + + # Make sure the `relu` was NOT delegated. + assert not graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert graph_contains_any_of_ops(delegated_ep.graph, [ReLU]) + + def test_relu_with_conv_quant_conversion(mocker, use_qat): - input_shape = (1, 4, 32, 32) + input_shape = (1, 8, 32, 32) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program( - ConvReLUModule(), + model = ConvReLUModule() + delegated_ep = to_quantized_edge_program( + model, input_shape, - use_qat=use_qat, use_neutron_for_format_conversion=False, - ) + ).exported_program() - # Capture generated model - tflite_flatbuffers_model, _ = converter_spy.spy_return + # Make sure the `relu` was delegated. + assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert not graph_contains_any_of_ops(delegated_ep.graph, [ReLU]) - # Capture converted program - edge_program: ExportedProgram = converter_spy.call_args.args[1] + intermediate_ep = converter_spy.call_args.args[1] + neutron_ir_model, _ = converter_spy.spy_return input_data = ( (2 * np.random.random(input_shape).astype(np.float32) - 1) * 50 ).astype(np.int8) - convert_run_compare( - edge_program, - input_data, - tfl_model=tflite_flatbuffers_model, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), - ) - - -def test_relu_with_linear_quant_conversion(mocker, use_qat): - input_shape = (256, 32) - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - - # Run conversion - _ = to_quantized_edge_program(LinearReLUModule(), input_shape, use_qat=use_qat) - - # Capture generated model - tflite_flatbuffers_model, _ = converter_spy.spy_return - - # Capture converted program - edge_program: ExportedProgram = converter_spy.call_args.args[1] - input_data = ( - (2 * np.random.random(input_shape).astype(np.float32) - 1) * 50 + np.random.random(input_shape).astype(np.float32) * 256.0 - 128.0 ).astype(np.int8) - convert_run_compare(edge_program, input_data, tfl_model=tflite_flatbuffers_model) + convert_run_compare( + intermediate_ep, + tfl_model=neutron_ir_model, + input_data=input_data, + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py index ad03aa18ded..9cd82b405cf 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py @@ -1,4 +1,4 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,10 +10,12 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, + exir_ops, ) from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToNCHWPreprocess, ToNHWCPreprocess, ) @@ -29,6 +31,10 @@ def reseed_model_per_test_run(): np.random.seed(23) +ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate +Sigmoid = exir_ops.edge.aten.sigmoid.default + + def test_conv_sigmoid(mocker, use_qat, input_shape: tuple[int] = (1, 3, 112, 112)): model = ConvWithSigmoid(conv_in_channels=input_shape[1]) @@ -55,11 +61,11 @@ def test_conv_sigmoid(mocker, use_qat, input_shape: tuple[int] = (1, 3, 112, 112 @pytest.mark.parametrize( "input_shape", [ - pytest.param((10,), id="Scalar"), - pytest.param((10, 25), id="1D"), - pytest.param((10, 25, 25), id="2D"), - pytest.param((10, 3, 25, 25), id="3D"), - pytest.param((10, 3, 25, 25, 25), id="4D"), + pytest.param((24,), id="Scalar"), + pytest.param((1, 24), id="1D"), + pytest.param((1, 25, 24), id="2D"), + pytest.param((1, 3, 25, 24), id="3D"), + pytest.param((1, 3, 25, 25, 24), id="4D"), ], ) def test_sigmoid_only(mocker, use_qat, input_shape): @@ -67,12 +73,35 @@ def test_sigmoid_only(mocker, use_qat, input_shape): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - to_quantized_edge_program(model, input_shape, use_qat=use_qat).exported_program() + delegated_ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() - tflite_flatbuffers_model, io_formats = converter_spy.spy_return + # Make sure the `sigmoid` was delegated. + assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert not graph_contains_any_of_ops(delegated_ep.graph, [Sigmoid]) + + neutron_ir_model = converter_spy.spy_return[0] exported_program: ExportedProgram = converter_spy.call_args.args[1] input_data = (np.random.random(input_shape) * 50).astype(np.int8) convert_run_compare( - exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data + exported_program, tfl_model=neutron_ir_model, input_data=input_data ) + + +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param((1, 8, 7), id="num_channels not divisible by NUM_MACS"), + pytest.param((2, 8, 16), id="num_batches != 1"), + ], +) +def test_convert_leaky_relu__unsupported(mocker, input_shape: tuple[int, ...]): + model = nn.Sigmoid() + + delegated_ep = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `sigmoid` was NOT delegated. + assert not graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert graph_contains_any_of_ops(delegated_ep.graph, [Sigmoid]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py index 10892d28e38..65bd2ea0abf 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py @@ -1,4 +1,4 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -9,7 +9,6 @@ import kgb import numpy as np import torch - from executorch.backends.nxp.nxp_backend import EdgeProgramToIRConverter from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( @@ -18,12 +17,18 @@ ToChannelFirstPreprocess, ToChannelLastPreprocess, ) -from executorch.backends.nxp.tests.models import Conv2dWithActivation + +from executorch.backends.nxp.tests.models import Conv2dWithActivation, TanhModule from executorch.exir.dialects._ops import ops as exir_ops from parameterized import parameterized from torch.export import ExportedProgram +ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate +Tanh = exir_ops.edge.aten.tanh.default +Tanh_ = exir_ops.edge.aten.tanh_.default + + class TestTanhConverter(unittest.TestCase): __test__ = False # Prevent interfering with PyTest tests @@ -45,7 +50,7 @@ def test_conv_tanh( _: str, inplace: bool, use_qat: bool, - input_shape: tuple[int] = (1, 3, 112, 112), + input_shape: tuple[int] = (1, 8, 112, 112), ): with kgb.spy_on( EdgeProgramToIRConverter.convert_program, @@ -74,8 +79,8 @@ def test_conv_tanh( quantized_program.graph_module.lowered_module_0.original_module.graph ) tanh_ops = [ - exir_ops.edge.aten.tanh.default, - exir_ops.edge.aten.tanh_.default, + Tanh, + Tanh_, ] assert graph_contains_any_of_ops(graph=lowered_module_graph, ops=tanh_ops) @@ -88,3 +93,22 @@ def test_conv_tanh( input_data=input_data, atol=2.0, ) + + @parameterized.expand( + input=[ + ((1, 8, 7),), # num_channels not divisible by NUM_MACS + ((2, 8, 16),), # num_batches != 1 + ] + ) + def test_tanh__unsupported( + self, + input_shape: tuple[int], + ): + model = TanhModule() + delegated_ep = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `tanh` was NOT delegated. + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [Tanh_, Tanh]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py index 2a0e69dcd54..04d986a274c 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py @@ -272,7 +272,7 @@ def test_view_copy_w_linear_quant_conversion(mocker, input_shape, new_shape, use @pytest.mark.parametrize( "input_shape, channels_view_out", [ - pytest.param((1, 4, 16, 16), 196, id="4D"), + pytest.param((1, 8, 16, 16), 392, id="4D"), ], ) def test_view_w_conv_linear_quant_conversion( diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 9357660d6c6..49251beee05 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -163,6 +163,16 @@ def forward(self, x): return self.block(x) +class TanhModule(torch.nn.Module): + def __init__(self): + super().__init__() + + self.tanh = nn.Tanh() + + def forward(self, x): + return self.tanh(x) + + class LinearModule(torch.nn.Module): def __init__(self, bias: bool): super().__init__() @@ -415,11 +425,16 @@ def forward(self, x): class Conv2dWithActivation(torch.nn.Module): - def __init__(self, activation: torch.nn.Module | Callable, in_channels: int = 3): + def __init__( + self, + activation: torch.nn.Module | Callable, + in_channels: int = 3, + out_channels=64, + ): super().__init__() self.conv = torch.nn.Conv2d( - in_channels=in_channels, out_channels=64, kernel_size=(3, 3) + in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3) ) self.activation = activation @@ -606,6 +621,17 @@ def forward(self, x): return torch.mean(x, dim=self.dim, keepdim=self.keepdim) +class HardTanhModule(torch.nn.Module): + def __init__(self, min_val, max_val, inplace=True): + super().__init__() + self.hardtanh = torch.nn.Hardtanh( + min_val=min_val, max_val=max_val, inplace=inplace + ) + + def forward(self, x): + return self.hardtanh(x) + + class MeanDimConvModule(torch.nn.Module): def __init__(self, dim, keepdim, out_channels=8): super().__init__()