Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cadence/aot/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ fbcode_target(_kind = runtime.python_library,
],
typing = True,
deps = [
":fuse_ops",
":ops_registrations",
"//caffe2:torch",
"//executorch/backends/cadence/aot:pass_utils",
Expand Down
9 changes: 7 additions & 2 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,15 @@
- arg_meta: null
kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out

- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
- func: cadence::quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_max_pool2d_out
kernel_name: impl::generic::quantized_max_pool2d_nchw_out

- func: cadence::quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_max_pool2d_nhwc_out

- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
Expand Down
5 changes: 4 additions & 1 deletion backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,10 @@ def can_fuse_for_chain(
return False

# checking that permut2(permut1(identity)) == identity, modulo unitary dimensions
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
producer_input = cast(torch.fx.Node, producer.args[0])
if "val" not in producer_input.meta:
return False
input_shape = producer_input.meta["val"].shape
ident_dims = list(range(len(input_shape)))
# this mapping helps to handle both transpose and permutations
f: dict[Any, Callable] = {
Expand Down
55 changes: 51 additions & 4 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,16 @@ def register_fake(
)

lib.define(
"quantized_max_pool2d(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
"quantized_max_pool2d_nchw(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
)
lib.define(
"quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
"quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_max_pool2d_nhwc(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
)
lib.define(
"quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
Expand Down Expand Up @@ -2277,8 +2283,8 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta(
return input.new_empty(input.size(), dtype=input.dtype)


@register_fake("cadence::quantized_max_pool2d")
def quantized_max_pool2d_meta(
@register_fake("cadence::quantized_max_pool2d_nchw")
def quantized_max_pool2d_nchw_meta(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
Expand Down Expand Up @@ -2318,6 +2324,47 @@ def quantized_max_pool2d_meta(
return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype)


@register_fake("cadence::quantized_max_pool2d_nhwc")
def quantized_max_pool2d_nhwc_meta(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
padding: list[int],
dilation: list[int],
ceil_mode: bool,
) -> torch.Tensor:
assert (
len(kernel_size) == 2
), f"kernel_size must have 2 elements, got {len(kernel_size)}"
assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}"
assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}"
assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}"
assert (
len(input.size()) == 4
), f"input must be 4D (N, H, W, C), got {len(input.size())}D"

batch = input.size(0)
height_in = input.size(1)
width_in = input.size(2)
channels = input.size(3)

height_out_raw = (
height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
) / stride[0] + 1
width_out_raw = (
width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
) / stride[1] + 1

if ceil_mode:
height_out = ceil(height_out_raw)
width_out = ceil(width_out_raw)
else:
height_out = int(height_out_raw)
width_out = int(width_out_raw)

return input.new_empty([batch, height_out, width_out, channels], dtype=input.dtype)


@register_fake("cadence::fully_connected")
def fully_connected_meta(
src: torch.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from executorch.backends.cadence.aot.remove_ops import (
CadenceRemoveNops,
RemoveNopSliceOrViewOpPass,
RemovePermutesAroundElementwiseOps,
RemoveRedundantOps,
)
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
Expand Down Expand Up @@ -89,6 +90,7 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]:
CadenceSimplifyOpsInGraph.passes,
FinalizePipeline,
FuseFullThenReshapePass,
RemovePermutesAroundElementwiseOps,
FuseTransposeOrPermuteOpPairsPass,
RemoveNopSliceOrViewOpPass,
CompileTimeTypeDispatchPass,
Expand Down
7 changes: 5 additions & 2 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d.default
return torch.ops.cadence.quantized_max_pool2d_nchw.default


class MaxPool2dWithoutIndicesPattern(QuantizationPattern):
Expand Down Expand Up @@ -498,7 +498,10 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d.default
return torch.ops.cadence.quantized_max_pool2d_nchw.default


# This is a base class for ReLU


# This is a base class for ReLU, since it can be used with two different aten ops
Expand Down
35 changes: 33 additions & 2 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,8 +1868,8 @@ def rms_norm(
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)


@impl_tracked(m, "quantized_max_pool2d")
def quantized_max_pool2d(
@impl_tracked(m, "quantized_max_pool2d_nchw")
def quantized_max_pool2d_nchw(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
Expand Down Expand Up @@ -1897,6 +1897,37 @@ def quantized_max_pool2d(
)


@impl_tracked(m, "quantized_max_pool2d_nhwc")
def quantized_max_pool2d_nhwc(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
padding: list[int],
dilation: list[int],
ceil_mode: bool,
) -> torch.Tensor:
"""
Quantized max pooling in NHWC layout.
Converts NHWC→NCHW, performs max pooling, then converts back NCHW→NHWC.
"""
# Convert NHWC [N, H, W, C] to NCHW [N, C, H, W]
input_nchw = input.permute(0, 3, 1, 2).contiguous()

# Call the NCHW version
output_nchw = quantized_max_pool2d_nchw(
input_nchw,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
ceil_mode=ceil_mode,
)

# Convert NCHW [N, C, H_out, W_out] back to NHWC [N, H_out, W_out, C]
return output_nchw.permute(0, 2, 3, 1).contiguous()


@impl_tracked(m, "where_Scalar")
def where_Scalar(
condition: torch.Tensor,
Expand Down
7 changes: 6 additions & 1 deletion backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

import torch
import torch.fx

from executorch.backends.cadence.aot.fuse_ops import FuseTransposeOrPermuteOpPairsPass
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
get_arg,
register_cadence_pass,
RemoveOrReplacePassInterface,
set_arg,
)

from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
Expand Down Expand Up @@ -412,6 +413,9 @@ class Subgraph:
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.cadence.quantize_per_tensor.default,
exir_ops.edge.cadence.dequantize_per_tensor.default,
exir_ops.edge.cadence.quantized_relu.per_tensor,
exir_ops.edge.cadence.requantize.per_tensor,
exir_ops.edge.cadence.quantized_add.per_tensor,
# Ops that require special handling.
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.mean.dim,
Expand Down Expand Up @@ -804,6 +808,7 @@ class CommonRemovePasses:
RemoveToOpsPass,
RemoveZeroSizedCatArgsPass,
RemovePermutesAroundElementwiseOps,
FuseTransposeOrPermuteOpPairsPass,
RemoveSqueezeViewBeforeElementwiseOps,
RemoveCatFromSliceCopyPass,
RemoveCloneOpsTransformImported,
Expand Down
62 changes: 62 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,67 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
return True


@register_cadence_pass(CadencePassAttribute(opt_level=3))
class ReplaceMaxPool2dWithChannelLastMaxPool2dPass(RemoveOrReplacePassInterface):
"""
Replace NCHW max pooling with NHWC (channel-last) max pooling by adding
permute operations before and after the max pooling.
"""

@property
def targets(self) -> list[EdgeOpOverload]:
return [
exir_ops.edge.cadence.quantized_max_pool2d_nchw.default,
]

def _change_nchw_to_nhwc(
self, graph: torch.fx.Graph, node: torch.fx.Node
) -> torch.fx.Node:
"""Convert NCHW format to NHWC format."""
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, [0, 2, 3, 1]), {}
)
permute_node.meta = node.meta
return permute_node

def _change_nhwc_to_nchw(
self, graph: torch.fx.Graph, node: torch.fx.Node
) -> torch.fx.Node:
"""Convert NHWC format to NCHW format."""
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, [0, 3, 1, 2]), {}
)
permute_node.meta = node.meta
return permute_node

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
graph = node.graph

# Get input node
input_node = cast(torch.fx.Node, node.args[0])

with graph.inserting_before(node):
# Convert input from NCHW to NHWC
input_nhwc = self._change_nchw_to_nhwc(graph, input_node)

# Create the NHWC max pooling with the same args (kernel_size, stride, padding, dilation, ceil_mode)
new_args = (input_nhwc,) + tuple(node.args[1:])

new_pool = graph.call_function(
exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default,
new_args,
node.kwargs,
)
new_pool.meta = node.meta

# Convert output back from NHWC to NCHW
nchw_output = self._change_nhwc_to_nchw(graph, new_pool)

# Replace all uses with the final output
node.replace_all_uses_with(nchw_output)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=3))
class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface):
"""
Expand Down Expand Up @@ -2561,6 +2622,7 @@ class CadenceReplaceOpsInGraph:
ReplacePadWithCatPass,
ReplaceConstantPadNdWithSlicePass,
ReplaceConvWithChannelLastConvPass,
ReplaceMaxPool2dWithChannelLastMaxPool2dPass,
ReplaceTrivialConvWithLinear,
ReplaceConvWithIm2RowAndLinear,
ReplaceTransposedConvWithLinearPass,
Expand Down
54 changes: 54 additions & 0 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ReplaceLinearWithFullyConnectedOpPass,
ReplaceLogicalNotBooleanWhereWithWherePass,
ReplaceMatmulWithTransposedMatmulPass,
ReplaceMaxPool2dWithChannelLastMaxPool2dPass,
ReplaceMMWithAddMMPass,
ReplaceMulTensorWithMulAndFullOpsPass,
ReplaceNopTransposeOrPermuteWithViewPass,
Expand Down Expand Up @@ -2586,6 +2587,59 @@ def test_cat_insert_transpose(self) -> None:
)


class TestReplaceMaxPool2dWithChannelLastMaxPool2dPass(unittest.TestCase):
def test_replace_max_pool2d_nchw_with_nhwc(self) -> None:
# Create a graph with a single quantized_max_pool2d_nchw node.
x = torch.randint(0, 100, (1, 3, 8, 8), dtype=torch.int8)
gm = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.cadence.quantized_max_pool2d_nchw.default,
args=(x, [2, 2], [2, 2], [0, 0], [1, 1], False),
)
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_max_pool2d_nchw.default), 1
)
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)

# Deepcopy before the pass
original = copy.deepcopy(gm)

# Apply replacement pass.
p = ReplaceMaxPool2dWithChannelLastMaxPool2dPass()
result = p.call(gm)
self.assertTrue(result.modified)
gm_after_replacement = result.graph_module

# Check that replacement was made.
self.assertEqual(
count_node(
gm_after_replacement,
exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default,
),
1,
)
self.assertEqual(
count_node(
gm_after_replacement,
exir_ops.edge.cadence.quantized_max_pool2d_nchw.default,
),
0,
)
# Two permutes: one for input NCHW->NHWC, one for output NHWC->NCHW
self.assertEqual(
count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default),
2,
)

# Validate numerical accuracy
validate(
original,
gm_after_replacement,
(x,),
"ReplaceMaxPool2dWithChannelLastMaxPool2dPass",
)


class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase):
def _get_slice_empty_gm(self) -> tuple[torch.fx.GraphModule, torch.Tensor]:
builder = GraphBuilder()
Expand Down
Loading
Loading