From 00fa742cdd4fd31d0d7748ac6fc3baa33b7c7c94 Mon Sep 17 00:00:00 2001 From: Arik Horodniceanu Date: Thu, 30 Apr 2026 14:06:28 -0700 Subject: [PATCH] Qualcomm AI Engine Direct - Updating Claude skill for new op development --- .claude/skills/qualcomm/new_op_development.md | 453 +++++++----------- 1 file changed, 175 insertions(+), 278 deletions(-) diff --git a/.claude/skills/qualcomm/new_op_development.md b/.claude/skills/qualcomm/new_op_development.md index dc639655257..6e1abcf77f6 100644 --- a/.claude/skills/qualcomm/new_op_development.md +++ b/.claude/skills/qualcomm/new_op_development.md @@ -1,361 +1,258 @@ -# New Op Development +# New Op Development — QNN/HTP Backend -Full reference: `backends/qualcomm/builders/README.md` (op builder) and `backends/qualcomm/quantizer/README.md` (quantizer annotation). +## Decision Tree -## Overview - -Adding a new op requires three steps: -1. Implement the op builder (`builders/op_*.py`) -2. Register quantizer annotation (`quantizer/annotators/`) -3. Add unit tests (`tests/`) - -**Important**: If the torch op requires **multiple QNN ops** to implement (e.g., no direct QNN equivalent), use a **decompose pass** instead of creating multiple ops in a single builder. Skip Steps 3–6 and follow the **Decompose Pass Approach** section at the bottom of this file. +1. **QNN has a native op?** → Native builder approach (Steps 1–8) +2. **No native op, needs multiple QNN ops?** → Decompose pass approach --- ## Step 1: Identify the Unsupported Op -Run the model through the QNN backend. A missing op surfaces as: - -``` -KeyError: 'aten.native_layer_norm.default' -``` - -To trace back to the source PyTorch layer: - -```python -from executorch.backends.qualcomm.utils.utils import capture_program - -prog = capture_program(MyModel(), example_inputs) -for node in prog.exported_program.graph.nodes: - if node.op == "call_function" and node.target.__name__ == 'aten.native_layer_norm.default': - print(node.meta["source_fn_stack"]) -``` - ---- +Missing ops surface as `KeyError: 'aten.my_op.default'` when running through QNN backend. ## Step 2: Check Operator Spec -- **QNN side**: [Operator Definitions](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/MasterOpDef.html) — check IO order, required vs optional tensors, parameter names and shapes -- **PyTorch side**: [ATen Operator Definitions](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native) — map PyTorch args to QNN IO/params -- **Fallback search**: [Supported Ops table](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/SupportedOps.html) -- **Header reference**: `$QNN_SDK_ROOT/include/QNN/QnnOpDef.h` — authoritative string literals +- [Master Op Definitions](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/MasterOpDef.html) — IO order, params, shapes +- [HTP Op Def Supplement](https://docs.qualcomm.com/doc/80-63442-10/topic/HtpOpDefSupplement.html) — HTP-specific constraints & supported dtypes +- [Supported Ops table](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/SupportedOps.html) +- `$QNN_SDK_ROOT/include/QNN/QnnOpDef.h` — authoritative string literals +- [ATen native ops](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native) — PyTorch arg mapping ---- +**⚠️ Caveats:** +- An op in the Master def may **not exist** in the HTP supplement → not available on HTP +- HTP docs may claim a dtype is supported but **fail at runtime** → always test on-device -## Step 3: Add Op Constant - -In `builders/qnn_constants.py`, add a dataclass (alphabetical order): +## Step 3: Add Op Constant (`builders/qnn_constants.py`) ```python @dataclass(init=False, frozen=True) -class OpLayerNorm: - op_name: str = "LayerNorm" - param_epsilon = "epsilon" - param_axes = "axes" +class OpMyOp: + op_name: str = "MyOp" # Must match QnnOpDef.h exactly + param_axis: str = "axis" + param_epsilon: str = "epsilon" ``` -String values must exactly match `QnnOpDef.h`. - ---- - -## Step 4: Implement the Builder - -Create `builders/op_layer_norm.py`: +## Step 4: Implement Builder (`builders/op_my_op.py`) ```python -import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager -import numpy as np -import torch -from executorch.backends.qualcomm.utils.constants import QCOM_DATA -from .node_visitor import NodeVisitor -from .node_visitor_manager import register_node_visitor -from .qnn_constants import OpLayerNorm, QNN_OP_PACKAGE_NAME_QTI_AISW -from .utils import get_parameter - @register_node_visitor -class LayerNormVisitor(NodeVisitor): - target = ["aten.native_layer_norm.default"] - - def __init__(self, *args) -> None: - super().__init__(*args) +class MyOpVisitor(NodeVisitor): + target = ["aten.my_op.default"] # Must be a list def define_node(self, node, nodes_to_wrappers): - # 1. Input activation - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) - input_tensor_wrapper = self.define_tensor( - input_node, node, input_tensor, - PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, - ) - - # 2. Weight (gamma) and bias (beta) — STATIC tensors - weight_node = self.get_node(node.args[2]) - weight_tensor = get_parameter(weight_node, self.edge_program) - weight_tensor_wrapper = self.define_tensor( - weight_node, node, weight_tensor, - PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) - - bias_node = self.get_node(node.args[3]) - bias_tensor = get_parameter(bias_node, self.edge_program) - bias_tensor_wrapper = self.define_tensor( - bias_node, node, bias_tensor, - PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) - - # 3. Parameters - normalized_shapes = node.args[1] - if len(normalized_shapes) != 1: - print("QNN only supports normalized output with rank 1") - return - axes = [len(input_tensor.shape) - 1] - axes_shape = [len(axes)] - epsilon = node.args[4] - - # 4. Output - output_tensor = self.get_tensor(node, node, 0) - output_tensor_wrapper = self.define_tensor( - node, node, output_tensor, - PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, - ) - - # 5. Build op - layer_norm_op = PyQnnManager.PyQnnOpWrapper( - node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpLayerNorm.op_name, - ) - layer_norm_op.AddInputTensors( - [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper] - ) - layer_norm_op.AddOutputTensors([output_tensor_wrapper]) - layer_norm_op.AddScalarParam( - OpLayerNorm.param_epsilon, - PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, - {QCOM_DATA: np.float32(epsilon)}, - ) - layer_norm_op.AddTensorParam( - OpLayerNorm.param_axes, - PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - len(axes_shape), axes_shape, - np.array(axes, dtype=np.uint32), - True, - ) - return layer_norm_op + input_wrapper = self.define_tensor(input_node, node, input_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers) + + output_tensor = self.get_tensor(node, node) + output_wrapper = self.define_tensor(node, node, output_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers) + + op = PyQnnManager.PyQnnOpWrapper(node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpMyOp.op_name) + op.AddInputTensors([input_wrapper]) + op.AddOutputTensors([output_wrapper]) + op.AddScalarParam(OpMyOp.param_axis, PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(axis)}) + return op # Return None → op falls back to CPU ``` -Key notes: -- `target` must be a list (multiple targets can share one visitor) -- Use `QNN_TENSOR_TYPE_NATIVE` for activations, `QNN_TENSOR_TYPE_STATIC` for weights/biases -- `define_tensor` handles `APP_READ`/`APP_WRITE` detection internally — always pass `NATIVE` -- `wrapper_idx` needed when node output is a tuple (e.g. split ops) -- Return `None` to signal validation failure → op falls back to CPU +**Key patterns:** +- `QNN_TENSOR_TYPE_NATIVE` for activations, `QNN_TENSOR_TYPE_STATIC` for weights/params +- `wrapper_idx=i` for multi-output ops (tuples); companion `getitem` skip op handles indexing +- Negative dims: `dim = dim % len(shape)` (QNN requires positive axes) +- Axis remapping: `if QCOM_AXIS_ORDER in node.meta: dim = node.meta[QCOM_AXIS_ORDER].index(dim)` +- Static params: `weight = get_parameter(self.get_node(node.args[1]), self.edge_program)` +- Scalar params → `AddScalarParam`; Array params → `AddTensorParam` +- Data types: axis/dims=`UINT_32`, epsilon=`FLOAT_32`, booleans=`BOOL_8` +- Int64 index tensors: use `.to(torch.int32)` in builder + add op to `I64_IN_OPS` in `_passes/i64_to_i32.py` for CPU fallback safety (see `op_gather.py` pattern) ---- +## Step 5: Register Builder (`builders/__init__.py`) + +Add `op_my_op` to both `from . import (...)` and `__all__ = [...]` (alphabetical). -## Step 5: Register the Builder +## Step 6: Add Quantizer Annotation -In `builders/__init__.py` (alphabetical order): +Add to BOTH `quantizer/annotators/htp_rules.py` AND `quantizer/annotators/lpai_rules.py`: ```python -from . import ( - ... - op_layer_norm, - ... -) -__all__ = [..., op_layer_norm, ...] +@register_annotator([torch.ops.aten.my_op.default], QnnConstants.OpMyOp.op_name) +class MyOp(GeneralOpDef): + pass # Default: annotate_single_in_single_out ``` ---- - -## Step 6: Add Quantizer Annotation +**Annotation function selection:** -In `quantizer/annotators/{backend}_rules.py`: +| Op type | Function | When | +|---------|----------|------| +| Compute (new scale) | `annotate_single_in_single_out` | Default — most ops | +| Pass-through (`is_math_invariant`) | `annotate_in_out_obs_sharing_op` + fallback `annotate_single_in_share_out` | Reshape, Permute, Squeeze, Gather | +| Two data inputs (same quant) | Custom `annotate` with `SharedQuantizationSpec` | Scatter, where both data+src need same spec | +| Two inputs | `annotate_binary` | Add, Mul, Sub | +| Conv/Linear (weight+bias) | `annotate_conv` | Convolution, Linear | +| Skip (no QNN mapping) | `qnn_op=None` | getitem, index_copy | +**Custom multi-input annotator** (e.g., scatter where args[0] and args[3] are both data tensors): ```python -@register_annotator( - [torch.ops.aten.native_layer_norm.default], - QnnConstants.OpLayerNorm.op_name, -) -class LayerNormAnnotator(GeneralOpDef): +@register_annotator([torch.ops.aten.scatter.src], qnn_op=None) +class ScatterElements(GeneralOpDef): @staticmethod def annotate(node, quantization_config): - annotate_single_in_single_out(node, quantization_config) + if _is_annotated([node]): return + input_qspec_map = {} + input_act = node.args[0] + input_qspec_map[input_act] = quantization_config.input_activation + if isinstance(node.args[3], Node) and _is_float_tensor(node.args[3]): + input_qspec_map[node.args[3]] = SharedQuantizationSpec((input_act, node)) + output_qspec = SharedQuantizationSpec((input_act, node)) if _is_float_tensor(node) else None + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, output_qspec=output_qspec, _annotated=True) ``` -- Use `qnn_op=None` for skip ops (e.g. `operator.getitem`) -- `annotate_single_in_single_out` covers most cases; implement custom logic for multi-input ops - -Full annotation tutorial: `backends/qualcomm/quantizer/README.md` - -### Choosing the right annotate function - -The QNN backend validates quantization constraints via `backend_opinfo` (QNN SDK ≥ 2.41). If validation fails with: - -``` -ValueError: Validation failed for node with target aten..default -``` - -Check the warning log above it — it will say which constraint failed. The most common case is `is_math_invariant=True`, which means the op does not change values (only rearranges data), so input and output **must share the same quantization parameters**. - -| Op type | annotate function | Example ops | -|---------|-------------------|-------------| -| General (input → output with new scale) | `annotate_single_in_single_out` | LayerNorm, Conv2d | -| Pass-through (rearranges data only) | `annotate_in_out_obs_sharing_op` + fallback | Reshape, ChannelShuffle, PixelShuffle | -| Multi-input | `annotate_binary` | Add, Mul | - -For **pass-through ops** (reshape, shuffle, permute — ops where `is_math_invariant=True`), override `annotate` like this: +## Step 7: Add Layout Transform Registration (`_passes/layout_transform.py`) +Add op to `layout_agnostic_ops` (most ops) or `layout_sensitive_ops` (conv, pool, etc.): ```python -@register_annotator( - [torch.ops.aten.channel_shuffle.default], QnnConstants.OpChannelShuffle.op_name -) -class ChannelShuffle(GeneralOpDef): - @staticmethod - def annotate(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_in_out_obs_sharing_op(node, quantization_config) - if not _is_annotated([node]): - annotate_single_in_share_out(node, quantization_config) +exir_ops.edge.aten.my_op.default, ``` -`annotate_in_out_obs_sharing_op` shares the input's observer with the output (satisfies `is_math_invariant`). The fallback `annotate_single_in_share_out` handles the case where the input node is not yet annotated. - ---- - -## Step 7: Add Unit Tests - -In `tests/models.py` (alphabetical order): +## Step 8: Add Unit Tests +**Model** in `tests/models.py` (alphabetical, parameterize variants): ```python -class LayerNorm(torch.nn.Module): - def __init__(self): +class MyOp(torch.nn.Module): + def __init__(self, param=0): super().__init__() - self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6) - + self.param = param def forward(self, x): - return self.layer_norm(x) + return torch.my_op(x, self.param) ``` -In `tests/test_qnn_delegate.py`, add to both `TestQNNFloatingPointOperator` and `TestQNNQuantizedOperator` (alphabetical order): - +**Tests** in `tests/test_qnn_delegate.py` — add to BOTH `TestQNNFloatingPointOperator` and `TestQNNQuantizedOperator`: ```python -def test_qnn_backend_layer_norm(self): - module = LayerNorm() - sample_input = (torch.randn(196, 768),) - module = self.get_qdq_module(module, sample_input) # quantized only - self.lower_module_and_test_output(module, sample_input) +def test_qnn_backend_my_op(self): + test_comb = [{ + QCOM_MODULE: [MyOp(), MyOp(param=1)], + QCOM_SAMPLE_INPUTS: [(torch.randn(3, 4),), (torch.randn(3, 4, dtype=torch.float16),)], + }] + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + self.lower_module_and_test_output(module, sample_input) ``` -Expected result: 1 delegated node, only placeholders/output nodes remain outside the delegate. - ---- - -## Step 8: Prevent Decomposition (if needed) - -Some torch ops are in ExecuTorch's default decomposition table and will be broken into primitives **before** the QNN partitioner sees them. If QNN has a native op for it, you must explicitly skip decomposition. - -**Check first** with a quick Python snippet (run from the executorch root with the `executorch` conda env active): - +**Quantized test** — use separate variable to avoid overwriting module: ```python -import torch -from executorch.exir.tracer import _default_decomposition_table - -decomp_table = _default_decomposition_table() -op = torch.ops.aten.channel_shuffle.default -print(op in decomp_table) # True → will be decomposed +qdq_module = self.get_qdq_module(module, sample_input) +self.lower_module_and_test_output(qdq_module, sample_input) ``` -Output: -``` -True # in ExecuTorch decomp table +**Test data rules:** +- No duplicate indices for scatter/gather with `reduction=NONE` +- Deterministic inputs for precision-sensitive decompositions (avoid boundary values) +- Bounded inputs for ops with singularities (tan, reciprocal): `torch.rand() * 2 - 1` + +**Run on-device:** +```bash +python backends/qualcomm/tests/test_qnn_delegate.py \ + -k TestQNNFloatingPointOperator.test_qnn_backend_my_op \ + --model SM8750 --host --device --build_folder build-android ``` -If `True`, add the op to `get_skip_decomp_table()` in `partition/utils.py` (alphabetical order): +Always ask user for `--model`, `--host`, `--device`, `--build_folder` values. -```python -def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: - do_not_decompose = [ - torch.ops.aten.adaptive_avg_pool2d.default, - torch.ops.aten.channel_shuffle.default, # ← add here - torch.ops.aten.col2im.default, - ... - ] -``` +## Step 9: Prevent Decomposition (if needed) -**Verification**: After adding, re-run the tests. The partitioner log should show: +If the ATen op exists in ExecuTorch's decomp table and you have a builder for it: +- Add to `partition/utils.py` → `get_skip_decomp_table()` +- Remove from `partition/common_defs.py` → `to_be_implemented_operator` if listed there -``` -[QNN Partitioner Op Support]: aten.channel_shuffle.default | True -``` +## Step 10: Update Documentation -If the op was decomposed (not in skip table), the partitioner would never see `aten.channel_shuffle.default` and the test would still pass but via decomposed primitives — not the native QNN op. +- `builders/README.md` — Update QNN ops table (✗ → ✓) and add to "Additional Operators Supported via Passes" table if using decomposition --- -## Decompose Pass Approach (for ops without direct QNN equivalent) +## Decompose Pass Approach -When a torch op has **no direct QNN equivalent** and requires multiple QNN ops to implement, use a **decompose pass** to rewrite the graph into primitive ops that QNN already supports. This is preferred over creating multiple ops in a single builder. +Use when QNN has **no native op** — decompose into supported primitives. -**Reference**: `backends/qualcomm/_passes/decompose_linalg_vector_norm.py` +### Approach A: Module Export +**Ref:** `_passes/decompose_linalg_vector_norm.py`. Write a `torch.nn.Module`, export, merge graph via `merge_decomposed_graph`. Simple but may produce unexpected ops. -### Pattern +### Approach B: Direct Graph Manipulation (RECOMMENDED) +**Ref:** `_passes/decompose_remainder.py`, `_passes/decompose_log_variants.py`. ```python -# 1. Define a torch.nn.Module that implements the op using supported primitives -class MyOpDecomposed(torch.nn.Module): - def __init__(self, param): +class DecomposeMyOp(ExportPass): + def __init__(self): super().__init__() - self.param = param - - def forward(self, x): - # Use only ops that QNN supports - return torch.some_supported_op(x, self.param) + self.targets = {torch.ops.aten.my_op.default, exir_ops.edge.aten.my_op.default} - -# 2. Create the ExportPass -class DecomposeMyOp(ExportPass): - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + def call(self, graph_module): graph = graph_module.graph + const_cache = {} for node in list(graph.nodes): - if node.target == torch.ops.aten.my_op.default: - param = node.args[1] # extract params from node - model = MyOpDecomposed(param) - ep = torch.export.export(model, (node.args[0].meta["val"],), strict=True) - decomposed_module = ep.run_decompositions().graph_module - + if node.op == "call_function" and node.target in self.targets: + is_edge = isinstance(node.target, EdgeOpOverload) + op = exir_ops.edge.aten.div.Tensor if is_edge else torch.ops.aten.div.Tensor with graph.inserting_before(node): - remap = {"x": node.args[0]} - merge_decomposed_graph( - remap=remap, - target_node=node, - target_graph=graph, - decomposed_graph_module=decomposed_module, - ) - graph.erase_node(node) - + new_node = graph.create_node("call_function", op, (node.args[0],)) + new_node.meta = copy_meta(node.meta) + for user in node.users.copy(): + user.replace_input_with(node, new_node) graph.eliminate_dead_code() graph_module.recompile() return PassResult(graph_module, True) ``` -### Registration +**Critical rules:** (1) handle both dialects via `EdgeOpOverload` check, (2) `copy_meta` on every new node, (3) lift scalars to tensors in edge dialect with `get_const_node`, (4) cache constants with `const_cache`, (5) for bool-output nodes use `callback=lambda m: {**m, "val": m["val"].to(torch.bool)}` in `create_node`. + +### Approach C: Built-in Decomposition Table +**Ref:** `_passes/decompose_triu.py`. Uses `make_fx` + `get_decompositions`. Only works if PyTorch has a registered decomp. -1. Add to `_passes/__init__.py` (alphabetical order): - ```python - from .decompose_my_op import DecomposeMyOp - ``` +### Registration (all decompose passes) +1. `_passes/__init__.py` — import + `__all__` +2. `_passes/qnn_pass_manager.py` — import + `transform_for_annotation_pipeline` + `transform_for_export_pipeline` + `get_capture_program_passes` +3. `_passes/utils.py` — add to `get_passes_dependency_for_capture_program()` with `[RemoveRedundancy]` dependency + +--- + +## Common Gotchas + +- **Op name mismatch**: `aten.clamp`→`ReluMinMax`, `aten.expand`→`Tile`, `aten.select_copy`→`StridedSlice`. Search by functionality. +- **Multi-output ops**: Use `wrapper_idx=i` + `getitem` skip op +- **Negative dims**: QNN needs positive → `dim = dim % len(shape)` +- **QCOM_AXIS_ORDER**: `LayoutTransform` permutes NCHW→NHWC; remap axis with `.index(dim)`. `get_tensor()` auto-permutes data. +- **Int64 indices**: Add to `I64_IN_OPS` in `i64_to_i32.py` + `.to(torch.int32)` in builder (see `op_gather.py`) +- **Recompose passes**: Detect primitive sequences and replace with single native op. Ref: `recompose_pixel_unshuffle.py` +- **`partition/common_defs.py`**: Remove op from `to_be_implemented_operator` when adding support +- **HTP doc bugs**: If runtime fails but docs say supported → test on-device always. + +--- + +## Error Debugging + +| Error | Cause | Fix | +|-------|-------|-----| +| `KeyError: 'aten.my_op.default'` | Builder not registered | Check `builders/__init__.py` + `@register_node_visitor` | +| `was not decomposed or delegated` | Op in skip decomp but partitioner rejected | Check builder `define_node` errors; check `I64_IN_OPS` | +| `QNN_GRAPH_ERROR` / `validateOpConfig failed` | HTP doesn't support config | Check params vs HTP Op Def Supplement | +| `Tensor mismatching datatypes` | Quantized: not all inputs annotated | Use custom annotator with `SharedQuantizationSpec` | +| `ValueError: Validation failed` | Wrong annotation | Check `is_math_invariant`; use `annotate_in_out_obs_sharing_op` | +| `Expected dtype int64 for index` | Op fell back to CPU with int32 index | Add to `I64_IN_OPS` + `.to(int32)` in builder | +| `Numerical mismatch` | Precision issue | Quantized: check quant params. Float: HTP FP16 precision limit | + +**Debug order:** Float test first → then quantized. If float fails → builder/config issue. If only quantized fails → annotation issue. + +--- -2. Add to `_passes/qnn_pass_manager.py` imports and both pipeline methods: - - `transform_for_annotation_pipeline` (before quantizer) - - `transform_for_export_pipeline` (before `to_edge`) +## Quick Reference Checklists -3. Remove the op from `to_be_implemented_operator` in `partition/common_defs.py` +**Native QNN Op:** `qnn_constants.py` → `op_my_op.py` → `builders/__init__.py` → `htp_rules.py` → `lpai_rules.py` → `layout_transform.py` → `tests/models.py` → `test_qnn_delegate.py` → `partition/utils.py` (skip decomp) → `common_defs.py` (remove to_be_implemented) → `builders/README.md` -### Notes -- The decomposed module must only use ops that QNN already supports -- `ep.run_decompositions()` ensures the graph is in edge IR form -- `remap` maps placeholder names in the decomposed graph to actual nodes in the target graph -- No separate quantizer annotation needed — the decomposed ops already have their own annotations +**Decompose Pass:** `_passes/decompose_my_op.py` → `_passes/__init__.py` → `qnn_pass_manager.py` (annotation + export + capture) → `_passes/utils.py` (dependency) → `tests/models.py` → `test_qnn_delegate.py` → `common_defs.py` → `builders/README.md`