From f75ac40df061eef7a30bf2be7e984959ad507bfb Mon Sep 17 00:00:00 2001 From: Vaclav Novak Date: Fri, 13 Mar 2026 10:15:31 +0100 Subject: [PATCH] feat: do not delegate relu/relu6 if alone in partition + clamp converter bugfix --- .../backend/ir/converter/node_converter.py | 28 +++++++ .../ops_converters/clamp_converter.py | 18 ++--- .../ops_converters/hardtanh_converter.py | 73 ++++++++++++++++-- .../ops_converters/relu_converter.py | 25 +++++- .../nxp/backend/neutron_operator_support.py | 28 ++++++- .../node_converter/test_clamp_converter.py | 23 ++++-- .../node_converter/test_hardtanh_converter.py | 76 +++++++++++++++---- .../node_converter/test_relu_converter.py | 42 +++++++++- backends/nxp/tests/models.py | 11 +++ 9 files changed, 277 insertions(+), 47 deletions(-) diff --git a/backends/nxp/backend/ir/converter/node_converter.py b/backends/nxp/backend/ir/converter/node_converter.py index 623ba97ba73..53cf1936435 100755 --- a/backends/nxp/backend/ir/converter/node_converter.py +++ b/backends/nxp/backend/ir/converter/node_converter.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod +from typing import Callable import torch @@ -181,6 +182,33 @@ def _has_shared_q_params_if_quantized(node: Node) -> bool: # Node not quantized return True + @staticmethod + def is_node_alone_in_partition( + node: Node, partition_list: list[Partition], filter_fn: Callable + ): + """Return True if `node` is the only node in its partition for which `filter_fn` + returns True. + + The function finds the unique partition containing `node` and applies + `filter_fn` to all nodes in that partition. If only one node passes the + predicate — and that node is `node` — the function returns True. + + :param node: The torch.Node to check. + :param partition_list: List of proposed partitions. + :param filter_fn: Predicate applied to nodes in the partition. + `node` is considered alone if it is the only node + for which this predicate returns True. + """ + partitions = [p for p in partition_list if node in p.nodes] + if len(partitions) != 1: + return False # Should never happen + + partition = partitions[0] + filtered_partition_nodes = list(filter(filter_fn, partition.nodes)) + return ( + len(filtered_partition_nodes) == 1 and filtered_partition_nodes[0] == node + ) + def assert_convertible(self, node): """Assert that the call `is_supported()` returns `True`. Otherwise, raise an exception and print an error message. diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/clamp_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/clamp_converter.py index 82347e38e8a..17d6b0d8b99 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/clamp_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/clamp_converter.py @@ -12,6 +12,9 @@ from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( BuiltinOperator, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + activation_supported_on_target, +) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.fx.passes.infra.partitioner import Partition @@ -77,18 +80,11 @@ def supports_partitioning_result( bounds = cls._get_clamp_bounds(node) if bounds in [cls.SUPPORTED_BOUNDS["Relu"], cls.SUPPORTED_BOUNDS["Relu6"]]: - # If this is the only operator in the partition, NeutronConverter will not create a NeutronNode for some - # reason. - clamp_partitions = [p for p in partition_list if node in p.nodes] - if len(clamp_partitions) != 1: - return False # Should never happen - - clamp_partition = clamp_partitions[0] - non_q_dq_partition_nodes = list( - filter(is_not_qdq_node, clamp_partition.nodes) + is_alone_in_partition = cls.is_node_alone_in_partition( + node, partition_list, filter_fn=is_not_qdq_node ) - if len(non_q_dq_partition_nodes) <= 1: - return False # This would be the only node in the partition, which would cause a crash later on. + if is_alone_in_partition: + return activation_supported_on_target(node, neutron_target_spec) return True diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/hardtanh_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/hardtanh_converter.py index 14d69ed42fb..3fbe2b51638 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/hardtanh_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/hardtanh_converter.py @@ -1,15 +1,21 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, + is_not_qdq_node, NodeConverter, + Partition, ) from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( BuiltinOperator, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + activation_supported_on_target, +) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -17,21 +23,76 @@ class HardTanhConverter(NodeConverter): # Maps possible input parameters of HardTanh to equivalent ReLU-based operators supported by TFLite. - supported_modes_map = { + SUPPORTED_MODES_MAP = { (0.0, 6.0): BuiltinOperator.RELU6, (-1.0, 1.0): BuiltinOperator.RELU_N1_TO_1, (0.0, 1.0): BuiltinOperator.RELU_0_TO_1, (0.0, float("inf")): BuiltinOperator.RELU, } + # Maps possible modes of HardTanh to equivalent ReLU bounds. + SUPPORTED_BOUNDS_MAP = { + "ReluN1To1": (-1.0, 1.0), + "Relu0To1": (0.0, 1.0), + "Relu6": (0.0, 6.0), + "Relu": (0.0, float("inf")), + } + + @staticmethod + def _get_hardtanh_bounds(node: Node) -> tuple[int, int]: + args = node.args + + match len(args): + case 1: + min_val = -1 + max_val = 1 + + case 2: + min_val = args[1] + max_val = 1 + + case 3: + min_val = args[1] + max_val = args[2] + + case _: + # should not occur + min_val = 0 + max_val = 1 + + return min_val, max_val + @staticmethod def _is_supported_in_IR( node: Node, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - _, min_value, max_value = node.args - return (min_value, max_value) in HardTanhConverter.supported_modes_map.keys() + bounds = HardTanhConverter._get_hardtanh_bounds(node) + return bounds in HardTanhConverter.SUPPORTED_MODES_MAP.keys() + + @classmethod + def supports_partitioning_result( + cls, + node: Node, + partition_list: list[Partition], + custom_delegation_options: CustomDelegationOptions, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + ) -> bool: + bounds = HardTanhConverter._get_hardtanh_bounds(node) + + if bounds in [ + cls.SUPPORTED_BOUNDS_MAP["Relu"], + cls.SUPPORTED_BOUNDS_MAP["Relu6"], + ]: + is_alone_in_partition = cls.is_node_alone_in_partition( + node, partition_list, filter_fn=is_not_qdq_node + ) + if is_alone_in_partition: + return activation_supported_on_target(node, neutron_target_spec) + + return True def convert(self, node: Node): """Convert 'aten::hardtanh' to it's supported ReLU equivalent.""" @@ -39,9 +100,9 @@ def convert(self, node: Node): t_op = self._create_tflite_op_with_io_tensors(node) - _, min_value, max_value = node.args + bounds = HardTanhConverter._get_hardtanh_bounds(node) - op = self.supported_modes_map[(min_value, max_value)] + op = self.SUPPORTED_MODES_MAP[bounds] t_op.opcode_index = self.builder.op_code_index_for_op_type(op) self.builder.append_operators([t_op]) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/relu_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/relu_converter.py index eb9d62287c0..5bdc7fc0996 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/relu_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/relu_converter.py @@ -1,15 +1,21 @@ -# Copyright 2024-2025 NXP +# Copyright 2024-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, + is_not_qdq_node, NodeConverter, + Partition, ) from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( BuiltinOperator, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + activation_supported_on_target, + NeutronTargetSpec, +) from torch.fx import Node from torch.nn import Parameter @@ -24,6 +30,23 @@ def _is_supported_in_IR( ) -> bool: return True + @classmethod + def supports_partitioning_result( + cls, + node: Node, + partition_list: list[Partition], + custom_delegation_options: CustomDelegationOptions, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + ) -> bool: + is_alone_in_partition = cls.is_node_alone_in_partition( + node, partition_list, filter_fn=is_not_qdq_node + ) + if is_alone_in_partition: + return activation_supported_on_target(node, neutron_target_spec) + + return True + def convert(self, node: Node): self.assert_convertible(node) diff --git a/backends/nxp/backend/neutron_operator_support.py b/backends/nxp/backend/neutron_operator_support.py index cdb46870b2e..3dafefef484 100644 --- a/backends/nxp/backend/neutron_operator_support.py +++ b/backends/nxp/backend/neutron_operator_support.py @@ -1,9 +1,15 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT +from executorch.backends.nxp.backend.edge_helper import input_tensor +from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( + dims_to_channels_last, +) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from torch.fx import Node def is_tensor_invariant_permutation( @@ -77,3 +83,23 @@ def transposition_is_supported_on_neutron( return True return False + + +def activation_supported_on_target( + node: Node, neutron_target_spec: NeutronTargetSpec +) -> bool: + """This function determines if the current NeutronSoftware properly supports an activation operator represented by the given node. + + :param node: The node representing the activation operator. + :param neutron_target_spec: Object for querying the target platform to retrieve its properties. + """ + input_shape = list(input_tensor(node, 0).shape) + if node.args[0].meta[NXP_NODE_FORMAT].is_channels_first(): + input_shape = dims_to_channels_last(input_shape) + + c = input_shape[-1] + num_macs = neutron_target_spec.get_num_macs() + + # activations in Neutron are delegable only + # if `num_channels` % `num_macs` == 0 + return c % num_macs == 0 diff --git a/backends/nxp/tests/ir/converter/node_converter/test_clamp_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_clamp_converter.py index 5fd71b7219c..8ba3c97d19f 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_clamp_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_clamp_converter.py @@ -100,21 +100,28 @@ def test_convert_clamp__supported(mocker, min, max): # noinspection PyShadowingBuiltins @pytest.mark.parametrize( - "min, max", + "input_shape, min, max", [ - pytest.param(0, 6, id="min = 0, max = 6 (Relu6)"), - pytest.param(0, None, id="min = 0, max = None (Relu)"), + pytest.param( + (1, 7, 9, 11), + 0, + 6, + id="min = 0, max = 6 (Relu6), num_channels not divisible by NUM_MACS, alone in partition", + ), + pytest.param( + (1, 7, 9, 11), + 0, + None, + id="min = 0, max = None (Relu), num_channels not divisible by NUM_MACS, alone in partition", + ), ], ) -def test_convert_clamp__single_op__not_delegated_variants(min, max): - # Test that Clamp representable as Relu6 or Relu is NOT delegated, because it is a single op model which is not - # supported by Neutron. - input_shape = (23,) +def test_convert_clamp__unsupported_shape(input_shape, min, max): model = ClampModule(min, max) delegated_ep = to_quantized_edge_program(model, input_shape).exported_program() - # Make sure the `clamp` was NOT delegated (single op model). + # Make sure the `clamp` was NOT delegated. assert not graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) assert graph_contains_any_of_ops(delegated_ep.graph, [Clamp]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py index fb272a2c650..2c3145bde18 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py @@ -1,4 +1,4 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,17 +10,14 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.hardtanh_converter import ( - HardTanhConverter, -) from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, graph_contains_any_of_ops, - ToNCHWPreprocess, - ToNHWCPreprocess, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, ) -from executorch.backends.nxp.tests.models import Conv2dWithActivation +from executorch.backends.nxp.tests.models import Conv2dWithActivation, HardTanhModule from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -32,6 +29,11 @@ def reseed_model_per_test_run(): np.random.seed(23) +ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate +HardTanh = exir_ops.edge.aten.hardtanh.default +HardTanh_ = exir_ops.edge.aten.hardtanh_.default + + @pytest.mark.parametrize("input_shape", [(1, 3, 128, 128)]) @pytest.mark.parametrize("inplace", [True, False]) def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool, use_qat: bool): @@ -50,15 +52,15 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool, use_qat: bo tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] - ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default] - assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops) + assert not graph_contains_any_of_ops(quantized_program.graph, [HardTanh, HardTanh_]) + assert graph_contains_any_of_ops(quantized_program.graph, [ExecutorchDelegateCall]) input_data = (np.random.random(input_shape) * 50).astype(np.int8) convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data, atol=2.0, ) @@ -66,7 +68,17 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool, use_qat: bo @pytest.mark.parametrize("input_shape", [(1, 3, 16, 16), (1, 3, 32, 32)]) @pytest.mark.parametrize( - "activation_range", list(HardTanhConverter.supported_modes_map.keys()) + "activation_range", + [ + (0.0, 6.0), + (-1.0, 1.0), + (0.0, 1.0), + (0.0, float("inf")), + (0, 6), + (-1, 1), + (0, 1), + (0, float("inf")), + ], ) @pytest.mark.parametrize("inplace", [True, False]) def test_custom_hardtanh_quant( @@ -93,15 +105,47 @@ def test_custom_hardtanh_quant( tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] - ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default] - assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops) + assert not graph_contains_any_of_ops(quantized_program.graph, [HardTanh, HardTanh_]) + assert graph_contains_any_of_ops(quantized_program.graph, [ExecutorchDelegateCall]) input_data = (np.random.random(input_shape) * 50).astype(np.int8) convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data, atol=2.0, ) + + +@pytest.mark.parametrize( + "input_shape, activation_range", + [ + pytest.param( + (3, 7, 15, 7), + (0, float("inf")), + id="activation range: Relu, num_channels not divisible by NUM_MACS, alone in partition", + ), + pytest.param( + (3, 7, 15, 7), + (0, 6), + id="activation range: Relu6, num_channels not divisible by NUM_MACS, alone in partition", + ), + ], +) +def test_hardtanh__unsupported( + mocker, + input_shape: tuple[int], + activation_range: tuple[int, int], + use_qat: bool, +): + min_val, max_val = activation_range + model = HardTanhModule(min_val, max_val) + delegated_ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `hardtanh` was NOT delegated. + assert not graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert graph_contains_any_of_ops(delegated_ep.graph, [HardTanh, HardTanh_]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py index b91720324f2..2ec285d6363 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024,2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -9,6 +9,7 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, + exir_ops, ) from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, @@ -16,6 +17,7 @@ ) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToNCHWPreprocess, ToNHWCPreprocess, ) @@ -30,6 +32,10 @@ def reseed_model_per_test_run(): np.random.seed(23) +ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate +ReLU = exir_ops.edge.aten.relu.default + + class ConvReLUModule(torch.nn.Module): def __init__(self): super().__init__() @@ -68,12 +74,12 @@ def test_relu_with_conv_quant_conversion(mocker, use_qat): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program( + delegated_ep = to_quantized_edge_program( ConvReLUModule(), input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False, - ) + ).exported_program() # Capture generated model tflite_flatbuffers_model, _ = converter_spy.spy_return @@ -85,6 +91,10 @@ def test_relu_with_conv_quant_conversion(mocker, use_qat): (2 * np.random.random(input_shape).astype(np.float32) - 1) * 50 ).astype(np.int8) + # Make sure the `relu` was delegated. + assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert not graph_contains_any_of_ops(delegated_ep.graph, [ReLU]) + convert_run_compare( edge_program, input_data, @@ -99,7 +109,9 @@ def test_relu_with_linear_quant_conversion(mocker, use_qat): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(LinearReLUModule(), input_shape, use_qat=use_qat) + delegated_ep = to_quantized_edge_program( + LinearReLUModule(), input_shape, use_qat=use_qat + ).exported_program() # Capture generated model tflite_flatbuffers_model, _ = converter_spy.spy_return @@ -111,4 +123,26 @@ def test_relu_with_linear_quant_conversion(mocker, use_qat): (2 * np.random.random(input_shape).astype(np.float32) - 1) * 50 ).astype(np.int8) + # Make sure the `relu` was delegated. + assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert not graph_contains_any_of_ops(delegated_ep.graph, [ReLU]) + convert_run_compare(edge_program, input_data, tfl_model=tflite_flatbuffers_model) + + +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param( + (3, 9, 7), id="num_channels not divisible by NUM_MACS, alone in partition" + ), + ], +) +def test_relu_conversion__unsupported(mocker, input_shape): + delegated_ep = to_quantized_edge_program( + ReLUModule(), input_shape + ).exported_program() + + # Make sure the `relu` was NOT delegated. + assert not graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert graph_contains_any_of_ops(delegated_ep.graph, [ReLU]) diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 9357660d6c6..1079269513d 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -191,6 +191,17 @@ def forward(self, x): return x +class HardTanhModule(torch.nn.Module): + def __init__(self, min_val, max_val, inplace=True): + super().__init__() + self.hardtanh = torch.nn.Hardtanh( + min_val=min_val, max_val=max_val, inplace=inplace + ) + + def forward(self, x): + return self.hardtanh(x) + + class SliceTensorConvModule(torch.nn.Module): def __init__(self, dims, starts, ends, in_channels, out_channels): super().__init__()