From 9b586c7795f362ca591e30dfc4e65f2cf9fb6748 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Wed, 18 Mar 2026 21:35:48 -0700 Subject: [PATCH 1/3] Add NHWC version of max_pool2d (#18239) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/18239 As titled. Should perform better and also allow removing some permutes when convolutions are also moved to channel last. Differential Revision: D96869747 Reviewed By: hsharma35 --- backends/cadence/aot/functions.yaml | 9 +- backends/cadence/aot/ops_registrations.py | 55 ++++++- backends/cadence/aot/quantizer/patterns.py | 7 +- backends/cadence/aot/ref_implementations.py | 35 ++++- backends/cadence/aot/replace_ops.py | 62 ++++++++ .../aot/tests/test_replace_ops_passes.py | 54 +++++++ .../operators/op_quantized_max_pool2d.cpp | 12 +- .../operators/op_quantized_max_pool2d.h | 2 +- .../op_quantized_max_pool2d_nhwc.cpp | 136 ++++++++++++++++++ .../operators/op_quantized_max_pool2d_nhwc.h | 30 ++++ .../cadence/generic/operators/targets.bzl | 12 ++ 11 files changed, 397 insertions(+), 17 deletions(-) create mode 100644 backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp create mode 100644 backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 80de190fedf..1c4dd3e06f3 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -309,10 +309,15 @@ - arg_meta: null kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out -- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::generic::quantized_max_pool2d_out + kernel_name: impl::generic::quantized_max_pool2d_nchw_out + +- func: cadence::quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_max_pool2d_nhwc_out - func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 601d54fe49b..060702becec 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -214,10 +214,16 @@ def register_fake( ) lib.define( - "quantized_max_pool2d(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor" + "quantized_max_pool2d_nchw(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor" ) lib.define( - "quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_max_pool2d_nhwc(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor" +) +lib.define( + "quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( @@ -2277,8 +2283,8 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta( return input.new_empty(input.size(), dtype=input.dtype) -@register_fake("cadence::quantized_max_pool2d") -def quantized_max_pool2d_meta( +@register_fake("cadence::quantized_max_pool2d_nchw") +def quantized_max_pool2d_nchw_meta( input: torch.Tensor, kernel_size: list[int], stride: list[int], @@ -2318,6 +2324,47 @@ def quantized_max_pool2d_meta( return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype) +@register_fake("cadence::quantized_max_pool2d_nhwc") +def quantized_max_pool2d_nhwc_meta( + input: torch.Tensor, + kernel_size: list[int], + stride: list[int], + padding: list[int], + dilation: list[int], + ceil_mode: bool, +) -> torch.Tensor: + assert ( + len(kernel_size) == 2 + ), f"kernel_size must have 2 elements, got {len(kernel_size)}" + assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}" + assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}" + assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}" + assert ( + len(input.size()) == 4 + ), f"input must be 4D (N, H, W, C), got {len(input.size())}D" + + batch = input.size(0) + height_in = input.size(1) + width_in = input.size(2) + channels = input.size(3) + + height_out_raw = ( + height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 + ) / stride[0] + 1 + width_out_raw = ( + width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 + ) / stride[1] + 1 + + if ceil_mode: + height_out = ceil(height_out_raw) + width_out = ceil(width_out_raw) + else: + height_out = int(height_out_raw) + width_out = int(width_out_raw) + + return input.new_empty([batch, height_out, width_out, channels], dtype=input.dtype) + + @register_fake("cadence::fully_connected") def fully_connected_meta( src: torch.Tensor, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 0d52c004dea..204f066ebf4 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -459,7 +459,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_max_pool2d.default + return torch.ops.cadence.quantized_max_pool2d_nchw.default class MaxPool2dWithoutIndicesPattern(QuantizationPattern): @@ -498,7 +498,10 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_max_pool2d.default + return torch.ops.cadence.quantized_max_pool2d_nchw.default + + +# This is a base class for ReLU # This is a base class for ReLU, since it can be used with two different aten ops diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index ed8b3ca60ae..f985718c150 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1868,8 +1868,8 @@ def rms_norm( return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X) -@impl_tracked(m, "quantized_max_pool2d") -def quantized_max_pool2d( +@impl_tracked(m, "quantized_max_pool2d_nchw") +def quantized_max_pool2d_nchw( input: torch.Tensor, kernel_size: list[int], stride: list[int], @@ -1897,6 +1897,37 @@ def quantized_max_pool2d( ) +@impl_tracked(m, "quantized_max_pool2d_nhwc") +def quantized_max_pool2d_nhwc( + input: torch.Tensor, + kernel_size: list[int], + stride: list[int], + padding: list[int], + dilation: list[int], + ceil_mode: bool, +) -> torch.Tensor: + """ + Quantized max pooling in NHWC layout. + + Converts NHWC→NCHW, performs max pooling, then converts back NCHW→NHWC. + """ + # Convert NHWC [N, H, W, C] to NCHW [N, C, H, W] + input_nchw = input.permute(0, 3, 1, 2).contiguous() + + # Call the NCHW version + output_nchw = quantized_max_pool2d_nchw( + input_nchw, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + # Convert NCHW [N, C, H_out, W_out] back to NHWC [N, H_out, W_out, C] + return output_nchw.permute(0, 2, 3, 1).contiguous() + + @impl_tracked(m, "where_Scalar") def where_Scalar( condition: torch.Tensor, diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 14a35c01baf..6e6e98af267 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -1182,6 +1182,67 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True +@register_cadence_pass(CadencePassAttribute(opt_level=3)) +class ReplaceMaxPool2dWithChannelLastMaxPool2dPass(RemoveOrReplacePassInterface): + """ + Replace NCHW max pooling with NHWC (channel-last) max pooling by adding + permute operations before and after the max pooling. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.cadence.quantized_max_pool2d_nchw.default, + ] + + def _change_nchw_to_nhwc( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert NCHW format to NHWC format.""" + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, [0, 2, 3, 1]), {} + ) + permute_node.meta = node.meta + return permute_node + + def _change_nhwc_to_nchw( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert NHWC format to NCHW format.""" + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, [0, 3, 1, 2]), {} + ) + permute_node.meta = node.meta + return permute_node + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + graph = node.graph + + # Get input node + input_node = cast(torch.fx.Node, node.args[0]) + + with graph.inserting_before(node): + # Convert input from NCHW to NHWC + input_nhwc = self._change_nchw_to_nhwc(graph, input_node) + + # Create the NHWC max pooling with the same args (kernel_size, stride, padding, dilation, ceil_mode) + new_args = (input_nhwc,) + tuple(node.args[1:]) + + new_pool = graph.call_function( + exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default, + new_args, + node.kwargs, + ) + new_pool.meta = node.meta + + # Convert output back from NHWC to NCHW + nchw_output = self._change_nhwc_to_nchw(graph, new_pool) + + # Replace all uses with the final output + node.replace_all_uses_with(nchw_output) + return True + + @register_cadence_pass(CadencePassAttribute(opt_level=3)) class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface): """ @@ -2561,6 +2622,7 @@ class CadenceReplaceOpsInGraph: ReplacePadWithCatPass, ReplaceConstantPadNdWithSlicePass, ReplaceConvWithChannelLastConvPass, + ReplaceMaxPool2dWithChannelLastMaxPool2dPass, ReplaceTrivialConvWithLinear, ReplaceConvWithIm2RowAndLinear, ReplaceTransposedConvWithLinearPass, diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 95d470644a0..5d9f8c0784b 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -36,6 +36,7 @@ ReplaceLinearWithFullyConnectedOpPass, ReplaceLogicalNotBooleanWhereWithWherePass, ReplaceMatmulWithTransposedMatmulPass, + ReplaceMaxPool2dWithChannelLastMaxPool2dPass, ReplaceMMWithAddMMPass, ReplaceMulTensorWithMulAndFullOpsPass, ReplaceNopTransposeOrPermuteWithViewPass, @@ -2586,6 +2587,59 @@ def test_cat_insert_transpose(self) -> None: ) +class TestReplaceMaxPool2dWithChannelLastMaxPool2dPass(unittest.TestCase): + def test_replace_max_pool2d_nchw_with_nhwc(self) -> None: + # Create a graph with a single quantized_max_pool2d_nchw node. + x = torch.randint(0, 100, (1, 3, 8, 8), dtype=torch.int8) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.cadence.quantized_max_pool2d_nchw.default, + args=(x, [2, 2], [2, 2], [0, 0], [1, 1], False), + ) + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_max_pool2d_nchw.default), 1 + ) + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + + # Deepcopy before the pass + original = copy.deepcopy(gm) + + # Apply replacement pass. + p = ReplaceMaxPool2dWithChannelLastMaxPool2dPass() + result = p.call(gm) + self.assertTrue(result.modified) + gm_after_replacement = result.graph_module + + # Check that replacement was made. + self.assertEqual( + count_node( + gm_after_replacement, + exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default, + ), + 1, + ) + self.assertEqual( + count_node( + gm_after_replacement, + exir_ops.edge.cadence.quantized_max_pool2d_nchw.default, + ), + 0, + ) + # Two permutes: one for input NCHW->NHWC, one for output NHWC->NCHW + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), + 2, + ) + + # Validate numerical accuracy + validate( + original, + gm_after_replacement, + (x,), + "ReplaceMaxPool2dWithChannelLastMaxPool2dPass", + ) + + class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase): def _get_slice_empty_gm(self) -> tuple[torch.fx.GraphModule, torch.Tensor]: builder = GraphBuilder() diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp b/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp index b241b0851a9..f843ad84080 100644 --- a/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp @@ -27,7 +27,7 @@ using ::executorch::runtime::KernelRuntimeContext; namespace { template -void quantized_max_pool2d_impl( +void quantized_max_pool2d_nchw_impl( const Tensor& input, IntArrayRef kernel_size, IntArrayRef stride, @@ -98,7 +98,7 @@ void quantized_max_pool2d_impl( } // namespace -Tensor& quantized_max_pool2d_out( +Tensor& quantized_max_pool2d_nchw_out( ET_UNUSED KernelRuntimeContext& ctx, const Tensor& input, IntArrayRef kernel_size, @@ -107,9 +107,9 @@ Tensor& quantized_max_pool2d_out( IntArrayRef dilation, bool ceil_mode, Tensor& output) { -#define typed_quantized_max_pool2d(ctype, dtype) \ +#define typed_quantized_max_pool2d_nchw(ctype, dtype) \ case ScalarType::dtype: { \ - quantized_max_pool2d_impl( \ + quantized_max_pool2d_nchw_impl( \ input, kernel_size, stride, padding, dilation, ceil_mode, output); \ break; \ } @@ -117,14 +117,14 @@ Tensor& quantized_max_pool2d_out( ScalarType dtype = input.scalar_type(); // NOLINTBEGIN(clang-diagnostic-switch-enum) switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d) + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d_nchw) default: ET_DCHECK_MSG( false, "Unhandled dtype %s", torch::executor::toString(dtype)); } // NOLINTEND(clang-diagnostic-switch-enum) -#undef typed_quantized_max_pool2d +#undef typed_quantized_max_pool2d_nchw return output; } diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d.h b/backends/cadence/generic/operators/op_quantized_max_pool2d.h index 07f406a37a7..453dd5a2582 100644 --- a/backends/cadence/generic/operators/op_quantized_max_pool2d.h +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d.h @@ -15,7 +15,7 @@ namespace impl { namespace generic { namespace native { -::executorch::aten::Tensor& quantized_max_pool2d_out( +::executorch::aten::Tensor& quantized_max_pool2d_nchw_out( ::executorch::runtime::KernelRuntimeContext& ctx, const ::executorch::aten::Tensor& input, ::executorch::aten::IntArrayRef kernel_size, diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp new file mode 100644 index 00000000000..d8f0d9e068b --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp @@ -0,0 +1,136 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +namespace { + +template +void quantized_max_pool2d_nhwc_impl( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + ET_UNUSED bool ceil_mode, + Tensor& output) { + const T* __restrict__ in_data = input.const_data_ptr(); + T* __restrict__ out_data = output.mutable_data_ptr(); + + // Input dimensions: [N, H, W, C] + const int64_t batch_size = input.size(0); + const int64_t in_height = input.size(1); + const int64_t in_width = input.size(2); + const int64_t channels = input.size(3); + + // Output dimensions: [N, H_out, W_out, C] + const int64_t out_height = output.size(1); + const int64_t out_width = output.size(2); + + // Pooling parameters + const int64_t kernel_h = kernel_size[0]; + const int64_t kernel_w = kernel_size[1]; + const int64_t stride_h = stride[0]; + const int64_t stride_w = stride[1]; + const int64_t pad_h = padding[0]; + const int64_t pad_w = padding[1]; + const int64_t dilation_h = dilation[0]; + const int64_t dilation_w = dilation[1]; + + for (int64_t n = 0; n < batch_size; ++n) { + for (int64_t oh = 0; oh < out_height; ++oh) { + for (int64_t ow = 0; ow < out_width; ++ow) { + const int64_t ih_start = oh * stride_h - pad_h; + const int64_t iw_start = ow * stride_w - pad_w; + + T* __restrict__ out_ptr = + out_data + ((n * out_height + oh) * out_width + ow) * channels; + + // Initialize all channels to the minimum value. + for (int64_t c = 0; c < channels; ++c) { + out_ptr[c] = std::numeric_limits::lowest(); + } + + // For each kernel position, compute element-wise max across all + // channels. The inner loop over channels is a stride-1 contiguous + // access in NHWC layout, enabling SIMD auto-vectorization. + for (int64_t kh = 0; kh < kernel_h; ++kh) { + const int64_t ih = ih_start + kh * dilation_h; + if (ih < 0 || ih >= in_height) { + continue; + } + for (int64_t kw = 0; kw < kernel_w; ++kw) { + const int64_t iw = iw_start + kw * dilation_w; + if (iw < 0 || iw >= in_width) { + continue; + } + + const T* __restrict__ in_ptr = + in_data + ((n * in_height + ih) * in_width + iw) * channels; + + for (int64_t c = 0; c < channels; ++c) { + out_ptr[c] = std::max(out_ptr[c], in_ptr[c]); + } + } + } + } + } + } +} + +} // namespace + +Tensor& quantized_max_pool2d_nhwc_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + Tensor& output) { +#define typed_quantized_max_pool2d_nhwc(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_max_pool2d_nhwc_impl( \ + input, kernel_size, stride, padding, dilation, ceil_mode, output); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + // NOLINTBEGIN(clang-diagnostic-switch-enum) + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d_nhwc) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + // NOLINTEND(clang-diagnostic-switch-enum) + +#undef typed_quantized_max_pool2d_nhwc + return output; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h new file mode 100644 index 00000000000..2b0c02e4bb7 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_max_pool2d_nhwc_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef kernel_size, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + bool ceil_mode, + ::executorch::aten::Tensor& output); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl index bf1de9e009a..fa6708a188e 100644 --- a/backends/cadence/generic/operators/targets.bzl +++ b/backends/cadence/generic/operators/targets.bzl @@ -225,6 +225,18 @@ def define_common_targets(): visibility = ["PUBLIC"], ) + runtime.cxx_library( + name = "op_quantized_max_pool2d_nhwc", + srcs = ["op_quantized_max_pool2d_nhwc.cpp"], + exported_headers = ["op_quantized_max_pool2d_nhwc.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ":cadence_type_util", + ], + visibility = ["PUBLIC"], + ) + runtime.cxx_library( name = "op_quantized_matmul", srcs = ["op_quantized_matmul.cpp"], From c5dfd18aa361f4cd4a33d7692d4bea5e3754dbfa Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Wed, 18 Mar 2026 21:35:48 -0700 Subject: [PATCH 2/3] Add dedicated HiFi kernel for max pool 2d (#18240) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/18240 As titled. Calls into nnlib directly. Differential Revision: D96874522 Reviewed By: hsharma35 --- .../op_quantized_max_pool2d_nhwc.cpp | 1 + .../op_quantized_max_pool2d_nhwc.cpp | 148 ++++++++++++++++++ backends/cadence/hifi/operators/targets.bzl | 10 ++ 3 files changed, 159 insertions(+) create mode 100644 backends/cadence/hifi/operators/op_quantized_max_pool2d_nhwc.cpp diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp index d8f0d9e068b..cb4a9616394 100644 --- a/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include diff --git a/backends/cadence/hifi/operators/op_quantized_max_pool2d_nhwc.cpp b/backends/cadence/hifi/operators/op_quantized_max_pool2d_nhwc.cpp new file mode 100644 index 00000000000..69c4a3fbc45 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_max_pool2d_nhwc.cpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +Tensor& quantized_max_pool2d_nhwc_out( + KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + Tensor& output) { + // NHWC layout: [N, H, W, C] + const int32_t batch_size = input.size(0); + const int32_t in_height = input.size(1); + const int32_t in_width = input.size(2); + const int32_t channels = input.size(3); + + const int32_t out_height = output.size(1); + const int32_t out_width = output.size(2); + + const int32_t kernel_h = kernel_size[0]; + const int32_t kernel_w = kernel_size[1]; + const int32_t stride_h = stride[0]; + const int32_t stride_w = stride[1]; + const int32_t pad_h = padding[0]; + const int32_t pad_w = padding[1]; + + // Determine NNLIB precision constants based on dtype + ScalarType dtype = input.scalar_type(); + int32_t nnlib_precision; + switch (dtype) { + case ScalarType::Char: // int8 + nnlib_precision = PREC_SYM8S; + break; + case ScalarType::Byte: // uint8 + nnlib_precision = PREC_ASYM8U; + break; + default: + ET_DCHECK_MSG( + false, + "Unsupported dtype %s for HiFi quantized_max_pool2d_nhwc", + torch::executor::toString(dtype)); + return output; + } + + // Compute scratch buffer size for NNLIB maxpool + int32_t scratch_size = xa_nn_maxpool_getsize( + channels, + nnlib_precision, + nnlib_precision, + in_height, + in_width, + kernel_h, + kernel_w, + stride_w, // x_stride + stride_h, // y_stride + pad_w, // x_padding + pad_h, // y_padding + out_height, + out_width, + 0, // inp_data_format: 0 = NHWC + 0); // out_data_format: 0 = NHWC + ET_DCHECK_MSG(scratch_size >= 0, "xa_nn_maxpool_getsize failed"); + + // Allocate aligned scratch memory + void* p_scratch = kernels::allocate_temp_memory(ctx, scratch_size); + + // Process each batch using NNLIB optimized maxpool kernel + for (int32_t n = 0; n < batch_size; ++n) { + const int32_t spatial_size = in_height * in_width * channels; + const int32_t out_spatial_size = out_height * out_width * channels; + + int32_t ret; + if (dtype == ScalarType::Char) { + const int8_t* in_batch = + input.const_data_ptr() + n * spatial_size; + int8_t* out_batch = + output.mutable_data_ptr() + n * out_spatial_size; + + ret = xa_nn_maxpool_8( + out_batch, + in_batch, + in_height, + in_width, + channels, + kernel_h, + kernel_w, + stride_w, // x_stride + stride_h, // y_stride + pad_w, // x_padding + pad_h, // y_padding + out_height, + out_width, + 0, // inp_data_format: NHWC + 0, // out_data_format: NHWC + p_scratch); + } else { + const uint8_t* in_batch = + input.const_data_ptr() + n * spatial_size; + uint8_t* out_batch = + output.mutable_data_ptr() + n * out_spatial_size; + + ret = xa_nn_maxpool_asym8( + out_batch, + in_batch, + in_height, + in_width, + channels, + kernel_h, + kernel_w, + stride_w, // x_stride + stride_h, // y_stride + pad_w, // x_padding + pad_h, // y_padding + out_height, + out_width, + 0, // inp_data_format: NHWC + 0, // out_data_format: NHWC + p_scratch); + } + ET_DCHECK_MSG(ret == 0, "HiFi xa_nn_maxpool failed"); + } + + return output; +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index 9753051bf72..1ea57862cf6 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -632,6 +632,16 @@ def define_common_targets(): compatible_with = ["ovr_config//cpu:xtensa"], ) + runtime.cxx_library( + name = "op_quantized_max_pool2d_nhwc", + srcs = ["op_quantized_max_pool2d_nhwc.cpp"], + exported_headers = ["operators.h"], + platforms = CXX, + deps = COMMON_DEPS, + visibility = ["PUBLIC"], + compatible_with = ["ovr_config//cpu:xtensa"], + ) + runtime.cxx_library( name = "op_quantized_relu_asym8s_asym8s_per_tensor_out", srcs = ["op_quantized_relu_asym8s_asym8s_per_tensor_out.cpp"], From fc5d0ba0fdbe7da3d1f404fd76db14bf8666faeb Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Wed, 18 Mar 2026 21:39:06 -0700 Subject: [PATCH 3/3] Update permute removal pass to handle binary operations, and cleanup better (#18256) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/18256 As titled. It is currently not cleaning up as much as it should, and the pass is only capable of handling single input cases. Result: from 9 to 1 (minimum by construction) permutes on Wake Gesture. Reviewed By: abeakkas Differential Revision: D96940254 --- backends/cadence/aot/BUCK | 1 + backends/cadence/aot/fuse_ops.py | 5 ++++- backends/cadence/aot/passes.py | 2 ++ backends/cadence/aot/remove_ops.py | 7 ++++++- 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/backends/cadence/aot/BUCK b/backends/cadence/aot/BUCK index c85dc23c4bd..e4bc833c183 100644 --- a/backends/cadence/aot/BUCK +++ b/backends/cadence/aot/BUCK @@ -300,6 +300,7 @@ fbcode_target(_kind = runtime.python_library, ], typing = True, deps = [ + ":fuse_ops", ":ops_registrations", "//caffe2:torch", "//executorch/backends/cadence/aot:pass_utils", diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index 34bf3b11684..e71803c03bb 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -1170,7 +1170,10 @@ def can_fuse_for_chain( return False # checking that permut2(permut1(identity)) == identity, modulo unitary dimensions - input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape + producer_input = cast(torch.fx.Node, producer.args[0]) + if "val" not in producer_input.meta: + return False + input_shape = producer_input.meta["val"].shape ident_dims = list(range(len(input_shape))) # this mapping helps to handle both transpose and permutations f: dict[Any, Callable] = { diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index bb4a8f065d5..647819d91ab 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -25,6 +25,7 @@ from executorch.backends.cadence.aot.remove_ops import ( CadenceRemoveNops, RemoveNopSliceOrViewOpPass, + RemovePermutesAroundElementwiseOps, RemoveRedundantOps, ) from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph @@ -89,6 +90,7 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]: CadenceSimplifyOpsInGraph.passes, FinalizePipeline, FuseFullThenReshapePass, + RemovePermutesAroundElementwiseOps, FuseTransposeOrPermuteOpPairsPass, RemoveNopSliceOrViewOpPass, CompileTimeTypeDispatchPass, diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 8e1d6d1f07e..518c06bb90d 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -14,6 +14,8 @@ import torch import torch.fx + +from executorch.backends.cadence.aot.fuse_ops import FuseTransposeOrPermuteOpPairsPass from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, get_arg, @@ -21,7 +23,6 @@ RemoveOrReplacePassInterface, set_arg, ) - from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform @@ -412,6 +413,9 @@ class Subgraph: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.cadence.quantize_per_tensor.default, exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.cadence.quantized_relu.per_tensor, + exir_ops.edge.cadence.requantize.per_tensor, + exir_ops.edge.cadence.quantized_add.per_tensor, # Ops that require special handling. exir_ops.edge.aten.cat.default, exir_ops.edge.aten.mean.dim, @@ -804,6 +808,7 @@ class CommonRemovePasses: RemoveToOpsPass, RemoveZeroSizedCatArgsPass, RemovePermutesAroundElementwiseOps, + FuseTransposeOrPermuteOpPairsPass, RemoveSqueezeViewBeforeElementwiseOps, RemoveCatFromSliceCopyPass, RemoveCloneOpsTransformImported,