From 4a6594dfb643bb42dd7bc5a8c20b1ed2c06b74e2 Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Wed, 25 Feb 2026 10:20:04 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - Python API Refactor --- .ci/scripts/test_qnn_static_llm.sh | 6 +- backends/qualcomm/__init__.py | 3 + .../qualcomm/bc/test_qnn_static_llama_bc.sh | 4 +- backends/qualcomm/builders/README.md | 5 +- backends/qualcomm/debugger/README.md | 47 +- backends/qualcomm/export_utils.py | 871 ++++++++++++++++++ .../runtime/backends/direct_mode/README.md | 2 +- backends/qualcomm/tests/test_qnn_delegate.py | 217 ++--- backends/qualcomm/tests/utils.py | 88 +- backends/qualcomm/utils/utils.py | 20 +- examples/models/llama/export_llama_lib.py | 4 +- examples/qualcomm/README.md | 35 +- examples/qualcomm/custom_op/custom_ops_1.py | 52 +- examples/qualcomm/oss_scripts/albert.py | 84 +- examples/qualcomm/oss_scripts/bert.py | 81 +- examples/qualcomm/oss_scripts/conv_former.py | 45 +- .../qualcomm/oss_scripts/convnext_small.py | 49 +- examples/qualcomm/oss_scripts/cvt.py | 44 +- examples/qualcomm/oss_scripts/deit.py | 46 +- examples/qualcomm/oss_scripts/dino_v2.py | 46 +- examples/qualcomm/oss_scripts/distilbert.py | 81 +- examples/qualcomm/oss_scripts/dit.py | 55 +- .../oss_scripts/efficientSAM/efficientSAM.py | 46 +- examples/qualcomm/oss_scripts/efficientnet.py | 50 +- examples/qualcomm/oss_scripts/esrgan.py | 49 +- examples/qualcomm/oss_scripts/eurobert.py | 97 +- examples/qualcomm/oss_scripts/fastvit.py | 51 +- examples/qualcomm/oss_scripts/fbnet.py | 42 +- examples/qualcomm/oss_scripts/focalnet.py | 45 +- .../oss_scripts/gMLP_image_classification.py | 46 +- .../oss_scripts/llama/artifacts/README.md | 2 +- .../llama/decoder_runtime_evaluator.py | 24 +- examples/qualcomm/oss_scripts/llama/llama.py | 16 +- .../oss_scripts/llama/range_setting_pt2e.py | 3 +- .../llama/wrappers/attention_sink_wrappers.py | 4 +- .../llama/wrappers/llm_wrappers.py | 2 +- .../qualcomm/oss_scripts/llm_utils/README.md | 6 +- .../llm_utils/eval_decoder_model_qnn.py | 27 +- .../llm_utils/qnn_decoder_model_manager.py | 2 +- examples/qualcomm/oss_scripts/maxvit_t.py | 47 +- examples/qualcomm/oss_scripts/mobilevit_v1.py | 54 +- examples/qualcomm/oss_scripts/mobilevit_v2.py | 50 +- examples/qualcomm/oss_scripts/moshi/mimi.py | 97 +- examples/qualcomm/oss_scripts/pvt.py | 45 +- .../qualcomm/oss_scripts/qwen2_5/qwen2_5.py | 49 +- examples/qualcomm/oss_scripts/regnet.py | 45 +- examples/qualcomm/oss_scripts/retinanet.py | 50 +- examples/qualcomm/oss_scripts/roberta.py | 51 +- examples/qualcomm/oss_scripts/squeezenet.py | 45 +- examples/qualcomm/oss_scripts/ssd300_vgg16.py | 50 +- .../qualcomm/oss_scripts/swin_transformer.py | 45 +- examples/qualcomm/oss_scripts/swin_v2_t.py | 48 +- examples/qualcomm/oss_scripts/t5/t5.py | 40 +- examples/qualcomm/oss_scripts/vit_b_16.py | 49 +- .../qualcomm/oss_scripts/whisper/whisper.py | 63 +- .../llama/llama2/qaihub_llama2_7b.py | 25 +- .../llama/llama3/qaihub_llama3_8b.py | 28 +- .../qaihub_stable_diffusion.py | 28 +- .../qualcomm/qaihub_scripts/utils/export.py | 34 +- examples/qualcomm/sample_config.json | 13 + examples/qualcomm/scripts/deeplab_v3.py | 52 +- examples/qualcomm/scripts/edsr.py | 50 +- examples/qualcomm/scripts/inception_v3.py | 46 +- examples/qualcomm/scripts/inception_v4.py | 46 +- .../qualcomm/scripts/mobilebert_fine_tune.py | 46 +- examples/qualcomm/scripts/mobilenet_v2.py | 46 +- examples/qualcomm/scripts/mobilenet_v3.py | 52 +- examples/qualcomm/scripts/torchvision_vit.py | 35 +- examples/qualcomm/scripts/wav2letter.py | 38 +- examples/qualcomm/util_scripts/README.md | 2 +- examples/qualcomm/util_scripts/cli.py | 44 +- .../qualcomm/util_scripts/gen_etrecord.py | 27 +- .../util_scripts/qairt_visualizer_demo.py | 34 +- .../qnn_intermediate_debugger_demo.py | 34 +- examples/qualcomm/utils.py | 757 +-------------- 75 files changed, 2064 insertions(+), 2598 deletions(-) create mode 100644 backends/qualcomm/export_utils.py create mode 100644 examples/qualcomm/sample_config.json diff --git a/.ci/scripts/test_qnn_static_llm.sh b/.ci/scripts/test_qnn_static_llm.sh index 91e37aae7c8..16d7e1615f1 100644 --- a/.ci/scripts/test_qnn_static_llm.sh +++ b/.ci/scripts/test_qnn_static_llm.sh @@ -47,11 +47,11 @@ if [[ "${TASK_NAME}" == "stories_110m" ]]; then $PYTHON_EXECUTABLE -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin # Compile only as weight sharing is not applicable on x86. - $PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_110m --model SM8650 --build_folder build-android/ --executorch_root . --artifact_dir ./stories_110m_pte_size --llama_artifacts . --compile_only + $PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_110m --soc_model SM8650 --build_folder build-android/ --executorch_root . --artifact_dir ./stories_110m_pte_size --llama_artifacts . --compile_only exit_code1=$? # Checks accuracy with weight sharing disabled since x86 does not support weight sharing. - $PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_110m --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./stories_110m_accuracy --llama_artifacts . --enable_x86_64 + $PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_110m --soc_model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./stories_110m_accuracy --llama_artifacts . --enable_x86_64 exit_code2=$? # Check the exit codes and print messages @@ -84,7 +84,7 @@ elif [[ "${TASK_NAME}" == "smollm2_135m" ]]; then if [ -n "$2" ]; then EXTRA_FLAGS="$EXTRA_FLAGS --static_llm_eval_method $2" fi - $PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_static_llm_model --model_name smollm2_135m --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./static_smollm2 --enable_x86_64 $EXTRA_FLAGS + $PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_static_llm_model --model_name smollm2_135m --soc_model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./static_smollm2 --enable_x86_64 $EXTRA_FLAGS exit_code1=$? if [ $exit_code1 -ne 0 ]; then exit 1 diff --git a/backends/qualcomm/__init__.py b/backends/qualcomm/__init__.py index 5770dfb0fcd..d0fb22bc5c3 100644 --- a/backends/qualcomm/__init__.py +++ b/backends/qualcomm/__init__.py @@ -1,5 +1,7 @@ import os +import torch + from .scripts.download_qnn_sdk import install_qnn_sdk, is_linux_x86 @@ -11,3 +13,4 @@ ok = install_qnn_sdk() if not ok: raise RuntimeError("Failed to install QNN SDK. Please check the logs above.") +torch.backends.mkldnn.enabled = False diff --git a/backends/qualcomm/bc/test_qnn_static_llama_bc.sh b/backends/qualcomm/bc/test_qnn_static_llama_bc.sh index 478e6118641..36d9cb61189 100644 --- a/backends/qualcomm/bc/test_qnn_static_llama_bc.sh +++ b/backends/qualcomm/bc/test_qnn_static_llama_bc.sh @@ -27,11 +27,11 @@ touch ${llama_artifacts}/params.json echo '{"dim": 64, "n_layers": 5, "n_heads": 8, "n_kv_heads": 4, "vocab_size": 512, "multiple_of": 4, "max_seq_len": 512}' > ${llama_artifacts}/params.json # Checks e2e accuracy -expected=$($PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_260k --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir . --llama_artifacts $llama_artifacts --enable_x86_64 | grep "Model CI result:") +expected=$($PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_260k --soc_model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir . --llama_artifacts $llama_artifacts --enable_x86_64 | grep "Model CI result:") exit_code1=$? # Checks accuracy with precompiled -output=$($PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_260k --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir $PTE_ARTIFACT --llama_artifacts $llama_artifacts --enable_x86_64 --pre_gen_pte $PTE_ARTIFACT | grep "Model CI result:") +output=$($PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_260k --soc_model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir $PTE_ARTIFACT --llama_artifacts $llama_artifacts --enable_x86_64 --pre_gen_pte $PTE_ARTIFACT | grep "Model CI result:") exit_code2=$? if [[ "$output" == "$expected" ]]; then diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 4ea5b1d5c40..21fb4b431a3 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -41,12 +41,11 @@ class MyModel(torch.nn.Module): ``` At the time we try to lower it with Qualcomm backend: ```python -from executorch.examples.qualcomm.utils import build_executorch_binary +from executorch.backends.qualcomm.export_utils import build_executorch_binary build_executorch_binary( model=MyModel(), - inputs=(torch.randn(200, 768),), - soc_model="SM8650" + qnn_config=qnn_config, file_name="my_model", dataset=None, ) diff --git a/backends/qualcomm/debugger/README.md b/backends/qualcomm/debugger/README.md index 9dc03b617e3..fb8f9a1c662 100644 --- a/backends/qualcomm/debugger/README.md +++ b/backends/qualcomm/debugger/README.md @@ -31,13 +31,14 @@ To enable model visualization, please add the `--online_prepare` flag. ## Details ### 1. Lower to QNN backend Generate an ExecuTorch binary for Qualcomm platforms. +Ensure that qnn_config.profile_level is set to 3, which will generate op_trace. ```python +qnn_config.profile_level = 3 build_executorch_binary( - model, - example_input, - args.model, - f"{args.artifact}/{pte_filename}", - [example_input], + model=model, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=[example_input], quant_dtype=QuantDtype.use_8a8w, online_prepare=args.online_prepare, optrace=True, @@ -47,14 +48,9 @@ build_executorch_binary( Generate optrace and QHAS files using QNN tools under $QNN_SDK_ROOT. After finishing, you will get a `binaries_trace` dictionary. ``` python adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", - workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - target=args.target, + workspace=f"/data/local/tmp/executorch/{pte_filename}, ) binaries_trace = generate_optrace( args, adb, f"{args.artifact}/{pte_filename}.pte", example_input @@ -139,42 +135,23 @@ When executing the script, please add the flag `--dump_intermediate_outputs`. Th Initialize a `QNNIntermediateDebugger`. Please pass initialized `QNNIntermediateDebugger` and the `args.dump_intermediate_outputs` to `build_executorch_binary` method as well. #### Example: ```python -from executorch.examples.qualcomm.utils import build_executorch_binary +from executorch.backends.qualcomm.export_utils import build_executorch_binary from executorch.backends.qualcomm.debugger.qnn_intermediate_debugger import QNNIntermediateDebugger qnn_intermediate_debugger = QNNIntermediateDebugger() build_executorch_binary( model=MyModel(), - inputs=(torch.randn(200, 768),), - soc_model="SM8650", + qnn_config=qnn_config, file_name="my_model", dataset=my_dataset, - dump_intermediate_outputs=args.dump_intermediate_outputs, # Add this flag - qnn_intermediate_debugger=qnn_intermediate_debugger, # Add this flag + qnn_intermediate_debugger=qnn_intermediate_debugger, # Provide this param ) ``` ### 4. Set data num to 1 It is perfectly fine for users to pass the desired amount of datasets to `build_executorch_binary`, which helps achieve better quantization results. However, after `build_executorch_binary` is called, we need to ensure that we only perform one inference during execution. Please ensure that CPU and QNN is using the same input during execution; otherwise, the debugging results might not be accurate. -### 5. Pass flag to SimpleADB -When creating `SimpleADB`, please also pass the flag `args.dump_intermediate_outputs`. This tells the runner to create files that store the intermediate output schema and binary data. -#### Example: -```python -adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", - pte_path=f"{args.artifact}/{pte_filename}.pte", - workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - dump_intermediate_outputs=args.dump_intermediate_outputs, # Add this flag -) -``` - -### 6: Pull and process the results. +### 5: Pull and process the results. After QNN execution with the runner, if the previous steps are done correctly, we should be able to get two files: `etdump.etdp` and `debug_output.bin`. The following example pulls the files back and calls a callback function to process the results. In this callback function, we create the `Inspector`. Then we perform CPU inference to get CPU intermediate results. Now, we have both QNN and CPU intermediate results, we can start generating results to compare the accuracy. Taking the following example, we should be able to get `debug_graph.svg` as an output in the current directory. #### Example: diff --git a/backends/qualcomm/export_utils.py b/backends/qualcomm/export_utils.py new file mode 100644 index 00000000000..dff7618df03 --- /dev/null +++ b/backends/qualcomm/export_utils.py @@ -0,0 +1,871 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# TODO: reenable pyre after fixing the issues +# pyre-ignore-all-errors +import argparse +import json +import logging +import os +import random +import subprocess +import sys +import tempfile +from dataclasses import dataclass, fields +from pathlib import Path +from typing import Callable, List, Optional, Set, Tuple, Union + +import numpy as np +import torch +import torchao +from executorch.backends.qualcomm.debugger.qnn_intermediate_debugger import ( + QNNIntermediateDebugger, +) +from executorch.backends.qualcomm.quantizer.quantizer import ( + ModuleQConfig, + QnnQuantizer, + QuantDtype, +) +from executorch.backends.qualcomm.serialization.qc_schema import ( + QcomChipset, + QnnExecuTorchBackendType, + QnnExecuTorchOpPackageOptions, +) +from executorch.backends.qualcomm.utils.constants import ( + DSP_VERSION, + HEXAGON_SDK_ROOT, + HEXAGON_TOOLS_ROOT, +) +from executorch.backends.qualcomm.utils.utils import ( + generate_gpu_compiler_spec, + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + get_qnn_context_binary_alignment, + get_soc_to_arch_map, + to_edge_transform_and_lower_to_qnn, +) +from executorch.exir.backend.utils import get_delegates +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from torchao.quantization.pt2e import MovingAverageMinMaxObserver +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) + + +@dataclass +class QnnConfig: + """ + A configuration used as input to QNN ExecuTorch’s lowering API. + This config initialization currently supports: + 1. Provide command-line arguments paired with setup_common_args_and_variables. + 2. Provide a json file that stores desired config. + + Attributes: + backend (str): The target backend, such as htp, gpu, etc. QnnConfig will then parse this to type QnnExecuTorchBackendType. + soc_model (QcomChipset): The target Qualcomm System on Chip (SoC) model. + build_folder (str): Path to cmake binary directory for target platform, e.g., /path/to/build-android. + direct_build_folder (str): Path to cmake binary directory for direct_mode. E.g., path/to/build-hexagon. + target (str): Target platform for deployment. + online_prepare (bool): Compose QNN graph on device if set to True. + shared_buffer (bool): Enables usage of shared buffer(zero-copy mechanism) between application and backend for graph I/O during runtime. + dump_intermediate_outputs (bool): Enables dumping model intermediate outputs. + profile_level (int): Level of profiling in runtime. + enable_x86_64: Enable x86_64 simulator execution. + host (str): Hostname where android device is connected. + device (str): Serial number for android device communicated via ADB. + port (int): IPC port for delivering execution result + ip (str): IPC address for delivering execution result. + skip_delegate_node_ids (str): If specified, skip delegation for the specified node based on node ids. Node ids should be separated by comma. e.g., aten_relu_default_10,aten_relu_default_2 + skip_delegate_node_ops (str): If specified, skip delegation for the specified op. Node ops should be separated by comma. e.g., aten.add.Tensor,aten.relu.default + compile_only (bool): If specified, only compile the model. + pre_gen_pte (str): Run the pre-generated pte in the given directory. + skip_push: If specified, skip pushing files to device. Assumes all required files are on device already. + ci (bool): This flag is for Continuous Integration(CI) purpose and is NOT recommended to turn on for typical use cases. It will use random inputs instead of real inputs. + seed (int): Set the seed for generating random numbers in both torch and random. + """ + + soc_model: str + build_folder: str + direct_build_folder: Optional[str] = None + backend: str = "htp" + target: str = "aarch64-android" + online_prepare: Optional[bool] = False + shared_buffer: Optional[bool] = False + dump_intermediate_outputs: Optional[bool] = False + profile_level: Optional[int] = 0 + enable_x86_64: Optional[bool] = False + host: Optional[str] = None + device: Optional[str] = None + port: Optional[str] = -1 + ip: Optional[str] = "" + skip_delegate_node_ids: Optional[str] = None + skip_delegate_node_ops: Optional[str] = None + compile_only: Optional[bool] = False + pre_gen_pte: Optional[str] = None + skip_push: Optional[bool] = False + ci: Optional[bool] = False + seed: Optional[int] = None + + def __post_init__(self): + assert self.soc_model, "Please provide the soc_model" + assert self.build_folder, "Please provide the build_folder." + assert not ( + self.compile_only and self.pre_gen_pte + ), "Cannot set both compile_only and pre_gen_pte as true" + assert ( + "QNN_SDK_ROOT" in os.environ + ), "Environment variable QNN_SDK_ROOT must be set." + if not self.enable_x86_64 and not self.compile_only and self.device is None: + raise RuntimeError( + "device serial is required if not compile only. Please specify a device serial." + ) + + if self.seed: + torch.manual_seed(self.seed) + np.random.seed(self.seed) + random.seed(self.seed) + + self.backend = get_backend_type(self.backend) + self.skip_delegate_node_ids, self.skip_delegate_node_ops = ( + self._parse_skip_delegation_node( + self.skip_delegate_node_ids, self.skip_delegate_node_ops + ) + ) + + @classmethod + def load_config(cls, config: Union[argparse.Namespace, str]) -> "QnnConfig": + """ + config (Union[argparse.Namespace, str]): Accepts either a parser generated from setup_common_args_and_variables() or a json file. + """ + qnn_config = None + if isinstance(config, argparse.Namespace): + logging.info("Using parser's config") + args_dict = vars(config) + matched_keys = {f.name for f in fields(QnnConfig)} + config = {k: v for k, v in args_dict.items() if k in matched_keys} + qnn_config = cls(**config) + elif isinstance(config, str): + logging.info(f"Using {config}'s config.") + + with open(config) as f: + qnn_config = cls(**json.load(f)) + else: + raise TypeError( + f"Invalid config type {type(config).__name__}. Expected argparse.Namespace or str." + ) + + return qnn_config + + def _parse_skip_delegation_node( + self, skip_delegate_node_ids, skip_delegate_node_ops + ): + skip_node_id_set = set() + skip_node_op_set = set() + + if skip_delegate_node_ids: + skip_node_id_set = set(map(str, skip_delegate_node_ids.split(","))) + print("Skipping following node ids: ", skip_node_id_set) + + if skip_delegate_node_ops: + skip_node_op_set = set(map(str, skip_delegate_node_ops.split(","))) + print("Skipping following node ops: ", skip_node_op_set) + + return skip_node_id_set, skip_node_op_set + + +class SimpleADB: + """ + A wrapper class for communicating with Android device + + Attributes: + qnn_config: (QnnConfig): A config class that saves qnn lowering and execution configuration. + pte_path (Union[str, list]): Path where executorch binary was stored. If there are multiple pte files, provide a list of pte paths. + workspace (str): Folder for storing artifacts on android device + error_only (bool): Redirect stdio and leave error messages only + runner (str): Runtime executor binary + expected_input_shape (Tuple[torch.Size]): Input shape of dynamic graph + expected_output_shape (Tuple[torch.Size]): Output shape of dynamic graph + """ + + def __init__( + self, + qnn_config: QnnConfig, + pte_path: Union[str, list], + workspace, + error_only=False, + runner=None, + expected_input_shape=None, + expected_output_shape=None, + ): + if runner is None: + runner = ( + "examples/qualcomm/executor_runner/qnn_executor_runner" + if qnn_config.direct_build_folder is None + else "examples/qualcomm/direct_executor_runner/qnn_executor_direct_runner" + ) + self.runner = runner + if qnn_config.direct_build_folder: + required_env = [HEXAGON_SDK_ROOT, HEXAGON_TOOLS_ROOT, DSP_VERSION] + assert all( + var in os.environ for var in required_env + ), f"Please ensure the following environment variables are set: {required_env}" + self.hexagon_sdk_root = os.getenv(HEXAGON_SDK_ROOT) + self.hexagon_tools_root = os.getenv(HEXAGON_TOOLS_ROOT) + self.dsp_arch = os.getenv(DSP_VERSION) + logging.info(f"{HEXAGON_SDK_ROOT}={self.hexagon_sdk_root}") + logging.info(f"{HEXAGON_TOOLS_ROOT}={self.hexagon_tools_root}") + logging.info(f"{DSP_VERSION}={self.dsp_arch}") + self.qnn_config = qnn_config + self.qnn_sdk = os.getenv("QNN_SDK_ROOT") + self.build_path = qnn_config.build_folder + self.direct_build_folder = qnn_config.direct_build_folder + self.pte_path = pte_path if isinstance(pte_path, list) else [pte_path] + if qnn_config.pre_gen_pte: + self.pte_path = [ + os.path.join(qnn_config.pre_gen_pte, os.path.basename(p)) + for p in self.pte_path + ] + assert all( + os.path.exists(p) for p in self.pte_path + ), f"{self.pte_path} not found. Please ensure there are pregenerated pte files under pre_gen_pte path." + logging.info( + f"Pregenerated pte path given. Using pre_gen_pte path: {self.pte_path}" + ) + self.workspace = workspace + self.device_id = qnn_config.device + self.host_id = qnn_config.host + if len(self.pte_path) > 0: + self.working_dir = Path(self.pte_path[0]).parent.absolute() + else: + self.working_dir = Path.cwd() + self.input_list_filename = "input_list.txt" + self.etdump_path = f"{self.workspace}/etdump.etdp" + self.dump_intermediate_outputs = qnn_config.dump_intermediate_outputs + self.debug_output_path = f"{self.workspace}/debug_output.bin" + self.output_folder = f"{self.workspace}/outputs" + self.htp_arch = get_soc_to_arch_map()[qnn_config.soc_model] + self.error_only = error_only + self.shared_buffer = qnn_config.shared_buffer + self.target = qnn_config.target + self.expected_input_shape = expected_input_shape + self.expected_output_shape = expected_output_shape + self.extra_cmds = "" + self.skip_push = qnn_config.skip_push + self.backend_library_paths = {} + + if self.direct_build_folder: + direct_general_artifacts = [ + f"{self.build_path}/examples/qualcomm/direct_executor_runner/libqnn_executorch_stub.so", + f"{self.direct_build_folder}/backends/qualcomm/libqnn_executorch_backend.so", + f"{self.direct_build_folder}/backends/qualcomm/qnn_executorch/direct_mode/libqnn_executorch_skel.so", + ] + self.backend_library_paths.update( + { + QnnExecuTorchBackendType.kHtpBackend: [ + f"{self.qnn_sdk}/lib/hexagon-v{self.htp_arch}/unsigned/libQnnHtpV{self.htp_arch}.so", + f"{self.qnn_sdk}/lib/hexagon-v{self.htp_arch}/unsigned/libQnnSystem.so", + f"{self.hexagon_tools_root}/Tools/target/hexagon/lib/v{self.htp_arch}/G0/pic/libc++abi.so.1", + f"{self.hexagon_tools_root}/Tools/target/hexagon/lib/v{self.htp_arch}/G0/pic/libc++.so.1", + ] + } + ) + for _, library_paths in self.backend_library_paths.items(): + library_paths.extend(direct_general_artifacts) + else: + traditional_general_artifacts = [ + f"{self.qnn_sdk}/lib/{self.target}/libQnnSystem.so", + f"{self.build_path}/backends/qualcomm/libqnn_executorch_backend.so", + f"{self.qnn_sdk}/lib/{self.target}/libQnnModelDlc.so", + ] + self.backend_library_paths.update( + { + QnnExecuTorchBackendType.kHtpBackend: [ + f"{self.qnn_sdk}/lib/{self.target}/libQnnHtp.so", + ( + f"{self.qnn_sdk}/lib/hexagon-v{self.htp_arch}/" + f"unsigned/libQnnHtpV{self.htp_arch}Skel.so" + ), + ( + f"{self.qnn_sdk}/lib/{self.target}/" + f"libQnnHtpV{self.htp_arch}Stub.so" + ), + f"{self.qnn_sdk}/lib/{self.target}/libQnnHtpPrepare.so", + ], + QnnExecuTorchBackendType.kGpuBackend: [ + f"{self.qnn_sdk}/lib/{self.target}/libQnnGpu.so", + ], + } + ) + for _, library_paths in self.backend_library_paths.items(): + library_paths.extend(traditional_general_artifacts) + + def _adb(self, cmd, output_callback: Optional[Callable[[str], None]] = None): + if not self.host_id: + cmds = ["adb", "-s", self.device_id] + else: + cmds = ["adb", "-H", self.host_id, "-s", self.device_id] + cmds.extend(cmd) + + if output_callback: + result = subprocess.run( + cmds, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True + ) + output_callback(result) + else: + subprocess.run( + cmds, stdout=subprocess.DEVNULL if self.error_only else sys.stdout + ) + + def push( # noqa: C901 + self, + inputs=None, + files=None, + backends: Optional[Set[QnnExecuTorchBackendType]] = None, + init_env=True, + ): + # Assume all required files are on device already + if self.skip_push: + return + + artifacts = [*self.pte_path, f"{self.build_path}/{self.runner}"] + if init_env: + self._adb(["shell", f"rm -rf {self.workspace}"]) + self._adb(["shell", f"mkdir -p {self.workspace}"]) + + if backends is None: + backends = {self.qnn_config.backend} + + # backend libraries + for backend in backends: + artifacts.extend(self.backend_library_paths[backend]) + with tempfile.TemporaryDirectory() as tmp_dir: + input_list_file, input_files = generate_inputs( + tmp_dir, self.input_list_filename, inputs + ) + + if input_list_file is not None: + # prepare input list + artifacts.append(input_list_file) + + for artifact in artifacts: + self._adb(["push", artifact, self.workspace]) + + # input data + for file_name in input_files: + self._adb(["push", file_name, self.workspace]) + + # dynamic shape related + if self.expected_input_shape and self.expected_output_shape: + shape_info = { + "input_shape": self.expected_input_shape, + "output_shape": self.expected_output_shape, + } + for name, shapes in shape_info.items(): + with open(f"{tmp_dir}/{name}.txt", "w") as f: + for s in shapes: + f.write(str(tuple(s)).strip("()") + "\n") + self._adb(["push", f"{tmp_dir}/{name}.txt", self.workspace]) + self.extra_cmds += f" --{name}_path {name}.txt" + + # custom files + if files is not None: + for file_name in files: + self._adb(["push", file_name, self.workspace]) + + def execute( + self, + custom_runner_cmd=None, + method_index=0, + output_callback: Optional[Callable[[str], None]] = None, + ): + self._adb(["shell", f"rm -rf {self.output_folder}"]) + self._adb(["shell", f"mkdir -p {self.output_folder}"]) + # run the delegation + if custom_runner_cmd is None: + qnn_executor_runner_args = ( + " ".join( + [ + f"--model_path {os.path.basename(self.pte_path[0])}", + f"--output_folder_path {self.output_folder}", + f"--input_list_path {self.input_list_filename}", + f"--etdump_path {self.etdump_path}", + "--shared_buffer" if self.shared_buffer else "", + f"--debug_output_path {self.debug_output_path}", + ( + "--dump_intermediate_outputs" + if self.dump_intermediate_outputs + else "" + ), + f"--method_index {method_index}", + ] + ) + + self.extra_cmds + ) + qnn_executor_runner_cmds = " ".join( + [ + f"cd {self.workspace} &&", + f"chmod +x {os.path.basename(self.runner)} &&", + f"export LD_LIBRARY_PATH=. && export ADSP_LIBRARY_PATH=. && echo 0x0C > {os.path.basename(self.runner)}.farf && ./{os.path.basename(self.runner)} {qnn_executor_runner_args}", + ] + ) + else: + qnn_executor_runner_cmds = custom_runner_cmd + self._adb( + ["shell", f"{qnn_executor_runner_cmds}"], output_callback=output_callback + ) + + def pull(self, host_output_path, device_output_path=None, callback=None): + if device_output_path is None: + device_output_path = self.output_folder + self._adb(["pull", "-a", device_output_path, host_output_path]) + if callback: + callback() + + def pull_etdump(self, output_path, callback=None): + self._adb(["pull", self.etdump_path, output_path]) + if callback: + callback() + + def pull_debug_output(self, etdump_path, debug_ouput_path, callback=None): + self._adb(["pull", self.etdump_path, etdump_path]) + self._adb(["pull", self.debug_output_path, debug_ouput_path]) + if callback: + callback() + + +def build_executorch_binary( + model: torch.nn.Module, # noqa: B006 + qnn_config: QnnConfig, + file_name: str, + dataset: List[torch.Tensor] | Callable[[torch.fx.GraphModule], None], + quant_dtype: Optional[QuantDtype] = None, + custom_quantizer: Optional[QnnQuantizer] = None, + metadata=None, + qnn_intermediate_debugger: QNNIntermediateDebugger = None, + passes_job=None, + passes_dependency=None, + qat_training_data=None, + op_package_options: QnnExecuTorchOpPackageOptions = None, +): + """ + A function to generate an ExecuTorch binary for Qualcomm platforms. + + Attributes: + model (torch.nn.Module): The model to be converted into an ExecuTorch binary. + qnn_config: (QnnConfig): A config class that saves qnn lowering and execution configuration. + file_name (str): Name for the output binary file (.pte). + dataset (List[torch.Tensor] | Callable): A dataset for quantization calibration. + quant_dtype (QuantDtype, optional): Data type for quantization. + custom_quantizer (Callable, optional): Custom quantizer. + metadata (dict, optional): An optional dictionary that maps each method name to a constant value in eager mode. + passes_job (OrderedDict, optional): Custom passes job in to_edge_transform_and_lower, users can enable/disable specific passes or modify their attributes. + passes_dependency (Dict, optional): A dictionary mapping each pass to its corresponding list of dependencies. + qat_training_data (List[torch.Tensor], optional): A dataset for quantization aware training(QAT). Typically is a pair of tensors, such as [features, ground truth]. + op_package_options: Optional structure to specify op packages + loaded and used by the backend. + + Returns: + None: The function writes the output to a specified .pte file. + """ + if qnn_config.pre_gen_pte: + logging.info( + f"Skip build_executorch_binary, using {file_name} under {qnn_config.pre_gen_pte}." + ) + return + + sample_input = dataset[0] + if ( + qnn_config.backend == QnnExecuTorchBackendType.kGpuBackend + and not qnn_config.online_prepare + ): + raise RuntimeError("Currently GPU backend only supports online_prepare.") + backend_options = { + QnnExecuTorchBackendType.kGpuBackend: generate_gpu_compiler_spec(), + QnnExecuTorchBackendType.kHtpBackend: generate_htp_compiler_spec( + use_fp16=False if any([quant_dtype, custom_quantizer]) is not None else True + ), + }[qnn_config.backend] + compile_spec = generate_qnn_executorch_compiler_spec( + soc_model=getattr(QcomChipset, qnn_config.soc_model), + backend_options=backend_options, + online_prepare=qnn_config.online_prepare, + profile_level=qnn_config.profile_level, + shared_buffer=qnn_config.shared_buffer, + dump_intermediate_outputs=qnn_config.dump_intermediate_outputs, + op_package_options=op_package_options, + ) + if quant_dtype is not None or custom_quantizer is not None: + captured_model = torch.export.export(model, sample_input, strict=False).module() + if qat_training_data: + quantizer = custom_quantizer or make_quantizer( + quant_dtype=quant_dtype, + is_qat=True, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, + ) + # qat training + annotated_model = _qat_train( + model, captured_model, quantizer, qat_training_data + ) + else: + quantizer = custom_quantizer or make_quantizer( + quant_dtype=quant_dtype, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, + ) + # ptq calibration + with torch.no_grad(): + annotated_model = _ptq_calibrate(captured_model, quantizer, dataset) + + quantized_model = convert_pt2e(annotated_model) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + quantized_model, + sample_input, + compile_spec, + constant_methods=metadata, + passes_job=passes_job, + dep_table=passes_dependency, + skip_node_id_set=qnn_config.skip_delegate_node_ids, + skip_node_op_set=qnn_config.skip_delegate_node_ops, + ) + else: + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + model, + sample_input, + compile_spec, + constant_methods=metadata, + passes_job=passes_job, + skip_node_id_set=qnn_config.skip_delegate_node_ids, + skip_node_op_set=qnn_config.skip_delegate_node_ops, + ) + + if qnn_intermediate_debugger: + lowered_module_nodes = get_delegates(edge_prog_mgr.exported_program().graph) + assert ( + len(lowered_module_nodes) == 1 + ), "Graph with partitions are currently unsupported." + + lowered_module_node = lowered_module_nodes[0] + lower_module = getattr( + edge_prog_mgr.exported_program().graph_module, lowered_module_node.name + ) + edge_module = lower_module.original_module.module() + qnn_intermediate_debugger.set_edge_module(edge_module=edge_module) + + allocate_io = not (qnn_config.shared_buffer or qnn_config.direct_build_folder) + executorch_config = ExecutorchBackendConfig( + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=allocate_io, + alloc_graph_output=allocate_io, + ), + segment_alignment=get_qnn_context_binary_alignment(), + ) + pte_name = f"{file_name}.pte" + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(pte_name, "wb") as file: + exec_prog_mgr.write_to_file(file) + + if qnn_config.compile_only: + sys.exit(0) + + +def make_quantizer( + quant_dtype: Optional[QuantDtype] = QuantDtype.use_8a8w, + custom_annotations=(), + per_channel_conv=True, + per_channel_linear=False, + act_observer=MovingAverageMinMaxObserver, + act_symmetric=False, + is_qat=False, + submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None, + backend=QnnExecuTorchBackendType.kHtpBackend, + soc_model="SM8750", + eps=None, +): + quantizer = QnnQuantizer(backend=backend, soc_model=getattr(QcomChipset, soc_model)) + quantizer.add_custom_quant_annotations(custom_annotations) + quantizer.set_default_quant_config( + quant_dtype, + is_qat=is_qat, + is_conv_per_channel=per_channel_conv, + is_linear_per_channel=per_channel_linear, + act_observer=act_observer, + act_symmetric=act_symmetric, + eps=eps, + ) + submodule_qconfig_list = submodule_qconfig_list or [] + quantizer.set_submodule_qconfig_list(submodule_qconfig_list) + return quantizer + + +def get_backend_type(backend: str): + return getattr(QnnExecuTorchBackendType, f"k{backend.title()}Backend") + + +def setup_common_args_and_variables(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--config_file", + help="To reduce the effort of providing a lot of command-line arguments, users can choose to save all arguments to a .json file and pass it in. Please refer to executorch/examples/qualcomm/executor_runner/sample_config.json for sample.", + type=str, + required=False, + ) + + parser.add_argument( + "-m", + "--soc_model", + "--model", # Deprecate this flag in future. + help="SoC model of current device. e.g. 'SM8550' for Snapdragon 8 Gen 2.", + type=str, + default=None, + ) + + parser.add_argument( + "-b", + "--build_folder", + help="path to cmake binary directory for target platform, e.g., /path/to/build-android", + type=str, + default=None, + ) + + parser.add_argument( + "-H", + "--host", + help="hostname where android device is connected.", + default=None, + type=str, + ) + + parser.add_argument( + "--online_prepare", + help="If specified, compose QNN graph on device.", + action="store_true", + default=False, + ) + + parser.add_argument( + "--ip", + help="IPC address for delivering execution result", + default="", + type=str, + ) + + parser.add_argument( + "--port", + help="IPC port for delivering execution result", + default=-1, + type=int, + ) + + parser.add_argument( + "-S", + "--skip_delegate_node_ids", + help="If specified, skip delegation for the specified node based on node ids. Node ids should be separated by comma. e.g., aten_relu_default_10,aten_relu_default_2", + default=None, + type=str, + ) + + parser.add_argument( + "-f", + "--skip_delegate_node_ops", + help="If specified, skip delegation for the specified op. Node ops should be separated by comma. e.g., aten.add.Tensor,aten.relu.default", + default=None, + type=str, + ) + + parser.add_argument( + "-c", + "--compile_only", + help="If specified, only compile the model.", + action="store_true", + default=False, + ) + + parser.add_argument( + "-s", + "--device", + help="serial number for android device communicated via ADB.", + type=str, + ) + + parser.add_argument( + "--backend", + help="Backend to be deployed ('htp'/'gpu' are currently supported).", + choices=["htp", "gpu"], + default="htp", + type=str, + ) + + parser.add_argument( + "-z", + "--shared_buffer", + help="Enables usage of shared buffer(zero-copy mechanism) between application and backend for graph I/O.", + action="store_true", + ) + + parser.add_argument( + "--skip_push", + help="If specified, skip pushing files to device.", + action="store_true", + default=False, + ) + + parser.add_argument( + "-D", + "--dump_intermediate_outputs", + help="If specified, enable dump intermediate outputs", + action="store_true", + default=False, + ) + + parser.add_argument( + "--profile_level", + type=int, + help="Profiling level of the delegate and QNN backend. 0=Off, 1=Basic(Currently not supported), 2=Detailed, 3=Optrace.", + choices=[0, 2, 3], + default=0, + ) + + parser.add_argument( + "-x", + "--enable_x86_64", + help="Enable unittest to be executed on x86_64 platform", + action="store_true", + ) + + parser.add_argument( + "--ci", + help="This flag is for Continuous Integration(CI) purpose and is NOT recommended to turn on for typical use cases. It will use random inputs instead of real inputs.", + action="store_true", + default=False, + ) + + parser.add_argument( + "--seed", + help="Set the seed for generating random numbers in both torch and random.", + type=int, + ) + + parser.add_argument( + "-t", + "--target", + help="Target platform for deployment", + choices=[ + "aarch64-android", + "aarch64-oe-linux-gcc9.3", + "aarch64-oe-linux-gcc11.2", + ], + default="aarch64-android", + type=str, + ) + + parser.add_argument( + "--pre_gen_pte", + help="Run the pre-generated pte in the given directory.", + type=str, + ) + + parser.add_argument( + "--direct_build_folder", + help="Path to cmake binary directory for direct_mode. E.g., path/to/build-hexagon." + "If enabled, run self-defined protocol to control fastrpc communication.", + type=str, + ) + + return parser + + +def generate_inputs( + dest_path: str, + input_list_filename: str, + inputs=None, + prefix_input_filename: str = "", +): + + input_list_file = None + input_files = [] + + def prepare_input_file(tensor, fd, index, sub_index): + # transform torch.Tensor to raw file + input_file_name = f"{prefix_input_filename}_input_{index}_{sub_index}.raw" + input_file_path = f"{dest_path}/{input_file_name}" + if not isinstance(tensor, torch.Tensor): + tensor = torch.tensor(tensor) + tensor.detach().numpy().tofile(input_file_path) + input_files.append(input_file_path) + # prepare input_list + if sub_index > 0: + fd.write(" ") + fd.write(input_file_name) + + # Prepare input data + if inputs is not None: + input_list_file = f"{dest_path}/{input_list_filename}" + + with open(input_list_file, "w") as f: + for idx, data in enumerate(inputs): + sub_index = 0 + for d in data: + if isinstance(d, (list, tuple)): + for sub_d in d: + prepare_input_file(sub_d, f, idx, sub_index) + sub_index += 1 + else: + prepare_input_file(d, f, idx, sub_index) + sub_index += 1 + + f.write("\n") + + return input_list_file, input_files + + +def _qat_train(ori_model, captured_model, quantizer, dataset): + data, targets = dataset + annotated_model = torchao.quantization.pt2e.move_exported_model_to_train( + prepare_qat_pt2e(captured_model, quantizer) + ) + optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001) + criterion = torch.nn.CrossEntropyLoss() + for i, d in enumerate(data): + print(f"Epoch {i}") + if i > 3: + # Freeze quantizer parameters + annotated_model.apply( + torchao.quantization.pt2e.fake_quantize.disable_observer + ) + if i > 2: + # Freeze batch norm mean and variance estimates + annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) + + output = annotated_model(*d) + loss = criterion(output, targets[i]) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + return convert_pt2e( + torchao.quantization.pt2e.move_exported_model_to_eval(annotated_model), + ) + + +def _ptq_calibrate(captured_model, quantizer, dataset): + annotated_model = prepare_pt2e(captured_model, quantizer) + print("Quantizing(PTQ) the model...") + # calibration + if callable(dataset): + dataset(annotated_model) + else: + for data in dataset: + annotated_model(*data) + return annotated_model diff --git a/backends/qualcomm/runtime/backends/direct_mode/README.md b/backends/qualcomm/runtime/backends/direct_mode/README.md index c9792b2fe92..0fa89114938 100644 --- a/backends/qualcomm/runtime/backends/direct_mode/README.md +++ b/backends/qualcomm/runtime/backends/direct_mode/README.md @@ -27,7 +27,7 @@ backends/qualcomm/scripts/build.sh --enable_hexagon 3. Execution Below is an example to execute a unit test with direct mode using qnn_executor_direct_runner. ``` -python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator.test_qnn_backend_adaptive_avg_pool2d --model SM8750 --device $DEVICE_ID --build_folder build-android --direct_build_folder build-hexagon/ +python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator.test_qnn_backend_adaptive_avg_pool2d --soc_model SM8750 --device $DEVICE_ID --build_folder build-android --direct_build_folder build-hexagon/ ``` ### Note diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 10c2a717e20..15d901cbe3c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -27,6 +27,12 @@ get_passes_dependency_for_capture_program, ) from executorch.backends.qualcomm.debugger.utils import generate_optrace + +from executorch.backends.qualcomm.export_utils import ( + get_backend_type, + make_quantizer, + setup_common_args_and_variables, +) from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) @@ -64,12 +70,6 @@ update_spill_fill_size, ) -from executorch.examples.qualcomm.utils import ( - get_backend_type, - make_quantizer, - setup_common_args_and_variables, -) - from executorch.backends.qualcomm.tests.models import * # noqa: F403 import os @@ -110,13 +110,13 @@ def setUp(self): TestQNN.atol = 1e-1 TestQNN.rtol = 1e-1 TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, debug=False, saver=False, online_prepare=TestQNN.online_prepare, dump_intermediate_outputs=TestQNN.dump_intermediate_outputs, - profile=TestQNN.enable_profile, + profile_level=TestQNN.profile_level, shared_buffer=TestQNN.shared_buffer, ) @@ -1939,13 +1939,13 @@ def setUp(self): TestQNN.rtol = 1e-1 backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, debug=False, saver=False, online_prepare=TestQNN.online_prepare, dump_intermediate_outputs=TestQNN.dump_intermediate_outputs, - profile=TestQNN.enable_profile, + profile_level=TestQNN.profile_level, shared_buffer=TestQNN.shared_buffer, ) @@ -2174,13 +2174,13 @@ def setUp(self): TestQNN.rtol = 1 backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, debug=False, saver=False, online_prepare=TestQNN.online_prepare, dump_intermediate_outputs=TestQNN.dump_intermediate_outputs, - profile=TestQNN.enable_profile, + profile_level=TestQNN.profile_level, shared_buffer=TestQNN.shared_buffer, ) @@ -4311,13 +4311,13 @@ def setUp(self): TestQNN.rtol = 1 backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, debug=False, saver=False, online_prepare=TestQNN.online_prepare, dump_intermediate_outputs=TestQNN.dump_intermediate_outputs, - profile=TestQNN.enable_profile, + profile_level=TestQNN.profile_level, shared_buffer=TestQNN.shared_buffer, ) @@ -4520,9 +4520,9 @@ def test_qnn_backend_masked_softmax(self): ) backend_options = generate_htp_compiler_spec(use_fp16=False) compiler_spec = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, - optrace=True, + profile_level=3, ) with tempfile.TemporaryDirectory() as tmp_dir: edge_prog_mgr = to_edge_transform_and_lower_to_qnn( @@ -4533,7 +4533,11 @@ def test_qnn_backend_masked_softmax(self): edge_prog_mgr.write_to_file(f) adb = self.get_adb_tool(pte_path) binaries_trace = generate_optrace( - tmp_dir, self.chipset_table[self.model], adb, pte_path, [sample_input] + tmp_dir, + self.chipset_table[self.soc_model], + adb, + pte_path, + [sample_input], ) has_masked_softmax = False for _, (_, qhas) in binaries_trace.items(): @@ -4709,7 +4713,7 @@ def setUp(self): TestQNN.rtol = 1e-1 backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, debug=False, saver=False, @@ -4719,7 +4723,7 @@ def test_qnn_backend_dump_intermediate_outputs_topk(self): TestQNN.dump_intermediate_outputs = True backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, dump_intermediate_outputs=True, ) @@ -4737,7 +4741,7 @@ def test_qnn_backend_dump_intermediate_outputs_simple_model(self): TestQNN.dump_intermediate_outputs = True backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, dump_intermediate_outputs=True, ) @@ -4779,7 +4783,7 @@ def test_qnn_backend_spill_fill_buffer_size(self): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) edge_prog = to_edge_transform_and_lower_to_qnn( @@ -4798,7 +4802,7 @@ def test_qnn_backend_multi_contexts(self): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) pass_jobs = get_capture_program_passes() @@ -4825,7 +4829,7 @@ def test_qnn_backend_multi_contexts_composite(self): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) module = CompositeDelegateModule( # noqa: F405 @@ -4855,7 +4859,7 @@ def test_qnn_backend_multi_graphs(self): ) compiler_specs = [ generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) ] * len(graph_names) @@ -4880,12 +4884,12 @@ def test_qnn_backend_multi_graphs(self): ) def test_qnn_backend_profile_op(self): - TestQNN.enable_profile = True + TestQNN.profile_level = 2 backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, - profile=True, + profile_level=2, ) module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) @@ -4895,11 +4899,12 @@ def test_qnn_backend_profile_op(self): expected_partitions=1, expected_profile_events=30, ) + TestQNN.profile_level = 0 def test_qnn_backend_runtime_option_htp_performance(self): backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) module = SimpleModel() # noqa: F405 @@ -4946,7 +4951,7 @@ def output_callback(log_msg, is_burst): def test_qnn_backend_runtime_option_log(self): backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) module = SimpleModel() # noqa: F405 @@ -4975,12 +4980,12 @@ def output_callback(log_msg): ) def test_qnn_backend_runtime_option_profile(self): - TestQNN.enable_profile = True + TestQNN.profile_level = 2 backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, - profile=False, # Turn on using runtime command + profile_level=0, # Turn on using runtime command ) module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) @@ -4993,6 +4998,7 @@ def test_qnn_backend_runtime_option_profile(self): expected_profile_events=30, extra_cmds=runtime_commands, ) + TestQNN.profile_level = 0 def test_qnn_backend_shared_buffer(self): TestQNN.shared_buffer = True @@ -5000,7 +5006,7 @@ def test_qnn_backend_shared_buffer(self): use_fp16=True, ) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, shared_buffer=True, ) @@ -5018,7 +5024,7 @@ def test_qnn_backend_online_prepare(self): backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, online_prepare=True, ) @@ -5243,14 +5249,14 @@ def test_qnn_backend_generate_optrace(self): compiler_specs = [ generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, online_prepare=True, ), generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, - optrace=True, + profile_level=3, ), ] @@ -5266,7 +5272,7 @@ def test_qnn_backend_generate_optrace(self): adb = self.get_adb_tool(pte_path) binaries_trace = generate_optrace( tmp_dir, - self.chipset_table[self.model], + self.chipset_table[self.soc_model], adb, pte_path, [sample_input], @@ -5298,7 +5304,7 @@ def setUp(self): TestQNN.rtol = 1 backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, debug=False, saver=False, @@ -5308,7 +5314,7 @@ def test_qnn_backend_dump_intermediate_outputs_simple_model(self): TestQNN.dump_intermediate_outputs = True backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, dump_intermediate_outputs=True, ) @@ -5328,7 +5334,7 @@ def test_qnn_backend_dump_intermediate_outputs_topk(self): TestQNN.dump_intermediate_outputs = True backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, dump_intermediate_outputs=True, ) @@ -5420,7 +5426,7 @@ def test_qnn_backend_rewrite_prepared_observer(self): module = torch.export.export(module, sample_input, strict=True).module() quantizer = make_quantizer( - backend=get_backend_type(self.backend), soc_model=self.model + backend=get_backend_type(self.backend), soc_model=self.soc_model ) prepared = prepare_pt2e(module, quantizer) @@ -5447,7 +5453,7 @@ def test_qnn_backend_rewrite_prepared_observer(self): def test_qnn_backend_saver_backend(self): backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, saver=True, ) @@ -5496,12 +5502,12 @@ def test_qnn_backend_skip_node_id_quantizer(self): use_fp16=False, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) # define quantizer quantizer = make_quantizer( - backend=get_backend_type(self.backend), soc_model=self.model + backend=get_backend_type(self.backend), soc_model=self.soc_model ) # define calibration method @@ -5544,12 +5550,12 @@ def test_qnn_backend_skip_node_op_quantizer(self): use_fp16=False, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) # define quantizer quantizer = make_quantizer( - backend=get_backend_type(self.backend), soc_model=self.model + backend=get_backend_type(self.backend), soc_model=self.soc_model ) # define calibration method @@ -5581,7 +5587,7 @@ def test_qnn_backend_spill_fill_buffer_size(self): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) edge_prog = to_edge_transform_and_lower_to_qnn( @@ -5600,12 +5606,12 @@ def test_qnn_backend_graph_level_mixed_precision(self): use_fp16=False, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) # define quantizer quantizer = make_quantizer( - backend=get_backend_type(self.backend), soc_model=self.model + backend=get_backend_type(self.backend), soc_model=self.soc_model ) # define calibration method @@ -5639,7 +5645,7 @@ def test_qnn_backend_multi_contexts(self): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) pass_jobs = get_capture_program_passes() @@ -5666,7 +5672,7 @@ def test_qnn_backend_multi_contexts_composite(self): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) module = CompositeDelegateModule( # noqa: F405 @@ -5697,7 +5703,7 @@ def test_qnn_backend_multi_graphs(self): ) compiler_specs = [ generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) ] * len(graph_names) @@ -5709,7 +5715,7 @@ def test_qnn_backend_multi_graphs(self): module_prepared = prepare_pt2e( module_exported, make_quantizer( - backend=get_backend_type(self.backend), soc_model=self.model + backend=get_backend_type(self.backend), soc_model=self.soc_model ), ) module_prepared(*sample_inputs[i]) @@ -5730,12 +5736,12 @@ def test_qnn_backend_multi_graphs(self): ) def test_qnn_backend_profile_op(self): - TestQNN.enable_profile = True + TestQNN.profile_level = 2 backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, - profile=True, + profile_level=2, ) module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) @@ -5746,11 +5752,12 @@ def test_qnn_backend_profile_op(self): expected_partitions=1, expected_profile_events=30, ) + TestQNN.profile_level = 0 def test_qnn_backend_runtime_option_htp_performance(self): backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) module = SimpleModel() # noqa: F405 @@ -5798,7 +5805,7 @@ def output_callback(log_msg, is_burst): def test_qnn_backend_runtime_option_log(self): backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, ) module = SimpleModel() # noqa: F405 @@ -5828,12 +5835,12 @@ def output_callback(log_msg): ) def test_qnn_backend_runtime_option_profile(self): - TestQNN.enable_profile = True + TestQNN.profile_level = 2 backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, - profile=False, # Turn on using runtime command + profile_level=0, # Turn on using runtime command ) module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) @@ -5847,6 +5854,7 @@ def test_qnn_backend_runtime_option_profile(self): expected_profile_events=30, extra_cmds=runtime_commands, ) + TestQNN.profile_level = 0 def test_qnn_backend_shared_buffer(self): TestQNN.shared_buffer = True @@ -5854,7 +5862,7 @@ def test_qnn_backend_shared_buffer(self): use_fp16=False, ) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, shared_buffer=True, ) @@ -5873,7 +5881,7 @@ def test_qnn_backend_online_prepare(self): backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, online_prepare=True, ) @@ -6121,14 +6129,14 @@ def test_qnn_backend_generate_optrace(self): compiler_specs = [ generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, online_prepare=True, ), generate_qnn_executorch_compiler_spec( - soc_model=self.chipset_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.soc_model], backend_options=backend_options, - optrace=True, + profile_level=3, ), ] @@ -6144,7 +6152,7 @@ def test_qnn_backend_generate_optrace(self): adb = self.get_adb_tool(pte_path) binaries_trace = generate_optrace( tmp_dir, - self.chipset_table[self.model], + self.chipset_table[self.soc_model], adb, pte_path, [sample_input], @@ -6182,11 +6190,11 @@ def test_qnn_backend_seq_mse(self): # per-channel / per-block quantizers = [ make_quantizer( - backend=get_backend_type(self.backend), soc_model=self.model + backend=get_backend_type(self.backend), soc_model=self.soc_model ), make_quantizer( backend=get_backend_type(self.backend), - soc_model=self.model, + soc_model=self.soc_model, quant_dtype=QuantDtype.use_16a4w_block, ), ] @@ -6269,7 +6277,7 @@ def setUp(self): SM8650=37, SM8750=45, pte_size=1_500_000_000, # 1.5 GB - wikitext_ppl=16, + wikitext_ppl=17, hellaswag_acc_norm=None, sqnr=15, ), @@ -6425,7 +6433,9 @@ def test_static_llm_model(self): # noqa: C901 conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - logging.info(f"Model Name: {self.model_name}\nTarget Device: {self.model}") + logging.info( + f"Model Name: {self.model_name}\nTarget Device: {self.soc_model}" + ) logging.info(f"Eval Result: {msg}") if "Error" in msg: self.fail(msg["Error"]) @@ -6452,9 +6462,9 @@ def test_static_llm_model(self): # noqa: C901 sqnr = msg["sqnr"] self.assertGreaterEqual(sqnr, llm_spec.sqnr) - if not self.enable_x86_64 and hasattr(llm_spec, self.model): + if not self.enable_x86_64 and hasattr(llm_spec, self.soc_model): device_inference_speed = msg["inference_speed"] - expected_inference_speed = getattr(llm_spec, self.model) + expected_inference_speed = getattr(llm_spec, self.soc_model) self.assertGreaterEqual( device_inference_speed, expected_inference_speed ) @@ -6784,8 +6794,8 @@ def test_static_vlm(self): self.artifact_dir, "--build_folder", self.build_folder, - "--model", - self.model, + "--soc_model", + self.soc_model, "--ip", self.ip, "--port", @@ -6844,7 +6854,7 @@ def test_static_vlm(self): print(f"Token Embedding PTE Size: {tok_embedding_pte_size} bytes") print(f"Text Decoder PTE Size: {decoder_pte_size} bytes") - attr_name = f"{self.model.lower()}_token_rate" + attr_name = f"{self.soc_model.lower()}_token_rate" if ( not self.compile_only and not self.enable_x86_64 @@ -7772,7 +7782,7 @@ def test_utils_export(self): "-a", ctx_path, "-m", - self.model, + self.soc_model, "-l", "False", "-b", @@ -8297,8 +8307,8 @@ def test_cli(self): f"{tmp_dir}/q_out", "--input_list", f"{tmp_dir}/input_list", - "--model", - self.model, + "--soc_model", + self.soc_model, ] subprocess.run(cmds, stdout=subprocess.DEVNULL) self.assertTrue(os.path.isfile(f"{tmp_dir}/q_out/relu_quantized.pt2")) @@ -8312,8 +8322,8 @@ def test_cli(self): f"{tmp_dir}/q_out/relu_quantized.pt2", "--output_folder", f"{tmp_dir}/c_out", - "--model", - self.model, + "--soc_model", + self.soc_model, ] subprocess.run(cmds, stdout=subprocess.DEVNULL) self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/relu_quantized.pte")) @@ -8332,8 +8342,10 @@ def test_cli(self): self.build_folder, "--input_list", f"{tmp_dir}/input_list", - "--model", - self.model, + "--soc_model", + self.soc_model, + "--host", + self.host, "--target", self.target, "--device", @@ -8369,8 +8381,8 @@ def test_cli_with_input_list_assignment(self): f"{tmp_dir}/q_out", "--input_list", f"{tmp_dir}/input_list", - "--model", - self.model, + "--soc_model", + self.soc_model, ] subprocess.run(cmds, stdout=subprocess.DEVNULL) self.assertTrue(os.path.isfile(f"{tmp_dir}/q_out/sub_quantized.pt2")) @@ -8384,8 +8396,8 @@ def test_cli_with_input_list_assignment(self): f"{tmp_dir}/q_out/sub_quantized.pt2", "--output_folder", f"{tmp_dir}/c_out", - "--model", - self.model, + "--soc_model", + self.soc_model, ] subprocess.run(cmds, stdout=subprocess.DEVNULL) self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/sub_quantized.pte")) @@ -8400,8 +8412,8 @@ def test_cli_with_input_list_assignment(self): f"{tmp_dir}/c_out/sub_quantized.pte", "--output_folder", f"{tmp_dir}/e_out", - "--model", - self.model, + "--soc_model", + self.soc_model, "--target", self.target, "--device", @@ -8432,8 +8444,8 @@ def test_custom_op(self): self.build_folder, "--device", self.device, - "--model", - self.model, + "--soc_model", + self.soc_model, "--target", self.target, "--ip", @@ -8466,18 +8478,19 @@ def test_debugger_generate_optrace(self): self.build_folder, "--device", self.device, - "--model", - self.model, + "--soc_model", + self.soc_model, "--target", self.target, "--ip", self.ip, "--port", str(self.port), + "--profile_level", + "3", ] if self.host: cmds.extend(["--host", self.host]) - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: conn = listener.accept() @@ -8517,8 +8530,8 @@ def test_intermediate_debugger(self): self.build_folder, "--device", self.device, - "--model", - self.model, + "--soc_model", + self.soc_model, "--ip", self.ip, "--port", @@ -8614,12 +8627,6 @@ def setup_environment(): help="Input the model to export", type=str, ) - parser.add_argument( - "-P", - "--enable_profile", - help="Profile the performance of each operator with kProfileDetailed profile level", - action="store_true", - ) parser.add_argument( "-e", "--error_only", @@ -8653,7 +8660,7 @@ def setup_environment(): args, ns_args = parser.parse_known_args(namespace=unittest) TestQNN.host = args.host TestQNN.device = args.device - TestQNN.model = args.model + TestQNN.soc_model = args.soc_model TestQNN.build_folder = args.build_folder TestQNN.executorch_root = args.executorch_root TestQNN.artifact_dir = args.artifact_dir @@ -8663,7 +8670,7 @@ def setup_environment(): TestQNN.pretrained_weight = args.pretrained_weight TestQNN.model_name = args.model_name TestQNN.online_prepare = args.online_prepare - TestQNN.enable_profile = args.enable_profile + TestQNN.profile_level = args.profile_level TestQNN.error_only = args.error_only TestQNN.oss_repo = args.oss_repo TestQNN.shared_buffer = args.shared_buffer diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 1b92df80dcf..06ec042b017 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -18,6 +18,13 @@ from executorch.backends.qualcomm.debugger.qnn_intermediate_debugger import ( QNNIntermediateDebugger, ) +from executorch.backends.qualcomm.export_utils import ( + generate_inputs, + get_backend_type, + make_quantizer, + QnnConfig, + SimpleADB, +) from executorch.backends.qualcomm.qnn_preprocess import QnnBackend from executorch.backends.qualcomm.quantizer.quantizer import ModuleQConfig, QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset @@ -35,13 +42,7 @@ ) from executorch.devtools import Inspector from executorch.devtools.inspector._inspector_utils import TimeScale -from executorch.examples.qualcomm.utils import ( - generate_inputs, - get_backend_type, - make_output_dir, - make_quantizer, - SimpleADB, -) +from executorch.examples.qualcomm.utils import make_output_dir from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.utils import get_delegates @@ -155,7 +156,7 @@ class TestQNN(unittest.TestCase): host: str = "" device: str = "" build_folder: str = "" - model: QcomChipset = None + soc_model: QcomChipset = None compiler_specs: List[CompileSpec] = None chipset_table = get_soc_to_chipset_map() error_only = False @@ -167,7 +168,7 @@ class TestQNN(unittest.TestCase): qa_dataset: str = "" sentence_dataset: str = "" pretrained_weight: str = "" - enable_profile: bool = False + profile_level: int = 0 op_package_dir: str = "" target: str = "" model_name: str = "" @@ -191,19 +192,24 @@ class TestQNN(unittest.TestCase): @classmethod def setUpClass(cls): if not cls.enable_x86_64 and not cls.compile_only: + qnn_config = QnnConfig( + backend=cls.backend, + target=cls.target, + build_folder=cls.build_folder, + device=cls.device, + host=cls.host, + soc_model=cls.soc_model, + direct_build_folder=cls.direct_build_folder, + ) + # init device once adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=cls.build_folder, - direct_mode_build_path=cls.direct_build_folder, + qnn_config=qnn_config, pte_path=[], workspace="/data/local/tmp/qnn_executorch_test", - device_id=cls.device, - host_id=cls.host, - soc_model=cls.model, error_only=cls.error_only, - target=cls.target, ) + adb.push( backends={get_backend_type(cls.backend)}, init_env=True, @@ -257,8 +263,8 @@ def required_envs(self, conditions=None) -> bool: def add_default_cmds(self, cmds): cmds.extend( [ - "--model", - self.model, + "--soc_model", + self.soc_model, "--target", self.target, "--ip", @@ -485,17 +491,22 @@ def validate_intermediate_tensor(): self.inference_speed = float(f.read()) else: + qnn_config = QnnConfig( + backend=self.backend, + target=self.target, + build_folder=self.build_folder, + device=self.device, + host=self.host, + soc_model=self.soc_model, + dump_intermediate_outputs=expected_intermediate_events != -1, + direct_build_folder=self.direct_build_folder, + ) + adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=self.build_folder, - direct_mode_build_path=self.direct_build_folder, + qnn_config=qnn_config, pte_path=pte_fname, workspace="/data/local/tmp/qnn_executorch_test", - device_id=self.device, - host_id=self.host, - soc_model=self.model, error_only=self.error_only, - dump_intermediate_outputs=self.dump_intermediate_outputs, expected_input_shape=( (tensor.shape for tensor in processed_inputs) if check_io_shape @@ -506,7 +517,6 @@ def validate_intermediate_tensor(): if check_io_shape else None ), - target=self.target, ) adb.push( inputs=[processed_inputs], @@ -565,7 +575,8 @@ def lower_module_and_test_output( skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, skip_mutable_buffer=skip_mutable_buffer, - generate_etrecord=self.enable_profile, + generate_etrecord=self.profile_level != 0 + or expected_intermediate_events != -1, ) qnn_intermediate_debugger = None @@ -614,7 +625,7 @@ def lower_module_and_test_output( ) etrecord_path = "etrecord.bin" - if self.enable_profile: + if self.profile_level: exec_prog.get_etrecord().save(etrecord_path) # Check numerics if ( @@ -660,7 +671,7 @@ def get_qdq_module( per_channel_linear=is_linear_per_channel, submodule_qconfig_list=submodule_qconfig_list, backend=get_backend_type(self.backend), - soc_model=self.model, + soc_model=self.soc_model, ) if block_size_map is not None: quantizer.set_block_size_map(block_size_map) @@ -702,7 +713,7 @@ def get_prepared_qat_module( is_qat=True, submodule_qconfig_list=submodule_qconfig_list, backend=get_backend_type(self.backend), - soc_model=self.model, + soc_model=self.soc_model, ) if block_size_map is not None: quantizer.set_block_size_map(block_size_map) @@ -729,17 +740,20 @@ def get_converted_sgd_trained_module( return convert_pt2e(prepared) def get_adb_tool(self, pte_fname): + qnn_config = QnnConfig( + backend=self.backend, + build_folder=self.build_folder, + device=self.device, + host=self.host, + soc_model=self.soc_model, + direct_build_folder=self.direct_build_folder, + ) + adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=self.build_folder, - direct_mode_build_path=self.direct_build_folder, + qnn_config=qnn_config, pte_path=pte_fname, workspace="/data/local/tmp/qnn_executorch_test", - device_id=self.device, - host_id=self.host, - soc_model=self.model, error_only=self.error_only, - target=self.target, ) return adb diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 19603e6219b..4927d37f1f7 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -1047,8 +1047,7 @@ def generate_qnn_executorch_compiler_spec( saver: bool = False, online_prepare: bool = False, dump_intermediate_outputs: bool = False, - profile: bool = False, - optrace: bool = False, + profile_level: int = 0, shared_buffer: bool = False, is_from_context_binary: bool = False, op_package_options: QnnExecuTorchOpPackageOptions = None, @@ -1075,9 +1074,8 @@ def generate_qnn_executorch_compiler_spec( for debugging purpose. dump_intermediate_outputs: If tensor dump is enabled, all intermediate tensors output will be dumped. This option exists for debugging accuracy issues - profile: Enable profile the performance of per operator. - Note that for now only support kProfileDetailed to - profile the performance of each operator with cycle unit. + profile_level: Enable profiling the performance of per operator. + Note that for now only support kProfileDetailed and kProfileOptrace. shared_buffer: Enables usage of shared buffer between application and backend for graph I/O. is_from_context_binary: True if current graph comes from pre-built context binary. @@ -1096,7 +1094,7 @@ def generate_qnn_executorch_compiler_spec( if soc_model not in _supported_soc_models: raise ValueError(f"unknown SoC model for QNN: {soc_model}") - if profile and dump_intermediate_outputs: + if profile_level and dump_intermediate_outputs: warnings.warn( "It is not recommended to turn on both profiling and dump_intermediate_outputs the same time" ", because dump_intermediate_outputs will cause performance drop.", @@ -1119,13 +1117,19 @@ def generate_qnn_executorch_compiler_spec( qnn_executorch_options.saver = True qnn_executorch_options.saver_output_dir = "saver_output" - if optrace: + if profile_level == 3: qnn_executorch_options.profile_level = QnnExecuTorchProfileLevel.kProfileOptrace - elif profile: + elif profile_level == 2: qnn_executorch_options.profile_level = ( QnnExecuTorchProfileLevel.kProfileDetailed ) else: + if profile_level == 1: + warnings.warn( + "Profile Level 1, kProfileBasic, is not supported, turning off profiling.", + DeprecationWarning, + stacklevel=1, + ) qnn_executorch_options.profile_level = QnnExecuTorchProfileLevel.kProfileOff if ( diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0394bf7f320..e605392cf7e 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -196,7 +196,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--model", default="llama3", choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS, - help="The Lllama model to export. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.", + help="The Llama model to export. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.", ) parser.add_argument( "-E", @@ -423,7 +423,7 @@ def build_args_parser() -> argparse.ArgumentParser: type=parse_list_of_ints, default=None, help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16] to specify that every other layer is sliding window of 16." - " [0, 16, 32] pattern specifes 2nd and 3rd layer has sliding window of 16 and 32 respecitvely. " + " [0, 16, 32] pattern specifes 2nd and 3rd layer has sliding window of 16 and 32 respectively." " [16] pattern specifies all layers have sliding window of 16.", ) diff --git a/examples/qualcomm/README.md b/examples/qualcomm/README.md index 53505db6d69..34a0b4e8fe8 100644 --- a/examples/qualcomm/README.md +++ b/examples/qualcomm/README.md @@ -42,7 +42,7 @@ Or, you could put QNN libraries to default search path of the dynamic linker. Please connect an Android phone to the workstation. We use `adb` to communicate with the device. If the device is in a remote host, you might want to add `-H` to the `adb` -commands in the `SimpleADB` class inside [utils.py](utils.py). +commands in the `SimpleADB` class inside [export_utils.py](../../backends/qualcomm/export_utils.py). ## Please use python xxx.py --help for information of each examples. @@ -86,31 +86,31 @@ If you run into the following error, that means the ${QNN_SDK_ROOT} that you are Error: Failed to get context binary info. ``` ## Model Structure -This section outlines the essential APIs and utilities provided to streamline the process of model conversion, deployment, and evaluation on Qualcomm hardware using ExecuTorch. +This section outlines the essential APIs and utilities provided to streamline the process of model conversion, deployment, and evaluation on Qualcomm hardware using ExecuTorch. The official APIs can be found under [export_utils.py](../../backends/qualcomm/export_utils.py) -1. `build_executorch_binary()`: +1. `setup_common_args_and_variables()`: - build_executorch_binary is a high-level API used to convert a PyTorch model into a Qualcomm-compatible .pte binary format. This function streamlines the process of quantization, transformation, optimization, and export, enabling users to efficiently deploy models on Qualcomm hardware. + `setup_common_args_and_variables()` returns an `argparse.ArgumentParser`. This parser defines both required and optional arguments, which can later be passed into the ExecuTorch QNN API, `QnnConfig.load_config()`. -2. `SimpleADB`: +2. `QnnConfig.load_config()`: - SimpleADB is a Python class that provides a simplified interface for interacting with Android devices. It allows users to execute ADB commands, retrieve device information, and manage files on the device. + `QnnConfig.load_config` accepts either: + 1. An `argparse.ArgumentParser` created by `setup_common_args_and_variables()` + 2. A `.json` configuration file. A sample file is provided under [sample_config.json](./sample_config.json) for reference. -3. `get_imagenet_dataset`: - - If the model requires ImageNet, this function can be used to load the dataset and apply the necessary preprocessing steps to prepare it for inference or quantization calibration. + This function returns a `QnnConfig`, which serves as an input to some of the key APIs that will be covered below: `build_executorch_binary()`, `SimpleADB`. -4. `topk_accuracy`: +3. `build_executorch_binary()`: - Calculates the Top-K accuracy for classification models, used to evaluate model performance. + `build_executorch_binary` is a high-level API used to convert a PyTorch model into a Qualcomm-compatible .pte binary format. This function streamlines the process of quantization, transformation, optimization, and export, enabling users to efficiently deploy models on Qualcomm hardware. -5. `parse_skip_delegation_node`: +4. `SimpleADB`: - Parses command-line arguments to identify node IDs or operation types that should be skipped during model conversion. + `SimpleADB` provides a simplified interface for interacting with Android devices. It allows users to execute ADB commands such as: + 1. Push necessary artifacts to device + 2. Execute the runner + 3. Pull the execution outputs/results -6. `make_output_dir`: - - Creates a clean directory for storing model outputs or intermediate results. If the directory already exists, it will be deleted and recreated to ensure a consistent environment for each run. ## Run Inference Using Shared Buffer This section shows how to use shared buffer for input/output tensors in QNN ExecuTorch, usually graph inputs and outputs on shared memory to reduce huge tensor copying time from CPU to HTP. This feature can accelerate inference speed. Users need to do shared memory resource management by themselves. The key idea is to use `QnnExecuTorchAllocCustomMem` to allocate a large chunk of memory on the device, then use `QnnExecuTorchFreeCustomMem` to free it after inference. @@ -146,8 +146,7 @@ pip install scikit-learn pandas graphviz ## Limitation -1. QNN 2.28 is used for all examples. Newer or older QNN might work, -but the performance and accuracy number can differ. +1. QNN 2.37 is used for all examples. Newer or older QNN might work, but the performance and accuracy number can differ. 2. The mobilebert example is on QNN HTP fp16, which is only supported by a limited set of SoCs. Please check QNN documents for details. diff --git a/examples/qualcomm/custom_op/custom_ops_1.py b/examples/qualcomm/custom_op/custom_ops_1.py index 58a3bb64344..c91f75b8816 100644 --- a/examples/qualcomm/custom_op/custom_ops_1.py +++ b/examples/qualcomm/custom_op/custom_ops_1.py @@ -13,6 +13,15 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + generate_inputs, + get_backend_type, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( @@ -24,15 +33,7 @@ QnnExecuTorchOpPackagePlatform, QnnExecuTorchOpPackageTarget, ) -from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - generate_inputs, - get_backend_type, - make_output_dir, - make_quantizer, - setup_common_args_and_variables, - SimpleADB, -) +from executorch.examples.qualcomm.utils import make_output_dir from torch.library import impl, Library my_op_lib = Library("my_ops", "DEF") @@ -165,6 +166,8 @@ def prepare_op_package( def main(args): + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) + if args.build_op_package: if "HEXAGON_SDK_ROOT" not in os.environ: raise RuntimeError("Environment variable HEXAGON_SDK_ROOT must be set") @@ -186,7 +189,7 @@ def main(args): sample_input = (torch.ones(1, 32, 28, 28),) workspace = f"/data/local/tmp/executorch/{pte_filename}" - soc_info = _soc_info_table[getattr(QcomChipset, args.model)] + soc_info = _soc_info_table[getattr(QcomChipset, args.soc_model)] op_package_options, op_package_paths = prepare_op_package( workspace, @@ -198,23 +201,19 @@ def main(args): quant_dtype=quant_dtype, custom_annotations=(annotate_custom,), backend=get_backend_type(args.backend), - soc_model=args.model, + soc_model=args.soc_model, ) build_executorch_binary( - instance, - sample_input, - args.model, - f"{args.artifact}/{pte_filename}", - sample_input, + model=instance, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=[sample_input], op_package_options=op_package_options, quant_dtype=quant_dtype, custom_quantizer=quantizer, ) - if args.compile_only: - sys.exit(0) - # collect output data output_data_folder = f"{args.artifact}/outputs" make_output_dir(output_data_folder) @@ -244,22 +243,14 @@ def main(args): capture_output=True, ) else: - # setup required paths accordingly - # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built - # pte_path : path where executorch binary was stored + # setup required params accordingly + # qnn_config : QnnConfig that saves config info # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=workspace, - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) adb.push(inputs=sample_input, files=op_package_paths) adb.execute() @@ -327,7 +318,6 @@ def main(args): ) args = parser.parse_args() - args.validate(args) try: main(args) diff --git a/examples/qualcomm/oss_scripts/albert.py b/examples/qualcomm/oss_scripts/albert.py index 92546a301db..b5fbb8a2615 100644 --- a/examples/qualcomm/oss_scripts/albert.py +++ b/examples/qualcomm/oss_scripts/albert.py @@ -13,30 +13,30 @@ import evaluate import numpy as np import torch + +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_masked_language_model_dataset, make_output_dir, - make_quantizer, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, ) + from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers.masking_utils import create_bidirectional_mask def main(args): - if args.compile_only and args.pre_gen_pte: - raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") - - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) os.makedirs(args.artifact, exist_ok=True) data_size = 100 @@ -77,59 +77,36 @@ def main(args): for input_ids, attention_mask in inputs ] - # Skip lowering/compilation if using pre-generated PTE - if not args.pre_gen_pte: - # lower to QNN - backend = get_backend_type(args.backend) - quantizer = { - QnnExecuTorchBackendType.kGpuBackend: None, - QnnExecuTorchBackendType.kHtpBackend: make_quantizer( - quant_dtype=QuantDtype.use_16a16w, - eps=2**-20, - backend=backend, - soc_model=args.model, - ), - }[backend] - build_executorch_binary( - module, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - backend=backend, - custom_quantizer=quantizer, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, - ) - - if args.compile_only: - return - - workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" - pte_path = ( - f"{args.pre_gen_pte}/{pte_filename}.pte" - if args.pre_gen_pte - else f"{args.artifact}/{pte_filename}.pte" + # lower to QNN + quantizer = { + QnnExecuTorchBackendType.kGpuBackend: None, + QnnExecuTorchBackendType.kHtpBackend: make_quantizer( + quant_dtype=QuantDtype.use_16a16w, + eps=2**-20, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, + ), + }[qnn_config.backend] + build_executorch_binary( + model=module, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, + custom_quantizer=quantizer, ) + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" + pte_path = f"{args.artifact}/{pte_filename}.pte" adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=pte_path, workspace=workspace, - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) output_data_folder = f"{args.artifact}/outputs" make_output_dir(output_data_folder) # accuracy analysis - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() adb.pull(host_output_path=args.artifact) # since the original nn.Module could not perform well on this task either @@ -187,7 +164,6 @@ def main(args): ) args = parser.parse_args() - args.validate(args) try: main(args) diff --git a/examples/qualcomm/oss_scripts/bert.py b/examples/qualcomm/oss_scripts/bert.py index 6043cb697ac..7eacd035f17 100644 --- a/examples/qualcomm/oss_scripts/bert.py +++ b/examples/qualcomm/oss_scripts/bert.py @@ -13,30 +13,28 @@ import evaluate import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_masked_language_model_dataset, make_output_dir, - make_quantizer, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, ) from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers.masking_utils import create_bidirectional_mask def main(args): - if args.compile_only and args.pre_gen_pte: - raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") - - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) os.makedirs(args.artifact, exist_ok=True) data_size = 100 @@ -77,59 +75,37 @@ def main(args): for input_ids, attention_mask in inputs ] - # Skip lowering/compilation if using pre-generated PTE - if not args.pre_gen_pte: - # lower to QNN - backend = get_backend_type(args.backend) - quantizer = { - QnnExecuTorchBackendType.kGpuBackend: None, - QnnExecuTorchBackendType.kHtpBackend: make_quantizer( - quant_dtype=QuantDtype.use_16a8w, - eps=2**-20, - backend=backend, - soc_model=args.model, - ), - }[backend] - build_executorch_binary( - module, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - backend=backend, - custom_quantizer=quantizer, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, - ) - - if args.compile_only: - return + # lower to QNN + quantizer = { + QnnExecuTorchBackendType.kGpuBackend: None, + QnnExecuTorchBackendType.kHtpBackend: make_quantizer( + quant_dtype=QuantDtype.use_16a8w, + eps=2**-20, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, + ), + }[qnn_config.backend] + build_executorch_binary( + model=module, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, + custom_quantizer=quantizer, + ) workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" - pte_path = ( - f"{args.pre_gen_pte}/{pte_filename}.pte" - if args.pre_gen_pte - else f"{args.artifact}/{pte_filename}.pte" - ) + pte_path = f"{args.artifact}/{pte_filename}.pte" adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=pte_path, workspace=workspace, - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) output_data_folder = f"{args.artifact}/outputs" make_output_dir(output_data_folder) # accuracy analysis - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() adb.pull(host_output_path=args.artifact) goldens, predictions = [], [] @@ -176,7 +152,6 @@ def main(args): ) args = parser.parse_args() - args.validate(args) try: main(args) diff --git a/examples/qualcomm/oss_scripts/conv_former.py b/examples/qualcomm/oss_scripts/conv_former.py index c43cbecdd5a..c913357e8ad 100644 --- a/examples/qualcomm/oss_scripts/conv_former.py +++ b/examples/qualcomm/oss_scripts/conv_former.py @@ -8,30 +8,30 @@ import logging import os -import sys from multiprocessing.connection import Client import numpy as np import timm import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -56,40 +56,24 @@ def main(args): model = model.eval() # lower to QNN - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - model, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, + model=model, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - sys.exit(0) - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -140,7 +124,6 @@ def main(args): ) args = parser.parse_args() - args.validate(args) try: main(args) diff --git a/examples/qualcomm/oss_scripts/convnext_small.py b/examples/qualcomm/oss_scripts/convnext_small.py index 394cbc85e38..4e543c768a6 100755 --- a/examples/qualcomm/oss_scripts/convnext_small.py +++ b/examples/qualcomm/oss_scripts/convnext_small.py @@ -14,23 +14,28 @@ import torch import torchvision + +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - make_quantizer, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) def main(args): + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) + # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -50,43 +55,29 @@ def main(args): pte_filename = "convnext_small_qnn" instance = torchvision.models.convnext_small(weights="IMAGENET1K_V1").eval() - backend = get_backend_type(args.backend) qnn_quantizer = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: make_quantizer( quant_dtype=QuantDtype.use_8a8w, per_channel_linear=True, - backend=backend, - soc_model=args.model, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, ), - }[backend] + }[qnn_config.backend] build_executorch_binary( - instance, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, + model=instance, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, custom_quantizer=qnn_quantizer, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -137,7 +128,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/cvt.py b/examples/qualcomm/oss_scripts/cvt.py index 8e9399c1f72..dafd00a6743 100644 --- a/examples/qualcomm/oss_scripts/cvt.py +++ b/examples/qualcomm/oss_scripts/cvt.py @@ -14,18 +14,19 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) from transformers import AutoModelForImageClassification @@ -93,7 +94,7 @@ def _replace_attention( def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -120,40 +121,24 @@ def main(args): # Fix prepare failed due to einsum module = _replace_attention(module) pte_filename = "cvt_qnn" - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - module.eval(), - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=module.eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -205,7 +190,6 @@ def main(args): ) args = parser.parse_args() - args.validate(args) try: main(args) diff --git a/examples/qualcomm/oss_scripts/deit.py b/examples/qualcomm/oss_scripts/deit.py index 68d38d266d9..98623a333b8 100644 --- a/examples/qualcomm/oss_scripts/deit.py +++ b/examples/qualcomm/oss_scripts/deit.py @@ -13,21 +13,19 @@ import numpy as np import torch -from executorch.backends.qualcomm._passes.qnn_pass_manager import ( - get_capture_program_passes, +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, ) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) from transformers import AutoConfig, AutoModelForImageClassification @@ -46,7 +44,7 @@ def get_instance(): def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) os.makedirs(args.artifact, exist_ok=True) config = AutoConfig.from_pretrained("facebook/deit-base-distilled-patch16-224") @@ -72,45 +70,28 @@ def main(args): pte_filename = "deit_qnn" # lower to QNN - passes_job = get_capture_program_passes() - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] + build_executorch_binary( - model, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", + model=model, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, quant_dtype=quant_dtype, - backend=backend, - passes_job=passes_job, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" pte_path = f"{args.artifact}/{pte_filename}.pte" adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=pte_path, workspace=workspace, - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -161,7 +142,6 @@ def main(args): ) args = parser.parse_args() - args.validate(args) try: main(args) diff --git a/examples/qualcomm/oss_scripts/dino_v2.py b/examples/qualcomm/oss_scripts/dino_v2.py index 5aa4a79788d..363eea7d429 100644 --- a/examples/qualcomm/oss_scripts/dino_v2.py +++ b/examples/qualcomm/oss_scripts/dino_v2.py @@ -10,23 +10,23 @@ from multiprocessing.connection import Client import numpy as np -import torch from executorch.backends.qualcomm._passes.qnn_pass_manager import ( get_capture_program_passes, ) +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) @@ -42,7 +42,7 @@ def get_instance(): def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -54,46 +54,29 @@ def main(args): image_shape=(256, 256), crop_size=img_size, ) - sample_input = (torch.randn((1, 3, img_size, img_size)),) pte_filename = "dino_v2" instance = get_instance() passes_job = get_capture_program_passes() - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - instance, - sample_input, - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=instance, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, passes_job=passes_job, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -145,7 +128,6 @@ def main(args): ) args = parser.parse_args() - args.validate(args) try: main(args) diff --git a/examples/qualcomm/oss_scripts/distilbert.py b/examples/qualcomm/oss_scripts/distilbert.py index 6fe2fd063e8..5e9f3aa0a14 100644 --- a/examples/qualcomm/oss_scripts/distilbert.py +++ b/examples/qualcomm/oss_scripts/distilbert.py @@ -13,6 +13,13 @@ import evaluate import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( @@ -20,24 +27,15 @@ ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_masked_language_model_dataset, make_output_dir, - make_quantizer, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, ) from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers.masking_utils import create_bidirectional_mask def main(args): - if args.compile_only and args.pre_gen_pte: - raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") - - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) os.makedirs(args.artifact, exist_ok=True) data_size = 100 @@ -78,59 +76,37 @@ def main(args): for input_ids, attention_mask in inputs ] - # Skip lowering/compilation if using pre-generated PTE - if not args.pre_gen_pte: - # lower to QNN - backend = get_backend_type(args.backend) - quantizer = { - QnnExecuTorchBackendType.kGpuBackend: None, - QnnExecuTorchBackendType.kHtpBackend: make_quantizer( - quant_dtype=QuantDtype.use_16a8w, - eps=2**-20, - backend=backend, - soc_model=args.model, - ), - }[backend] - build_executorch_binary( - module, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - backend=backend, - custom_quantizer=quantizer, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, - ) - - if args.compile_only: - return + # lower to QNN + quantizer = { + QnnExecuTorchBackendType.kGpuBackend: None, + QnnExecuTorchBackendType.kHtpBackend: make_quantizer( + quant_dtype=QuantDtype.use_16a8w, + eps=2**-20, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, + ), + }[qnn_config.backend] + build_executorch_binary( + model=module, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, + custom_quantizer=quantizer, + ) workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" - pte_path = ( - f"{args.pre_gen_pte}/{pte_filename}.pte" - if args.pre_gen_pte - else f"{args.artifact}/{pte_filename}.pte" - ) + pte_path = f"{args.artifact}/{pte_filename}.pte" adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=pte_path, workspace=workspace, - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) output_data_folder = f"{args.artifact}/outputs" make_output_dir(output_data_folder) # accuracy analysis - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() adb.pull(host_output_path=args.artifact) goldens, predictions = [], [] @@ -177,7 +153,6 @@ def main(args): ) args = parser.parse_args() - args.validate(args) try: main(args) diff --git a/examples/qualcomm/oss_scripts/dit.py b/examples/qualcomm/oss_scripts/dit.py index 41dd831259d..0e87045f789 100644 --- a/examples/qualcomm/oss_scripts/dit.py +++ b/examples/qualcomm/oss_scripts/dit.py @@ -12,22 +12,20 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) -from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, - make_output_dir, - make_quantizer, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, - topk_accuracy, -) +from executorch.examples.qualcomm.utils import make_output_dir, topk_accuracy from torchao.quantization.pt2e import HistogramObserver from transformers import AutoImageProcessor, AutoModelForImageClassification @@ -57,7 +55,7 @@ def get_rvlcdip_dataset(data_size): def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -82,45 +80,29 @@ def main(args): pte_filename = "dit_qnn" # Use HistogramObserver to get better performance - backend = get_backend_type(args.backend) quantizer = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: make_quantizer( quant_dtype=QuantDtype.use_8a8w, act_observer=HistogramObserver, - backend=backend, - soc_model=args.model, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, ), - }[backend] + }[qnn_config.backend] build_executorch_binary( - module.eval(), - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=module.eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, custom_quantizer=quantizer, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -159,7 +141,6 @@ def main(args): ) args = parser.parse_args() - args.validate(args) try: main(args) diff --git a/examples/qualcomm/oss_scripts/efficientSAM/efficientSAM.py b/examples/qualcomm/oss_scripts/efficientSAM/efficientSAM.py index 666e6e366d4..f346e471694 100644 --- a/examples/qualcomm/oss_scripts/efficientSAM/efficientSAM.py +++ b/examples/qualcomm/oss_scripts/efficientSAM/efficientSAM.py @@ -13,20 +13,18 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.examples.qualcomm.oss_scripts.efficientSAM.source_transformation import ( replace_maskdecoder_with_custom_op, replace_pos_emb_with_custom_op, ) -from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - class_agnostic_mIoU, - get_backend_type, - make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, -) +from executorch.examples.qualcomm.utils import class_agnostic_mIoU, make_output_dir from PIL import Image, ImageDraw from scipy.ndimage import label from torch.utils.data import DataLoader, Dataset @@ -211,7 +209,7 @@ def save_mask(mask, input, save_path): def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) os.makedirs(args.artifact, exist_ok=True) @@ -232,38 +230,22 @@ def main(args): pte_filename = "efficientSAM_qnn" # lower to QNN - backend = get_backend_type(args.backend) build_executorch_binary( - model, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", + model=model, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" pte_path = f"{args.artifact}/{pte_filename}.pte" adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=pte_path, workspace=workspace, - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -336,7 +318,7 @@ def post_process(): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/efficientnet.py b/examples/qualcomm/oss_scripts/efficientnet.py index a116ad95726..56f6d839e6e 100644 --- a/examples/qualcomm/oss_scripts/efficientnet.py +++ b/examples/qualcomm/oss_scripts/efficientnet.py @@ -13,26 +13,27 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - make_quantizer, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) from transformers import AutoModelForImageClassification def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -57,45 +58,29 @@ def main(args): .to("cpu") ) pte_filename = "efficientnet_qnn" - backend = get_backend_type(args.backend) quantizer = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: make_quantizer( quant_dtype=QuantDtype.use_16a16w, eps=2**-20, - backend=backend, - soc_model=args.model, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, ), - }[backend] + }[qnn_config.backend] build_executorch_binary( - module.eval(), - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=module, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, custom_quantizer=quantizer, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -148,7 +133,6 @@ def main(args): ) args = parser.parse_args() - args.validate(args) try: main(args) diff --git a/examples/qualcomm/oss_scripts/esrgan.py b/examples/qualcomm/oss_scripts/esrgan.py index 8d35b67166d..7783f1e86ae 100644 --- a/examples/qualcomm/oss_scripts/esrgan.py +++ b/examples/qualcomm/oss_scripts/esrgan.py @@ -12,20 +12,19 @@ import numpy as np import piq import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.scripts.edsr import get_dataset -from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, - make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, -) +from executorch.examples.qualcomm.utils import make_output_dir from torchvision.transforms.functional import to_pil_image @@ -45,7 +44,7 @@ def get_instance(repo: str): def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -58,40 +57,24 @@ def main(args): pte_filename = "esrgan_qnn" instance = get_instance(args.oss_repo) - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - instance, - (inputs[0],), - args.model, - f"{args.artifact}/{pte_filename}", - [(input,) for input in inputs], - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=instance, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=[(input,) for input in inputs], quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -178,7 +161,7 @@ def post_process(): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/eurobert.py b/examples/qualcomm/oss_scripts/eurobert.py index b45cbe053b6..b5478efd2b1 100644 --- a/examples/qualcomm/oss_scripts/eurobert.py +++ b/examples/qualcomm/oss_scripts/eurobert.py @@ -14,6 +14,14 @@ import torch import transformers +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) + from executorch.backends.qualcomm.quantizer.custom_annotation import annotate_eurobert from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( @@ -21,14 +29,8 @@ ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_masked_language_model_dataset, make_output_dir, - make_quantizer, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, ) from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer @@ -36,15 +38,12 @@ def main(args): - if args.compile_only and args.pre_gen_pte: - raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) assert ( transformers.__version__ >= TRANSFORMERS_VERSION ), f"Please ensure transformers version >= {TRANSFORMERS_VERSION}, current version is {transformers.__version__}" - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) - os.makedirs(args.artifact, exist_ok=True) model_name = "EuroBERT/EuroBERT-210m" @@ -82,62 +81,38 @@ def main(args): pte_filename = "eurobert_qnn" - # Skip lowering/compilation if using pre-generated PTE - if not args.pre_gen_pte: - # lower to QNN - def get_custom_quantizer(backend, soc_model): - quantizer = make_quantizer( - quant_dtype=QuantDtype.use_16a16w, - eps=2**-20, - backend=backend, - soc_model=soc_model, - ) - quantizer.add_custom_quant_annotations((annotate_eurobert,)) - return quantizer - - backend = get_backend_type(args.backend) - quantizer = { - QnnExecuTorchBackendType.kGpuBackend: None, - QnnExecuTorchBackendType.kHtpBackend: get_custom_quantizer( - backend, args.model - ), - }[backend] - with torch.no_grad(): - build_executorch_binary( - module, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - custom_quantizer=quantizer, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, - ) - - if args.compile_only: - return + # lower to QNN + def get_custom_quantizer(): + quantizer = make_quantizer( + quant_dtype=QuantDtype.use_16a16w, + eps=2**-20, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, + ) + quantizer.add_custom_quant_annotations((annotate_eurobert,)) + return quantizer + + quantizer = { + QnnExecuTorchBackendType.kGpuBackend: None, + QnnExecuTorchBackendType.kHtpBackend: get_custom_quantizer(), + }[qnn_config.backend] + with torch.no_grad(): + build_executorch_binary( + model=module, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, + custom_quantizer=quantizer, + ) - pte_path = ( - f"{args.pre_gen_pte}/{pte_filename}.pte" - if args.pre_gen_pte - else f"{args.artifact}/{pte_filename}.pte" - ) + pte_path = f"{args.artifact}/{pte_filename}.pte" adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=pte_path, workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) output_data_folder = f"{args.artifact}/outputs" - make_output_dir(output_data_folder, backends={backend}) + make_output_dir(output_data_folder) # accuracy analysis adb.push(inputs=inputs) @@ -187,7 +162,7 @@ def get_custom_quantizer(backend, soc_model): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/fastvit.py b/examples/qualcomm/oss_scripts/fastvit.py index 87d90bb61b7..e548cc75b17 100644 --- a/examples/qualcomm/oss_scripts/fastvit.py +++ b/examples/qualcomm/oss_scripts/fastvit.py @@ -12,6 +12,13 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( PerChannelParamObserver, ) @@ -27,14 +34,8 @@ ) from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - make_quantizer, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) from torchao.quantization.pt2e.quantizer import QuantizationSpec @@ -56,7 +57,7 @@ def get_instance(repo_path: str, checkpoint_path: str): def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -75,6 +76,8 @@ def get_custom_quantizer(backend, soc_model): quant_dtype=QuantDtype.use_8a8w, backend=backend, soc_model=soc_model, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, ) # there are lots of outliers appearing in fastvit parameters @@ -115,40 +118,26 @@ def get_custom_quantizer(backend, soc_model): return quantizer # lower to QNN - backend = get_backend_type(args.backend) quantizer = { QnnExecuTorchBackendType.kGpuBackend: None, - QnnExecuTorchBackendType.kHtpBackend: get_custom_quantizer(backend, args.model), - }[backend] + QnnExecuTorchBackendType.kHtpBackend: get_custom_quantizer(), + }[qnn_config.backend] build_executorch_binary( - convert_linear_to_conv2d(get_instance(args.oss_repo, args.pretrained_weight)), - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", + model=convert_linear_to_conv2d( + get_instance(args.oss_repo, args.pretrained_weight) + ), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, custom_quantizer=quantizer, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -220,7 +209,7 @@ def get_custom_quantizer(backend, soc_model): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/fbnet.py b/examples/qualcomm/oss_scripts/fbnet.py index 010fcb740a1..d2ad3c37a08 100755 --- a/examples/qualcomm/oss_scripts/fbnet.py +++ b/examples/qualcomm/oss_scripts/fbnet.py @@ -11,22 +11,26 @@ import numpy as np import timm +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) def main(args): + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) + # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -41,38 +45,24 @@ def main(args): pte_filename = "fbnet_qnn" - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - instance, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, + model=instance, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -136,7 +126,7 @@ def post_process(): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/focalnet.py b/examples/qualcomm/oss_scripts/focalnet.py index b4308a48163..5fc065ded7f 100644 --- a/examples/qualcomm/oss_scripts/focalnet.py +++ b/examples/qualcomm/oss_scripts/focalnet.py @@ -13,25 +13,26 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) from transformers import AutoModelForImageClassification def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -56,40 +57,24 @@ def main(args): .to("cpu") ) pte_filename = "focalnet_qnn" - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - module.eval(), - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=module.eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -142,7 +127,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/gMLP_image_classification.py b/examples/qualcomm/oss_scripts/gMLP_image_classification.py index f6070bf6105..cb339653b93 100644 --- a/examples/qualcomm/oss_scripts/gMLP_image_classification.py +++ b/examples/qualcomm/oss_scripts/gMLP_image_classification.py @@ -12,26 +12,27 @@ import numpy as np import timm -import torch + +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -45,42 +46,25 @@ def main(args): pte_filename = "gMLP_image_classification_qnn" model = timm.create_model("gmlp_s16_224", pretrained=True).eval() - sample_input = (torch.randn(1, 3, 224, 224),) - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - model, - sample_input, - args.model, - f"{args.artifact}/{pte_filename}", + model=model, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -132,7 +116,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/llama/artifacts/README.md b/examples/qualcomm/oss_scripts/llama/artifacts/README.md index f0e96aee711..14c5d419a0b 100644 --- a/examples/qualcomm/oss_scripts/llama/artifacts/README.md +++ b/examples/qualcomm/oss_scripts/llama/artifacts/README.md @@ -37,7 +37,7 @@ echo '{"dim": 64, "n_layers": 5, "n_heads": 8, "n_kv_heads": 4, "vocab_size": 51 ``` bash # Checks accuracy with weight sharing disabled since x86 does not support weight sharing. -python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_260k --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./examples/qualcomm/oss_scripts/llama/artifacts --llama_artifacts . --enable_x86_64 --compile_only +python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_260k --soc_model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./examples/qualcomm/oss_scripts/llama/artifacts --llama_artifacts . --enable_x86_64 --compile_only ``` 4. Commit the hybrid_llama_qnn.pte file to the repository. diff --git a/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py b/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py index 55d7409a1e6..92c901b3990 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py @@ -15,6 +15,12 @@ import numpy as np import torch + +from executorch.backends.qualcomm.export_utils import ( + generate_inputs, + QnnConfig, + SimpleADB, +) from executorch.examples.models.llama.evaluate.eager_eval import EagerEvalWrapper from executorch.examples.qualcomm.oss_scripts.llama.decoder_constants import ( ATTENTION_SINK_EVICTOR, @@ -29,11 +35,9 @@ INFERENCE_REGISTRY, retrieve_info_from_pte, ) -from executorch.examples.qualcomm.utils import ( - generate_inputs, - make_output_dir, - SimpleADB, -) + +from executorch.examples.qualcomm.utils import make_output_dir + from pytorch_tokenizers.hf_tokenizer import HuggingFaceTokenizer from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer from pytorch_tokenizers.tiktoken import TiktokenTokenizer @@ -196,17 +200,15 @@ def _init_runner_base_cmd(self): @final def _get_adb(self): args = self.args + qnn_config = QnnConfig.load_config( + args.config_file if args.config_file else args + ) if EvalBase._adb is None: EvalBase._adb = SimpleADB( - qnn_sdk=self.qnn_sdk, - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=list(self.pte_paths.values()), workspace=self.device_workspace, - device_id=args.device, - host_id=args.host, - soc_model=args.model, runner=f"examples/qualcomm/oss_scripts/llama/{self.runner}", - target=args.target, ) return EvalBase._adb diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 5449599acc2..3aa1fa81610 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -15,6 +15,10 @@ from typing import Dict import torch +from executorch.backends.qualcomm.export_utils import ( + get_backend_type, + setup_common_args_and_variables, +) from executorch.backends.qualcomm.utils.utils import ( generate_htp_compiler_spec, @@ -54,10 +58,6 @@ MultiModalManager, next_power_of_two, ) -from executorch.examples.qualcomm.utils import ( - get_backend_type, - setup_common_args_and_variables, -) from torchao.quantization.utils import compute_error @@ -128,7 +128,7 @@ def compile( use_fp16=to_skip, ) encoder_compile_specs = generate_qnn_executorch_compiler_spec( - soc_model=get_soc_to_chipset_map()[args.model], + soc_model=get_soc_to_chipset_map()[args.soc_model], backend_options=backend_options, # x86 emulator does not support shared buffer shared_buffer=not args.enable_x86_64, @@ -143,7 +143,7 @@ def compile( ) compile_specs[modality] = [ generate_qnn_executorch_compiler_spec( - soc_model=get_soc_to_chipset_map()[args.model], + soc_model=get_soc_to_chipset_map()[args.soc_model], backend_options=backend_options, # x86 emulator does not support shared buffer shared_buffer=not args.enable_x86_64, @@ -159,7 +159,7 @@ def compile( ) compile_specs[modality] = [ generate_qnn_executorch_compiler_spec( - soc_model=get_soc_to_chipset_map()[args.model], + soc_model=get_soc_to_chipset_map()[args.soc_model], backend_options=backend_options, # x86 emulator does not support shared buffer shared_buffer=not args.enable_x86_64, @@ -173,7 +173,7 @@ def compile( skip_quantize=skip_quantize, tokenizer=tokenizer, backend=get_backend_type(args.backend), - soc_model=args.model, + soc_model=args.soc_model, ) # perform compilation diff --git a/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py b/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py index 3a1bab412de..008990d5192 100644 --- a/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py +++ b/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py @@ -17,6 +17,8 @@ import torch import torch.nn as nn + +from executorch.backends.qualcomm.export_utils import make_quantizer from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( PerChannelParamObserver, ) @@ -24,7 +26,6 @@ from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) -from executorch.examples.qualcomm.utils import make_quantizer from torchao.prototype.quantization.module_swap import ( QuantizationRecipe, diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py index 9f3bffb92f0..48386f181d8 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py @@ -17,6 +17,7 @@ get_capture_program_passes, ) from executorch.backends.qualcomm.builders.utils import is_graph_output +from executorch.backends.qualcomm.export_utils import make_quantizer from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.utils.constants import ( QCOM_DTYPE, @@ -57,7 +58,6 @@ Processor, Request, ) -from executorch.examples.qualcomm.utils import make_quantizer from executorch.exir._serialize._program import deserialize_pte_binary from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass @@ -436,7 +436,7 @@ def compile(self, attention_sink_evictor_pte_path: str): backend_options = generate_htp_compiler_spec(use_fp16=False) compiler_specs = [ generate_qnn_executorch_compiler_spec( - soc_model=get_soc_to_chipset_map()[self.control_args.model], + soc_model=get_soc_to_chipset_map()[self.control_args.soc_model], backend_options=backend_options, shared_buffer=not self.control_args.enable_x86_64, ) diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py index 4d63b2471ff..2d0713b175e 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py @@ -24,6 +24,7 @@ get_passes_dependency_for_capture_program, ) from executorch.backends.qualcomm.builders.utils import is_graph_output +from executorch.backends.qualcomm.export_utils import make_quantizer from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.utils.constants import ( @@ -81,7 +82,6 @@ Processor, Request, ) -from executorch.examples.qualcomm.utils import make_quantizer from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.dialects._ops import ops as exir_ops diff --git a/examples/qualcomm/oss_scripts/llm_utils/README.md b/examples/qualcomm/oss_scripts/llm_utils/README.md index a713d6df5de..84c23c294fa 100644 --- a/examples/qualcomm/oss_scripts/llm_utils/README.md +++ b/examples/qualcomm/oss_scripts/llm_utils/README.md @@ -32,7 +32,7 @@ The script evaluates the model by running the PTE file on a connected Qualcomm d * `--tasks`: (Optional, default: `["wikitext"]`) A list of `lm-evaluation-harness` tasks to evaluate. You can specify multiple tasks separated by spaces (e.g., `--tasks wikitext piqa`). * `--limit`: (Optional) Number of samples to evaluate per task. If not set, all samples will be evaluated. * `--num_fewshot`: (Optional) Number of examples to use in few-shot context for evaluation. -* `--model`: (Required for QNN execution) The SoC model name (e.g., `SM8550`, `SM8650`). +* `--soc_model`: (Required for QNN execution) The SoC model name (e.g., `SM8550`, `SM8650`). * `--device`: (Required for QNN execution) The ADB device ID. * `--host`: (Required for QNN execution) The ADB host ID (usually `localhost`). * `--build_folder`: (Optional, default: `build-android`) The build folder for ExecuTorch artifacts, relative to the current directory. @@ -44,7 +44,7 @@ python examples/qualcomm/oss_scripts/llm_utils/eval_decoder_model_qnn.py \ --artifact ./eval_output \ --tokenizer_path /path/to/your/tokenizer.model \ --pte /path/to/your/model.pte \ - --model SM8550 \ + --soc_model SM8550 \ --device YOUR_DEVICE_ID \ --host localhost \ --tasks wikitext \ @@ -62,7 +62,7 @@ python examples/qualcomm/oss_scripts/llm_utils/eval_decoder_model_qnn.py \ --tokenizer_path /path/to/your/tokenizer.model \ --pte /path/to/your/model.pte \ --logits_quant_attr_path /path/to/your/logits_quant_attrs.json \ - --model SM8550 \ + --soc_model SM8550 \ --device YOUR_DEVICE_ID \ --host localhost \ --tasks wikitext \ diff --git a/examples/qualcomm/oss_scripts/llm_utils/eval_decoder_model_qnn.py b/examples/qualcomm/oss_scripts/llm_utils/eval_decoder_model_qnn.py index 461224e1ccf..d52078b27c7 100644 --- a/examples/qualcomm/oss_scripts/llm_utils/eval_decoder_model_qnn.py +++ b/examples/qualcomm/oss_scripts/llm_utils/eval_decoder_model_qnn.py @@ -14,12 +14,13 @@ import numpy as np import torch -from executorch.examples.models.llama.evaluate.eager_eval import EagerEvalWrapper - -from executorch.examples.qualcomm.utils import ( - make_output_dir, +from executorch.backends.qualcomm.export_utils import ( + QnnConfig, setup_common_args_and_variables, ) +from executorch.examples.models.llama.evaluate.eager_eval import EagerEvalWrapper + +from executorch.examples.qualcomm.utils import make_output_dir from lm_eval.evaluator import simple_evaluate from pytorch_tokenizers import get_tokenizer @@ -110,21 +111,23 @@ def __init__( super().__init__(None, tokenizer, max_seq_length) import getpass - from executorch.examples.qualcomm.utils import SimpleADB + from executorch.backends.qualcomm.export_utils import SimpleADB self._model = model self.output_dir = output_dir self.quant_attrs = quant_attrs workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/meta_llama" + qnn_config = QnnConfig( + build_folder=build_folder, + device=device, + host=host, + soc_model=soc_model, + target=target, + ) self.adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=build_folder, + qnn_config=qnn_config, pte_path=model, workspace=workspace, - device_id=device, - host_id=host, - soc_model=soc_model, - target=target, ) self.adb.push() @@ -196,7 +199,7 @@ def gen_eval_wrapper( return QNNRunnerEvalWrapper( model=model, tokenizer=tokenizer, - soc_model=args.model, + soc_model=args.soc_model, device=args.device, host=args.host, max_seq_length=args.max_seq_len - 1, diff --git a/examples/qualcomm/oss_scripts/llm_utils/qnn_decoder_model_manager.py b/examples/qualcomm/oss_scripts/llm_utils/qnn_decoder_model_manager.py index c1074ba0dce..ca9f7af8fb0 100644 --- a/examples/qualcomm/oss_scripts/llm_utils/qnn_decoder_model_manager.py +++ b/examples/qualcomm/oss_scripts/llm_utils/qnn_decoder_model_manager.py @@ -15,6 +15,7 @@ get_capture_program_passes, ) from executorch.backends.qualcomm.builders.utils import is_graph_output +from executorch.backends.qualcomm.export_utils import make_quantizer from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.utils.constants import ( QCOM_PASS_ACTIVATE_KEY, @@ -31,7 +32,6 @@ from executorch.examples.qualcomm.oss_scripts.llm_utils.decoder_model_wrapper import ( QnnCausalLMExportableModule, ) -from executorch.examples.qualcomm.utils import make_quantizer from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from pytorch_tokenizers import get_tokenizer diff --git a/examples/qualcomm/oss_scripts/maxvit_t.py b/examples/qualcomm/oss_scripts/maxvit_t.py index d3b4894045d..8b5c6f28b1f 100755 --- a/examples/qualcomm/oss_scripts/maxvit_t.py +++ b/examples/qualcomm/oss_scripts/maxvit_t.py @@ -16,19 +16,21 @@ import torch import torch.nn.functional as F import torchvision +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - make_quantizer, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) from torchvision.models.maxvit import ( @@ -127,6 +129,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def main(args): + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -160,43 +163,29 @@ def main(args): forward, attn_sub_layer ) - backend = get_backend_type(args.backend) quantizer = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: make_quantizer( quant_dtype=QuantDtype.use_8a8w, per_channel_linear=True, - backend=backend, - soc_model=args.model, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, ), - }[backend] + }[qnn_config.backend] build_executorch_binary( - instance, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, + model=instance, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, custom_quantizer=quantizer, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -247,7 +236,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/mobilevit_v1.py b/examples/qualcomm/oss_scripts/mobilevit_v1.py index 68df356a4ab..8e78473062f 100644 --- a/examples/qualcomm/oss_scripts/mobilevit_v1.py +++ b/examples/qualcomm/oss_scripts/mobilevit_v1.py @@ -13,20 +13,18 @@ import numpy as np import torch -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -from executorch.backends.qualcomm.serialization.qc_schema import ( - QnnExecuTorchBackendType, -) -from executorch.examples.qualcomm.utils import ( +from executorch.backends.qualcomm.export_utils import ( build_executorch_binary, - get_backend_type, - make_output_dir, make_quantizer, - parse_skip_delegation_node, + QnnConfig, setup_common_args_and_variables, SimpleADB, - topk_accuracy, ) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.serialization.qc_schema import ( + QnnExecuTorchBackendType, +) +from executorch.examples.qualcomm.utils import make_output_dir, topk_accuracy from PIL import Image from torchvision import datasets from transformers import AutoImageProcessor, AutoModelForImageClassification @@ -58,7 +56,7 @@ def get_data_loader(): def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -82,45 +80,29 @@ def main(args): ) pte_filename = "mobilevit_v1_qnn" - backend = get_backend_type(args.backend) quantizer = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: make_quantizer( quant_dtype=QuantDtype.use_16a8w, eps=2**-12, - backend=backend, - soc_model=args.model, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, ), - }[backend] + }[qnn_config.backend] build_executorch_binary( - module.eval(), - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, + model=module.eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, custom_quantizer=quantizer, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -173,7 +155,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/mobilevit_v2.py b/examples/qualcomm/oss_scripts/mobilevit_v2.py index 15b25edb93b..2c6d13e67a1 100644 --- a/examples/qualcomm/oss_scripts/mobilevit_v2.py +++ b/examples/qualcomm/oss_scripts/mobilevit_v2.py @@ -14,20 +14,18 @@ import numpy as np import torch -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -from executorch.backends.qualcomm.serialization.qc_schema import ( - QnnExecuTorchBackendType, -) -from executorch.examples.qualcomm.utils import ( +from executorch.backends.qualcomm.export_utils import ( build_executorch_binary, - get_backend_type, - make_output_dir, make_quantizer, - parse_skip_delegation_node, + QnnConfig, setup_common_args_and_variables, SimpleADB, - topk_accuracy, ) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.serialization.qc_schema import ( + QnnExecuTorchBackendType, +) +from executorch.examples.qualcomm.utils import make_output_dir, topk_accuracy from PIL import Image from torchvision import datasets from transformers import AutoImageProcessor, AutoModelForImageClassification @@ -64,7 +62,7 @@ def main(args): stacklevel=1, ) - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -90,43 +88,27 @@ def main(args): ) pte_filename = "mobilevit_v2_qnn" - backend = get_backend_type(args.backend) quantizer = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: make_quantizer( quant_dtype=QuantDtype.use_16a8w, eps=2**-10, ), - }[backend] + }[qnn_config.backend] build_executorch_binary( - module, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=module, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, custom_quantizer=quantizer, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -178,7 +160,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/moshi/mimi.py b/examples/qualcomm/oss_scripts/moshi/mimi.py index e50d70c00f5..31cabbd8ee8 100644 --- a/examples/qualcomm/oss_scripts/moshi/mimi.py +++ b/examples/qualcomm/oss_scripts/moshi/mimi.py @@ -20,6 +20,14 @@ import torch.nn as nn import torchaudio +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) + from executorch.backends.qualcomm.quantizer.custom_annotation import ( annotate_mimi_decoder, ) @@ -40,14 +48,7 @@ get_static_mimi, MIMI_NAME, ) -from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - make_output_dir, - make_quantizer, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, -) +from executorch.examples.qualcomm.utils import make_output_dir from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass @@ -152,10 +153,9 @@ def init_inputs(): def compile_mimi_encoder( args, + qnn_config, orig_mimi, encoder_inputs, - skip_node_id_set, - skip_node_op_set, encoder_pte_filename, ): class MimiEncode(nn.Module): @@ -168,29 +168,19 @@ def forward(self, x): mimi_encoder_model = MimiEncode(orig_mimi) build_executorch_binary( - mimi_encoder_model.eval(), - encoder_inputs[0], - args.model, - f"{args.artifact}/{encoder_pte_filename}", - encoder_inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=mimi_encoder_model.eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{encoder_pte_filename}", + dataset=encoder_inputs, quant_dtype=QuantDtype.use_8a8w, - shared_buffer=args.shared_buffer, ) -def inference_mimi_encoder(args, encoder_inputs, encoder_pte_filename): +def inference_mimi_encoder(args, qnn_config, encoder_inputs, encoder_pte_filename): adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{encoder_pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{encoder_pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) adb.push(inputs=encoder_inputs) adb.execute() @@ -210,9 +200,7 @@ def inference_mimi_encoder(args, encoder_inputs, encoder_pte_filename): return encoder_predictions -def export_mimi_encoder( - args, orig_mimi, sample_pcm, pcm_chunk_size, skip_node_id_set, skip_node_op_set -): +def export_mimi_encoder(args, qnn_config, orig_mimi, sample_pcm, pcm_chunk_size): encoder_inputs = [] count = 0 cpu_encoded_results = [] @@ -235,16 +223,16 @@ def export_mimi_encoder( logging.info("Compile only for QNN Encoder") compile_mimi_encoder( args, + qnn_config, orig_mimi, encoder_inputs, - skip_node_id_set, - skip_node_op_set, encoder_pte_filename, ) elif args.pre_gen_pte: logging.info("Inference only for QNN Encoder") qnn_encoded_results = inference_mimi_encoder( args, + qnn_config, encoder_inputs, encoder_pte_filename, ) @@ -252,14 +240,14 @@ def export_mimi_encoder( logging.info("Compile and Inference for QNN Encoder") compile_mimi_encoder( args, + qnn_config, orig_mimi, encoder_inputs, - skip_node_id_set, - skip_node_op_set, encoder_pte_filename, ) qnn_encoded_results = inference_mimi_encoder( args, + qnn_config, encoder_inputs, encoder_pte_filename, ) @@ -276,10 +264,9 @@ def export_mimi_encoder( def compile_static_mimi_decoder( args, + qnn_config, static_mimi_decoder, encoded_results, - skip_node_id_set, - skip_node_op_set, static_decoder_pte_filename, ): quantizer = make_quantizer( @@ -288,7 +275,7 @@ def compile_static_mimi_decoder( per_channel_linear=True, act_observer=MinMaxObserver, backend=QnnExecuTorchBackendType.kHtpBackend, - soc_model=args.model, + soc_model=qnn_config.soc_model, ) quantizer.add_custom_quant_annotations((annotate_mimi_decoder,)) @@ -314,7 +301,7 @@ def compile_static_mimi_decoder( backend_options = generate_htp_compiler_spec(use_fp16=False) compiler_spec = generate_qnn_executorch_compiler_spec( - soc_model=get_soc_to_chipset_map()[args.model], + soc_model=get_soc_to_chipset_map()[args.soc_model], backend_options=backend_options, ) @@ -325,8 +312,8 @@ def compile_static_mimi_decoder( *static_states, ), compiler_spec, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + skip_node_id_set=qnn_config.skip_delegate_node_ids, + skip_node_op_set=qnn_config.skip_delegate_node_ops, ) executorch_config = ExecutorchBackendConfig( @@ -342,8 +329,8 @@ def compile_static_mimi_decoder( def inference_static_mimi_decoder( args, + qnn_config, encoded_results, - encoded_results_list, pcm_chunk_size, static_decoder_pte_filename, ): @@ -357,17 +344,10 @@ def inference_static_mimi_decoder( f"--output_folder_path {workspace}/outputs", ] ) - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=pte_path, workspace=workspace, - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, runner="examples/qualcomm/oss_scripts/moshi/qnn_mimi_decoder_runner", ) adb.push(inputs=encoded_results) @@ -392,11 +372,10 @@ def inference_static_mimi_decoder( def export_mimi_decoder( args, + qnn_config, static_mimi_decoder, encoded_results, pcm_chunk_size, - skip_node_id_set, - skip_node_op_set, ): encoded_results_list = "" for index, encoder_result in enumerate(encoded_results): @@ -411,16 +390,16 @@ def export_mimi_decoder( logging.info("Compile only for QNN Static Decoder") compile_static_mimi_decoder( args, + qnn_config, static_mimi_decoder, encoded_results, - skip_node_id_set, - skip_node_op_set, static_decoder_pte_filename, ) elif args.pre_gen_pte: logging.info("Inference only for QNN Static Decoder") qnn_decode_res = inference_static_mimi_decoder( args, + qnn_config, encoded_results, encoded_results_list, pcm_chunk_size, @@ -430,14 +409,14 @@ def export_mimi_decoder( logging.info("Compile and Inference for QNN Static Decoder") compile_static_mimi_decoder( args, + qnn_config, static_mimi_decoder, encoded_results, - skip_node_id_set, - skip_node_op_set, static_decoder_pte_filename, ) qnn_decode_res = inference_static_mimi_decoder( args, + qnn_config, encoded_results, encoded_results_list, pcm_chunk_size, @@ -451,8 +430,7 @@ def main(args): moshi.__version__ == MOSHI_VERSION ), f"Please ensure Moshi version == {MOSHI_VERSION}, current version is {moshi.__version__}" - if args.compile_only and args.pre_gen_pte: - exit("Cannot set both compile_only and pre_gen_pte as true") + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) logging.info("loading mimi") if args.mimi_weight is None: @@ -461,7 +439,6 @@ def main(args): static_mimi = get_static_mimi(args.mimi_weight, "cpu") # For static decoder logging.info("mimi loaded") - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) os.makedirs(args.artifact, exist_ok=True) sample_rate = orig_mimi.sample_rate @@ -480,19 +457,17 @@ def main(args): with torch.no_grad(): encoded_results, cpu_encoded_results = export_mimi_encoder( args, + qnn_config, orig_mimi, sample_pcm, pcm_chunk_size, - skip_node_id_set, - skip_node_op_set, ) qnn_decode_res = export_mimi_decoder( args, + qnn_config, static_mimi, encoded_results, pcm_chunk_size, - skip_node_id_set, - skip_node_op_set, ) if args.compile_only: diff --git a/examples/qualcomm/oss_scripts/pvt.py b/examples/qualcomm/oss_scripts/pvt.py index 5db3bbecc5c..b61a1fe99c6 100644 --- a/examples/qualcomm/oss_scripts/pvt.py +++ b/examples/qualcomm/oss_scripts/pvt.py @@ -12,25 +12,26 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) from transformers import AutoModelForImageClassification def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -56,40 +57,24 @@ def main(args): ) pte_filename = "pvt_qnn" - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - module.eval(), - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=module.eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -141,7 +126,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py b/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py index 86266018da4..70641af8fb7 100644 --- a/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py +++ b/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import argparse import getpass import json import logging @@ -12,6 +13,12 @@ from multiprocessing.connection import Client import torch +from executorch.backends.qualcomm.export_utils import ( + get_backend_type, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype @@ -19,14 +26,7 @@ get_qnn_llm_edge_manager, HUGGING_FACE_REPO_IDS, ) - -from executorch.examples.qualcomm.utils import ( - get_backend_type, - make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, -) +from executorch.examples.qualcomm.utils import make_output_dir from transformers import AutoTokenizer @@ -37,8 +37,7 @@ PTE_FILENAME = "qwen_qnn_q16" -def compile(args): # noqa: C901 - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) +def compile(args: argparse.Namespace, qnn_config: QnnConfig): # noqa: C901 # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -76,12 +75,14 @@ def compile(args): # noqa: C901 args.calibration_limit, args.prompt, tokenizer_json_path, - get_backend_type(args.backend), - args.model, + get_backend_type(qnn_config.backend), + qnn_config.soc_model, ) manager.to_edge_transform_and_lower_to_qnn( - args.model, skip_node_id_set, skip_node_op_set + qnn_config.soc_model, + qnn_config.skip_delegate_node_ids, + qnn_config.skip_delegate_node_ops, ) if args.ptq: logits_quant_attrs = manager.get_logits_quant_attrs() @@ -96,7 +97,7 @@ def compile(args): # noqa: C901 manager.to_executorch(args.artifact, PTE_FILENAME) -def inference(args): +def inference(args: argparse, qnn_config: QnnConfig): workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{PTE_FILENAME}" pte_path = f"{args.artifact}/{PTE_FILENAME}.pte" # collect output data @@ -151,15 +152,10 @@ def post_process(): ] ) adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=pte_path, workspace=workspace, - device_id=args.device, - host_id=args.host, - soc_model=args.model, runner="examples/models/llama/llama_main", - target=args.target, ) # No pregen inputs, input_list is not required adb.push(inputs=[], input_list="", files=[tokenizer_json_path]) @@ -182,16 +178,15 @@ def post_process(): def main(args): - if args.compile_only and args.pre_gen_pte: - raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) if args.compile_only: - compile(args) + compile(args, qnn_config) elif args.pre_gen_pte: - inference(args) + inference(args, qnn_config) else: - compile(args) - inference(args) + compile(args, qnn_config) + inference(args, qnn_config) if __name__ == "__main__": @@ -254,7 +249,7 @@ def main(args): try: args = parser.parse_args() - args.validate(args) + if args.artifact is None: args.artifact = args.decoder_model main(args) diff --git a/examples/qualcomm/oss_scripts/regnet.py b/examples/qualcomm/oss_scripts/regnet.py index 7d4a10c265c..c317e1cb0e2 100644 --- a/examples/qualcomm/oss_scripts/regnet.py +++ b/examples/qualcomm/oss_scripts/regnet.py @@ -10,18 +10,19 @@ from multiprocessing.connection import Client import numpy as np +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) @@ -34,7 +35,7 @@ def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -56,40 +57,24 @@ def main(args): model = regnet_x_400mf(weights=weights).eval() pte_filename = "regnet_x_400mf" - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - model, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=model, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -148,7 +133,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/retinanet.py b/examples/qualcomm/oss_scripts/retinanet.py index a47ade58712..f3565f95a6c 100644 --- a/examples/qualcomm/oss_scripts/retinanet.py +++ b/examples/qualcomm/oss_scripts/retinanet.py @@ -7,24 +7,22 @@ import json import os -import sys from multiprocessing.connection import Client import numpy as np import torch -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -from executorch.backends.qualcomm.serialization.qc_schema import ( - QnnExecuTorchBackendType, -) -from executorch.examples.qualcomm.utils import ( +from executorch.backends.qualcomm.export_utils import ( build_executorch_binary, - get_backend_type, - make_output_dir, - parse_skip_delegation_node, + QnnConfig, setup_common_args_and_variables, SimpleADB, ) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.serialization.qc_schema import ( + QnnExecuTorchBackendType, +) +from executorch.examples.qualcomm.utils import make_output_dir def get_instance(): @@ -213,7 +211,7 @@ def main(args): from torchvision.models.detection.image_list import ImageList - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist os.makedirs(args.artifact, exist_ok=True) @@ -228,40 +226,24 @@ def main(args): data_size=data_num, dataset_dir=args.dataset ) pte_filename = "retinanet_qnn" - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - model, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=model, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - sys.exit(0) - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -326,7 +308,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/roberta.py b/examples/qualcomm/oss_scripts/roberta.py index 4d0c5fca5a7..7545c974822 100644 --- a/examples/qualcomm/oss_scripts/roberta.py +++ b/examples/qualcomm/oss_scripts/roberta.py @@ -13,6 +13,13 @@ import evaluate import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( @@ -20,21 +27,15 @@ ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_masked_language_model_dataset, make_output_dir, - make_quantizer, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, ) from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers.masking_utils import create_bidirectional_mask def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) os.makedirs(args.artifact, exist_ok=True) data_size = 100 @@ -76,46 +77,30 @@ def main(args): ] # lower to QNN - backend = get_backend_type(args.backend) quantizer = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: make_quantizer( quant_dtype=QuantDtype.use_16a8w, eps=2**-20, - backend=backend, - soc_model=args.model, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, ), - }[backend] + }[qnn_config.backend] build_executorch_binary( - module, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", + model=module, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - backend=backend, custom_quantizer=quantizer, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" pte_path = f"{args.artifact}/{pte_filename}.pte" adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=pte_path, workspace=workspace, - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) output_data_folder = f"{args.artifact}/outputs" make_output_dir(output_data_folder) @@ -137,7 +122,7 @@ def main(args): ) sample_input = tuple(sample_input.values()) golden = module(*sample_input)[0] - adb.push(inputs=[sample_input], backends={backend}) + adb.push(inputs=[sample_input]) adb.execute() adb.pull(host_output_path=args.artifact) @@ -149,7 +134,7 @@ def main(args): print(f"QNN output: {tokenizer.batch_decode(predictions.argmax(axis=2))}") # accuracy analysis - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() adb.pull(host_output_path=args.artifact) goldens, predictions = [], [] @@ -195,7 +180,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/squeezenet.py b/examples/qualcomm/oss_scripts/squeezenet.py index a91fb8458c9..c937f84767f 100644 --- a/examples/qualcomm/oss_scripts/squeezenet.py +++ b/examples/qualcomm/oss_scripts/squeezenet.py @@ -12,24 +12,25 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -47,40 +48,24 @@ def main(args): "squeezenet1_1", weights="SqueezeNet1_1_Weights.DEFAULT", ) - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - instance.eval(), - (torch.randn(1, 3, 224, 224),), - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=instance.eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -133,7 +118,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/ssd300_vgg16.py b/examples/qualcomm/oss_scripts/ssd300_vgg16.py index 8592babb1e3..6977a68db48 100644 --- a/examples/qualcomm/oss_scripts/ssd300_vgg16.py +++ b/examples/qualcomm/oss_scripts/ssd300_vgg16.py @@ -13,19 +13,18 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) -from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, - make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, -) +from executorch.examples.qualcomm.utils import make_output_dir def create_data_lists(voc07_path, data_size): @@ -125,7 +124,7 @@ def SSD300VGG16(pretrained_weight_model): def main(args): sys.path.insert(0, args.oss_repo) - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -138,41 +137,24 @@ def main(args): pte_filename = "ssd300_vgg16_qnn" model = SSD300VGG16(args.pretrained_weight) - sample_input = (torch.randn((1, 3, 300, 300)),) - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - model, - sample_input, - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=model, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -282,7 +264,7 @@ def post_process(): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/swin_transformer.py b/examples/qualcomm/oss_scripts/swin_transformer.py index 54eb8e522a1..bc1b05869ad 100644 --- a/examples/qualcomm/oss_scripts/swin_transformer.py +++ b/examples/qualcomm/oss_scripts/swin_transformer.py @@ -13,18 +13,19 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) @@ -81,7 +82,7 @@ def window_reverse(windows, window_size, height, width): def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -109,40 +110,24 @@ def main(args): ) pte_filename = "swin_qnn" - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - module.eval(), - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=module.eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -195,7 +180,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/swin_v2_t.py b/examples/qualcomm/oss_scripts/swin_v2_t.py index ea39a08a316..9ad23284056 100755 --- a/examples/qualcomm/oss_scripts/swin_v2_t.py +++ b/examples/qualcomm/oss_scripts/swin_v2_t.py @@ -21,19 +21,21 @@ QCOM_PASS_ACTIVATE_KEY, QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, ) +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - make_quantizer, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -76,6 +78,8 @@ def main(args): # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) + data_num = 100 if args.ci: inputs = [(torch.rand(1, 3, 224, 224),)] @@ -99,45 +103,31 @@ def main(args): } passes_dep = get_passes_dependency_for_capture_program() passes_dep[RewritePartition] = [FoldQDQ] - backend = get_backend_type(args.backend) qnn_quantizer = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: make_quantizer( quant_dtype=QuantDtype.use_8a8w, per_channel_linear=True, - backend=backend, - soc_model=args.model, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, ), - }[backend] + }[qnn_config.backend] build_executorch_binary( - instance, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, + model=instance, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, custom_quantizer=qnn_quantizer, - backend=backend, - shared_buffer=args.shared_buffer, passes_job=passes_job, passes_dependency=passes_dep, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -188,7 +178,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/t5/t5.py b/examples/qualcomm/oss_scripts/t5/t5.py index c9ab864b6e4..ab8396fbe97 100644 --- a/examples/qualcomm/oss_scripts/t5/t5.py +++ b/examples/qualcomm/oss_scripts/t5/t5.py @@ -11,6 +11,12 @@ from multiprocessing.connection import Client import torch +from executorch.backends.qualcomm.export_utils import ( + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QcomChipset, @@ -32,12 +38,8 @@ ) from executorch.examples.qualcomm.utils import ( evaluate_squad, - get_backend_type, get_seq2seq_dataset_from_squad_csv, - make_quantizer, replace_module_with_custom_class, - setup_common_args_and_variables, - SimpleADB, ) from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass @@ -102,9 +104,7 @@ def __init__( self.exported_decoder = None self.quant_dtype = None - def quantize( - self, backend, soc_model, inputs, quant_dtype, targets=None, metrics=None - ): + def quantize(self, qnn_config, inputs, quant_dtype, targets=None, metrics=None): assert quant_dtype is not None, "quant_dtype must be specified" self.quant_dtype = quant_dtype @@ -123,8 +123,8 @@ def quantize( per_channel_linear=True, quant_dtype=quant_dtype, eps=2**-20, - backend=backend, - soc_model=soc_model, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, ) self.exported_encoder = prepare_pt2e(self.exported_encoder, quantizer) @@ -233,6 +233,7 @@ def lowering_modules( def main(args): # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) data_size = 100 max_hidden_seq_length = 384 @@ -256,18 +257,17 @@ def main(args): max_hidden_seq_length=max_hidden_seq_length, max_cache_length=max_cache_length, ) - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_16a8w, - }[backend] + }[qnn_config.backend] if quant_dtype: - t5.quantize(backend, args.model, inputs, quant_dtype) + t5.quantize(qnn_config, inputs, quant_dtype) t5.lowering_modules( args.artifact, - soc_model=getattr(QcomChipset, args.model), + soc_model=getattr(QcomChipset, args.soc_model), use_fp16=True if quant_dtype is None else False, - backend=backend, + backend=qnn_config.backend, online_prepare=args.online_prepare, ) @@ -323,23 +323,15 @@ def post_process(): runner_args, ] ) - backend = get_backend_type(args.backend) adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=pte_path, workspace=workspace, - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, runner="examples/qualcomm/oss_scripts/t5/qnn_t5_runner", ) adb.push( inputs=inputs, files=[runtime_tokenizer_path], - backends={backend}, ) adb.execute(custom_runner_cmd=runner_cmd) adb.pull(host_output_path=args.artifact, callback=post_process) @@ -377,7 +369,7 @@ def post_process(): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/vit_b_16.py b/examples/qualcomm/oss_scripts/vit_b_16.py index 797cadf30c0..40eabe6c691 100755 --- a/examples/qualcomm/oss_scripts/vit_b_16.py +++ b/examples/qualcomm/oss_scripts/vit_b_16.py @@ -15,18 +15,21 @@ import torch import torchvision +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) + from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - make_quantizer, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) @@ -35,6 +38,8 @@ def main(args): # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) + data_num = 100 if args.ci: inputs = [(torch.rand(1, 3, 224, 224),)] @@ -51,43 +56,29 @@ def main(args): pte_filename = "vit_b_16_qnn" instance = torchvision.models.vit_b_16(weights="IMAGENET1K_V1").eval() - backend = get_backend_type(args.backend) qnn_quantizer = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: make_quantizer( quant_dtype=QuantDtype.use_8a8w, per_channel_linear=True, - backend=backend, - soc_model=args.model, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, ), - }[backend] + }[qnn_config.backend] build_executorch_binary( - instance, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, + model=instance, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, custom_quantizer=qnn_quantizer, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -138,7 +129,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/whisper/whisper.py b/examples/qualcomm/oss_scripts/whisper/whisper.py index 978d39a1c7c..ccd1a39795f 100644 --- a/examples/qualcomm/oss_scripts/whisper/whisper.py +++ b/examples/qualcomm/oss_scripts/whisper/whisper.py @@ -7,6 +7,7 @@ # TODO: reenable pyre after fixing the issues # pyre-ignore-all-errors +import argparse import getpass import json import logging @@ -23,6 +24,12 @@ get_capture_program_passes, ) from executorch.backends.qualcomm.builders.utils import is_graph_output +from executorch.backends.qualcomm.export_utils import ( + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( @@ -49,14 +56,7 @@ QnnSeq2SeqLMEncoderExportableModule, ) -from executorch.examples.qualcomm.utils import ( - get_backend_type, - make_output_dir, - make_quantizer, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, -) +from executorch.examples.qualcomm.utils import make_output_dir from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from torchao.quantization.pt2e import MinMaxObserver @@ -251,7 +251,7 @@ def quantize( act_observer=MinMaxObserver, custom_annotations=custom_annotations, eps=2**-20, - backend=backend, + backend=qnn_config.backend, soc_model=soc_model, ) @@ -365,9 +365,7 @@ def lowering_modules( whisper_edge_prog_mgr.write_to_file(file) -def compile_whisper(args, inputs): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) - +def compile_whisper(args: argparse.Namespace, qnn_config: QnnConfig, inputs): # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -386,28 +384,28 @@ def compile_whisper(args, inputs): max_cache_length=max_cache_length, max_seq_length=args.max_seq_len, ) - - backend = get_backend_type(args.backend) quant_type = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_16a8w, - }[backend] + }[qnn_config.backend] whisper.prepare_model() if quant_type: - whisper.quantize(backend, args.model, inputs, quant_type, tokenizer) + whisper.quantize( + qnn_config.backend, qnn_config.soc_model, inputs, quant_type, tokenizer + ) whisper.lowering_modules( args.artifact, use_fp16=False, - soc_model=get_soc_to_chipset_map()[args.model], - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - backend=backend, + soc_model=get_soc_to_chipset_map()[args.soc_model], + skip_node_id_set=qnn_config.skip_delegate_node_ids, + skip_node_op_set=qnn_config.skip_delegate_node_ops, + backend=qnn_config.backend, online_prepare=args.online_prepare, ) -def inference_whisper(args, inputs, target): +def inference_whisper(args: argparse.Namespace, qnn_config: QnnConfig, inputs, target): workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/whisper" tokenizer = AutoTokenizer.from_pretrained("openai/whisper-tiny") tokenizer_json = tokenizer.save_pretrained(args.artifact)[-1] @@ -465,21 +463,14 @@ def post_process(): ] ) - backend = get_backend_type(args.backend) adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=pte_path, workspace=workspace, - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, runner="examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner", ) # No pregen inputs, input_list is not required - adb.push(inputs=inputs, files=[tokenizer_json], backends={backend}) + adb.push(inputs=inputs, files=[tokenizer_json]) adb.execute(custom_runner_cmd=runner_cmd) adb.pull(host_output_path=args.artifact, callback=post_process) @@ -520,10 +511,8 @@ def post_process(): ) args = parser.parse_args() - args.validate(args) - if args.compile_only and args.pre_gen_pte: - exit("Cannot set both compile_only and pre_gen_pte as true") + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) data_num = 20 if args.ci: @@ -535,16 +524,16 @@ def post_process(): inputs, target = get_dataset(data_num) if args.pre_gen_pte: - inference_whisper(args, inputs, target) + inference_whisper(args, qnn_config, inputs, target) exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}") if args.compile_only: - compile_whisper(args, inputs) + compile_whisper(args, qnn_config, inputs) exit(f"Finish compile_only and save to {args.artifact}") try: - compile_whisper(args, inputs) - inference_whisper(args, inputs, target) + compile_whisper(args, qnn_config, inputs) + inference_whisper(args, qnn_config, inputs, target) except Exception as e: if args.ip and args.port != -1: with Client((args.ip, args.port)) as conn: diff --git a/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py b/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py index 2e91d39e471..bc999a67de9 100644 --- a/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py +++ b/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py @@ -9,9 +9,14 @@ from multiprocessing.connection import Client import torch + +from executorch.backends.qualcomm.export_utils import ( + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.utils import ( - ExecutorchBackendConfig, from_context_binary, generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, @@ -21,14 +26,13 @@ gen_pte_from_ctx_bin, get_encoding, ) -from executorch.examples.qualcomm.utils import ( - setup_common_args_and_variables, - SimpleADB, -) +from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass def main(args): + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) + os.makedirs(args.artifact, exist_ok=True) target_names = ( @@ -49,7 +53,7 @@ def main(args): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.model), + soc_model=getattr(QcomChipset, args.soc_model), backend_options=backend_options, is_from_context_binary=True, ) @@ -65,7 +69,7 @@ def main(args): if args.pre_gen_pte is None: # create custom operators as context loader - soc_model = get_soc_to_chipset_map()[args.model] + soc_model = get_soc_to_chipset_map()[args.soc_model] bundle_programs = [ from_context_binary( ctx_path=f"{args.context_binaries}/{target}", @@ -94,15 +98,10 @@ def main(args): return adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=args.build_folder, + qnn_config=qnn_config, pte_path=pte_files, workspace=f"/data/local/tmp/executorch/{pte_name}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, runner="examples/qualcomm/qaihub_scripts/llama/qaihub_llama2_7b_runner", - target=args.target, ) output_file = "result.txt" pos_embs_file = ["freq_cos", "freq_sin"] diff --git a/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py b/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py index 5864f17c335..9da728767af 100644 --- a/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py +++ b/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py @@ -9,10 +9,13 @@ from multiprocessing.connection import Client import torch +from executorch.backends.qualcomm.export_utils import ( + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset - from executorch.backends.qualcomm.utils.utils import ( - ExecutorchBackendConfig, from_context_binary, generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, @@ -22,14 +25,13 @@ gen_pte_from_ctx_bin, get_encoding, ) -from executorch.examples.qualcomm.utils import ( - setup_common_args_and_variables, - SimpleADB, -) +from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass def main(args): + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) + os.makedirs(args.artifact, exist_ok=True) target_names = ( @@ -50,7 +52,7 @@ def main(args): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.model), + soc_model=getattr(QcomChipset, args.soc_model), backend_options=backend_options, is_from_context_binary=True, ) @@ -66,7 +68,7 @@ def main(args): if args.pre_gen_pte is None: # create custom operators as context loader - soc_model = get_soc_to_chipset_map()[args.model] + soc_model = get_soc_to_chipset_map()[args.soc_model] bundle_programs = [ from_context_binary( ctx_path=f"{args.context_binaries}/{target}", @@ -95,15 +97,10 @@ def main(args): return adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=args.build_folder, + qnn_config=qnn_config, pte_path=pte_files, workspace=f"/data/local/tmp/executorch/{pte_name}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, runner="examples/qualcomm/qaihub_scripts/llama/qaihub_llama3_8b_runner", - target=args.target, ) output_file = "result.txt" pos_embs_file = ["freq_cos", "freq_sin"] @@ -164,8 +161,7 @@ def post_process(): freq = (freq / scale + offset).clip(min=0, max=65535).detach() freq.to(dtype=torch.uint16).numpy().tofile(custom_files[-1]) - if not args.skip_push: - adb.push(files=custom_files) + adb.push(files=custom_files) adb.execute(custom_runner_cmd=runner_cmds) adb.pull(args.artifact, callback=post_process) if args.ip and args.port != -1: diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py index 400abd056da..931056c5444 100644 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py @@ -13,6 +13,11 @@ import torch from diffusers import EulerDiscreteScheduler, UNet2DConditionModel from diffusers.models.embeddings import get_timestep_embedding +from executorch.backends.qualcomm.export_utils import ( + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.utils.utils import ( ExecutorchBackendConfig, @@ -30,10 +35,6 @@ gen_pte_from_ctx_bin, get_encoding, ) -from executorch.examples.qualcomm.utils import ( - setup_common_args_and_variables, - SimpleADB, -) from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from PIL import Image from torchvision.transforms import ToTensor @@ -205,7 +206,7 @@ def save_result(output_image): print(f"Output image saved at {save_path}") -def inference(args, compiler_specs, pte_files): +def inference(args, qnn_config, compiler_specs, pte_files): # Loading a pretrained EulerDiscreteScheduler from the https://huggingface.co/stabilityai/stable-diffusion-2-1-base. # @lint-ignore scheduler = EulerDiscreteScheduler.from_pretrained( "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler", revision="main" @@ -240,15 +241,10 @@ def inference(args, compiler_specs, pte_files): } adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=args.build_folder, + qnn_config=qnn_config, pte_path=pte_files, workspace=f"/data/local/tmp/executorch/{args.pte_prefix}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, runner="examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner", - target=args.target, ) input_unet = () @@ -324,8 +320,7 @@ def inference(args, compiler_specs, pte_files): file.write(flattened_tensor.numpy().tobytes()) files.append(os.path.join(args.artifact, "latents.raw")) - if not args.skip_push: - adb.push(inputs=input_unet, files=files) + adb.push(inputs=input_unet, files=files) adb.execute(custom_runner_cmd=qnn_executor_runner_args) output_image = [] @@ -345,6 +340,7 @@ def post_process_vae(): def main(args): + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) os.makedirs(args.artifact, exist_ok=True) # common part for compile & inference backend_options = generate_htp_compiler_spec( @@ -352,14 +348,14 @@ def main(args): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.model), + soc_model=getattr(QcomChipset, args.soc_model), backend_options=backend_options, is_from_context_binary=True, ) if args.pre_gen_pte is None: # Create custom operators as context loader - soc_model = get_soc_to_chipset_map()[args.model] + soc_model = get_soc_to_chipset_map()[args.soc_model] bundle_programs = [ from_context_binary(args.text_encoder_bin, "ctx_loader_0", soc_model), from_context_binary(args.unet_bin, "ctx_loader_1", soc_model), @@ -390,7 +386,7 @@ def main(args): if args.compile_only: return - inference(args, compiler_specs, pte_files) + inference(args, qnn_config, compiler_specs, pte_files) if __name__ == "__main__": # noqa: C901 diff --git a/examples/qualcomm/qaihub_scripts/utils/export.py b/examples/qualcomm/qaihub_scripts/utils/export.py index 00a0d397a80..a144e74a82c 100644 --- a/examples/qualcomm/qaihub_scripts/utils/export.py +++ b/examples/qualcomm/qaihub_scripts/utils/export.py @@ -15,6 +15,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import QnnConfig, SimpleADB from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.utils import ( draw_graph, @@ -24,7 +25,7 @@ generate_qnn_executorch_option, ) from executorch.examples.qualcomm.qaihub_scripts.utils.utils import preprocess_binary -from executorch.examples.qualcomm.utils import make_output_dir, SimpleADB +from executorch.examples.qualcomm.utils import make_output_dir from executorch.exir import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass @@ -168,15 +169,17 @@ def to_context_binary( # leverage SimpleADB for model library conversion lib_name = Path(model_lib).stem sdk_root = os.getenv("QNN_SDK_ROOT") + qnn_config = QnnConfig( + soc_model=soc_model, + build_folder=build_folder, + device=device, + host=host, + target=target, + ) adb = SimpleADB( - qnn_sdk=sdk_root, - build_path=build_folder, + qnn_config=qnn_config, pte_path=model_lib, workspace=f"/data/local/tmp/executorch/{lib_name}", - device_id=device, - soc_model=soc_model, - host_id=host, - target=target, ) logger.info("pushing QNN libraries & tool") @@ -221,7 +224,7 @@ def compile(args): backend_options = generate_htp_compiler_spec(use_fp16=False) # setup general compiler spec for QNN compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.model), + soc_model=getattr(QcomChipset, args.soc_model), backend_options=backend_options, is_from_context_binary=True, ) @@ -242,7 +245,7 @@ def compile(args): # conversion model library into context binary if required ctx_bin = to_context_binary( model_lib=ctx_bin, - soc_model=args.model, + soc_model=args.soc_model, device=args.device, host=args.host, target=args.target, @@ -261,7 +264,7 @@ def compile(args): # step 1: generate ExportedProgram with custom op as binary loader & lower to QnnBackend logger.info(f"({index}/{num_bins}) exporting program for {ctx_bin}") prog_info = from_context_binary( - ctx_bin, custom_op_name, getattr(QcomChipset, args.model) + ctx_bin, custom_op_name, getattr(QcomChipset, args.soc_model) ) # step 2: write pte files and IO information logger.info(f"({index}/{num_bins}) exporting {binary_name}.pte") @@ -281,7 +284,7 @@ def compile(args): ) with open(f"{output_dir}/{binary_name}.json", "w") as f: graph_info = get_io_info(prog_info, ctx_bin, compiler_specs) - graph_info["soc_model"] = args.model + graph_info["soc_model"] = args.soc_model json.dump(graph_info, f, indent=2) @@ -308,17 +311,12 @@ def execute(args): inputs = get_tensor(graph_info["inputs"], user_inputs, logger) logger.info("preparing ADB connection") + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # leverage SimpleADB for e2e inference adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=args.build_folder, + qnn_config=qnn_config, pte_path=f"{args.pte_directory}/{pte_name}.pte", workspace=f"/data/local/tmp/executorch/{pte_name}", - device_id=args.device, - soc_model=graph_info["soc_model"], - host_id=args.host, - shared_buffer=args.shared_buffer, - target=args.target, ) logger.info("pushing QNN libraries & other artifacts") diff --git a/examples/qualcomm/sample_config.json b/examples/qualcomm/sample_config.json new file mode 100644 index 00000000000..ae8edd4c426 --- /dev/null +++ b/examples/qualcomm/sample_config.json @@ -0,0 +1,13 @@ +{ + "soc_model" : "SM8750", + "build_folder" : "build-android", + "backend" : "htp", + "target" : "aarch64-android", + "online_prepare" : false, + "shared_buffer" : false, + "dump_intermediate_outputs" : false, + "profile_level" : 0, + "enable_x86_64" : false, + "compile_only" : true, + "seed" : 42 +} diff --git a/examples/qualcomm/scripts/deeplab_v3.py b/examples/qualcomm/scripts/deeplab_v3.py index 2280ebb22da..2351cc6efdd 100755 --- a/examples/qualcomm/scripts/deeplab_v3.py +++ b/examples/qualcomm/scripts/deeplab_v3.py @@ -14,20 +14,19 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) + from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model -from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, - make_output_dir, - parse_skip_delegation_node, - segmentation_metrics, - setup_common_args_and_variables, - SimpleADB, -) +from executorch.examples.qualcomm.utils import make_output_dir, segmentation_metrics def get_dataset(data_size, dataset_dir, download): @@ -66,7 +65,7 @@ def get_dataset(data_size, dataset_dir, download): def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -84,40 +83,25 @@ def main(args): pte_filename = "dl3_qnn" instance = DeepLabV3ResNet101Model() - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] + build_executorch_binary( - instance.get_eager_model().eval(), - instance.get_example_inputs(), - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=instance.get_eager_model().eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -204,7 +188,7 @@ def post_process(): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/scripts/edsr.py b/examples/qualcomm/scripts/edsr.py index 8a5cb0f478f..227a3b57203 100755 --- a/examples/qualcomm/scripts/edsr.py +++ b/examples/qualcomm/scripts/edsr.py @@ -14,19 +14,19 @@ import numpy as np import piq import torch + +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.models.edsr import EdsrModel -from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, - make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, -) +from executorch.examples.qualcomm.utils import make_output_dir from PIL import Image from torch.utils.data import Dataset @@ -101,7 +101,7 @@ def get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -120,40 +120,24 @@ def main(args): inputs, targets = dataset.lr, dataset.hr pte_filename = "edsr_qnn" - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - instance.get_eager_model().eval(), - (inputs[0],), - args.model, - f"{args.artifact}/{pte_filename}", - [(input,) for input in inputs], - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=instance.get_eager_model().eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=[(input,) for input in inputs], quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -236,7 +220,7 @@ def post_process(): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/scripts/inception_v3.py b/examples/qualcomm/scripts/inception_v3.py index 5fb0db595d1..f1781207d0f 100755 --- a/examples/qualcomm/scripts/inception_v3.py +++ b/examples/qualcomm/scripts/inception_v3.py @@ -13,26 +13,26 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.models.inception_v3.model import InceptionV3Model from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) - + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -51,40 +51,25 @@ def main(args): ) pte_filename = "ic3_qnn" instance = InceptionV3Model() - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] + build_executorch_binary( - instance.get_eager_model().eval(), - instance.get_example_inputs(), - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=instance.get_eager_model().eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -137,7 +122,6 @@ def main(args): ) args = parser.parse_args() - args.validate(args) try: main(args) except Exception as e: diff --git a/examples/qualcomm/scripts/inception_v4.py b/examples/qualcomm/scripts/inception_v4.py index 3941239f0bb..10b69d2eda1 100755 --- a/examples/qualcomm/scripts/inception_v4.py +++ b/examples/qualcomm/scripts/inception_v4.py @@ -12,25 +12,26 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.models.inception_v4 import InceptionV4Model from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -49,40 +50,25 @@ def main(args): ) pte_filename = "ic4_qnn" instance = InceptionV4Model() - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] + build_executorch_binary( - instance.get_eager_model().eval(), - instance.get_example_inputs(), - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=instance.get_eager_model().eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -135,7 +121,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/scripts/mobilebert_fine_tune.py b/examples/qualcomm/scripts/mobilebert_fine_tune.py index 445f6a53d92..f01b94f905f 100755 --- a/examples/qualcomm/scripts/mobilebert_fine_tune.py +++ b/examples/qualcomm/scripts/mobilebert_fine_tune.py @@ -11,6 +11,13 @@ import numpy as np import torch +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QcomChipset, @@ -21,14 +28,7 @@ generate_qnn_executorch_compiler_spec, skip_annotation, ) -from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - make_output_dir, - make_quantizer, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, -) +from executorch.examples.qualcomm.utils import make_output_dir from executorch.exir import to_edge from transformers import BertTokenizer, MobileBertForSequenceClassification @@ -222,7 +222,7 @@ def get_fine_tuned_mobilebert(artifacts_dir, pretrained_weight, batch_size): def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -241,18 +241,12 @@ def main(args): ) if args.use_fp16: - quant_dtype = None pte_filename = "mb_qnn" build_executorch_binary( - model, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - quant_dtype=quant_dtype, - shared_buffer=args.shared_buffer, + model=model, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, ) else: @@ -263,11 +257,11 @@ def calibrator(gm): quantizer = make_quantizer( quant_dtype=quant_dtype, backend=QnnExecuTorchBackendType.kHtpBackend, - soc_model=args.model, + soc_model=qnn_config.soc_model, ) backend_options = generate_htp_compiler_spec(quant_dtype is not None) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.model), + soc_model=getattr(QcomChipset, args.soc_model), backend_options=backend_options, ) # skip embedding layer cause it's quantization sensitive @@ -291,15 +285,9 @@ def calibrator(gm): return adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) adb.push(inputs=inputs) adb.execute() @@ -379,7 +367,7 @@ def calibrator(gm): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/scripts/mobilenet_v2.py b/examples/qualcomm/scripts/mobilenet_v2.py index e1b076fd5c0..4cc56db2a89 100755 --- a/examples/qualcomm/scripts/mobilenet_v2.py +++ b/examples/qualcomm/scripts/mobilenet_v2.py @@ -12,25 +12,27 @@ import numpy as np import torch + +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.models.mobilenet_v2 import MV2Model from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -50,40 +52,24 @@ def main(args): ) pte_filename = "mv2_qnn" instance = MV2Model() - backend = get_backend_type(args.backend) quant_dtype = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: QuantDtype.use_8a8w, - }[backend] + }[qnn_config.backend] build_executorch_binary( - instance.get_eager_model().eval(), - instance.get_example_inputs(), - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=instance.get_eager_model().eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=quant_dtype, - backend=backend, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -136,7 +122,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/scripts/mobilenet_v3.py b/examples/qualcomm/scripts/mobilenet_v3.py index a41b552bc09..7cf64433971 100644 --- a/examples/qualcomm/scripts/mobilenet_v3.py +++ b/examples/qualcomm/scripts/mobilenet_v3.py @@ -13,26 +13,28 @@ import numpy as np import torch + +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, ) from executorch.examples.models.mobilenet_v3 import MV3Model from executorch.examples.qualcomm.utils import ( - build_executorch_binary, - get_backend_type, get_imagenet_dataset, make_output_dir, - make_quantizer, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -52,45 +54,29 @@ def main(args): ) pte_filename = "mv3_qnn" instance = MV3Model() - backend = get_backend_type(args.backend) quantizer = { QnnExecuTorchBackendType.kGpuBackend: None, QnnExecuTorchBackendType.kHtpBackend: make_quantizer( quant_dtype=QuantDtype.use_16a8w, eps=2**-10, - backend=backend, - soc_model=args.model, + backend=qnn_config.backend, + soc_model=qnn_config.soc_model, ), - }[backend] + }[qnn_config.backend] build_executorch_binary( - instance.get_eager_model().eval(), - instance.get_example_inputs(), - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - backend=backend, + model=instance.get_eager_model().eval(), + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, custom_quantizer=quantizer, - shared_buffer=args.shared_buffer, - online_prepare=args.online_prepare, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) - adb.push(inputs=inputs, backends={backend}) + adb.push(inputs=inputs) adb.execute() # collect output data @@ -143,7 +129,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/scripts/torchvision_vit.py b/examples/qualcomm/scripts/torchvision_vit.py index c26ac5d7d75..7bf01e95c98 100755 --- a/examples/qualcomm/scripts/torchvision_vit.py +++ b/examples/qualcomm/scripts/torchvision_vit.py @@ -15,14 +15,18 @@ import torch import torch.nn.functional as F + +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel from executorch.examples.qualcomm.utils import ( - build_executorch_binary, get_imagenet_dataset, make_output_dir, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) @@ -78,6 +82,8 @@ def PermuteInProjectionPacked(): def main(args): + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) + # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -100,28 +106,17 @@ def main(args): with PermuteInProjectionPacked(): build_executorch_binary( - instance, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, + model=instance, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=QuantDtype.use_8a8w, - shared_buffer=args.shared_buffer, ) - if args.compile_only: - return - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) adb.push(inputs=inputs) adb.execute() @@ -174,7 +169,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/scripts/wav2letter.py b/examples/qualcomm/scripts/wav2letter.py index 8bf22bae266..3e25383edfe 100644 --- a/examples/qualcomm/scripts/wav2letter.py +++ b/examples/qualcomm/scripts/wav2letter.py @@ -8,21 +8,20 @@ import logging import os -import sys from multiprocessing.connection import Client import numpy as np import torch -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -from executorch.examples.models.wav2letter import Wav2LetterModel -from executorch.examples.qualcomm.utils import ( +from executorch.backends.qualcomm.export_utils import ( build_executorch_binary, - make_output_dir, - parse_skip_delegation_node, + QnnConfig, setup_common_args_and_variables, SimpleADB, ) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.examples.models.wav2letter import Wav2LetterModel +from executorch.examples.qualcomm.utils import make_output_dir class Conv2D(torch.nn.Module): @@ -97,7 +96,7 @@ def parse(ids): def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist os.makedirs(args.artifact, exist_ok=True) @@ -137,30 +136,17 @@ def main(args): inputs, targets = get_dataset(data_size=data_num, artifact_dir=args.artifact) pte_filename = "w2l_qnn" build_executorch_binary( - model, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=model, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=QuantDtype.use_8a8w, - shared_buffer=args.shared_buffer, ) - if args.compile_only: - sys.exit(0) - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - target=args.target, ) adb.push(inputs=inputs) adb.execute() @@ -221,7 +207,7 @@ def main(args): ) args = parser.parse_args() - args.validate(args) + try: main(args) except Exception as e: diff --git a/examples/qualcomm/util_scripts/README.md b/examples/qualcomm/util_scripts/README.md index 45c68d3bc04..1d85bbe134f 100644 --- a/examples/qualcomm/util_scripts/README.md +++ b/examples/qualcomm/util_scripts/README.md @@ -86,7 +86,7 @@ This section describes how to generate an ET record for a .pte program using the PYTHONPATH=.. python -m examples.qualcomm.util_scripts.gen_etrecord \ -b build-android \ --device $DEVICE_SERIAL \ - --model SM8750 \ + --soc_model SM8750 \ ``` * This script will: - Quantize and compile a sample model to generate `.pte` file. diff --git a/examples/qualcomm/util_scripts/cli.py b/examples/qualcomm/util_scripts/cli.py index 79b9d7fb1f2..b62711c12f8 100644 --- a/examples/qualcomm/util_scripts/cli.py +++ b/examples/qualcomm/util_scripts/cli.py @@ -23,6 +23,12 @@ from executorch.backends.qualcomm._passes.qnn_pass_manager import ( get_capture_program_passes, ) +from executorch.backends.qualcomm.export_utils import ( + get_backend_type, + make_quantizer, + QnnConfig, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.constants import QCOM_PASS_ACTIVATE_KEY @@ -38,11 +44,6 @@ to_edge_transform_and_lower_to_qnn, ) from executorch.examples.qualcomm.qaihub_scripts.utils.utils import preprocess_binary -from executorch.examples.qualcomm.utils import ( - get_backend_type, - make_quantizer, - SimpleADB, -) from executorch.exir import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from torchao.quantization import pt2e @@ -145,7 +146,7 @@ def quantize(args): per_channel_linear=args.per_row, act_observer=act_observer, backend=get_backend_type(args.backend), - soc_model=args.model, + soc_model=args.soc_model, eps=args.eps, ) except Exception: @@ -195,7 +196,7 @@ def compile(args): backend_options = generate_htp_compiler_spec(use_fp16=True) # setup general compiler spec for QNN compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.model), + soc_model=getattr(QcomChipset, args.soc_model), backend_options=backend_options, is_from_context_binary=extension == "bin", ) @@ -204,7 +205,7 @@ def compile(args): # step 1: generate ExportedProgram with custom op as a binary loader & lower it w/QnnBackend logger.info(f"exporting program for {args.artifact}") prog_info = from_context_binary( - args.artifact, custom_op_name, getattr(QcomChipset, args.model) + args.artifact, custom_op_name, getattr(QcomChipset, args.soc_model) ) # step 2: write pte files and store final graph logger.info(f"exporting {file_name}.pte") @@ -294,22 +295,25 @@ def execute(args): backend_options = generate_htp_compiler_spec(use_fp16=True) # setup general compiler spec for QNN compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.model), + soc_model=getattr(QcomChipset, args.soc_model), backend_options=backend_options, ) io_info = get_io_info(args.artifact, compiler_specs) logger.info("preparing ADB connection") + + qnn_config = QnnConfig( + build_folder=args.build_folder, + device=args.device, + soc_model=args.soc_model, + host=args.host, + shared_buffer=args.shared_buffer, + target=args.target, + ) # leverage SimpleADB for e2e inference adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=args.build_folder, + qnn_config=qnn_config, pte_path=args.artifact, workspace=f"/data/local/tmp/executorch/{pte_name}", - device_id=args.device, - soc_model=args.model, - host_id=args.host, - shared_buffer=args.shared_buffer, - target=args.target, ) logger.info("pushing QNN libraries & other artifacts") @@ -358,7 +362,7 @@ def post_process(): torch.save(output, f"{output_result_folder}/output_{output_index}.pt") logger.info("collecting output data") - adb.pull(tmp_dir, post_process) + adb.pull(host_output_path=tmp_dir, callback=post_process) shutil.rmtree(tmp_dir) logger.info(f"execution finished, please check {args.output_folder} for results") @@ -438,7 +442,7 @@ def main(): ) sub_quantize.add_argument( "-m", - "--model", + "--soc_model", type=str, required=True, help="SoC model. e.g. SM8750", @@ -474,7 +478,7 @@ def main(): ) sub_compile.add_argument( "-m", - "--model", + "--soc_model", type=str, required=True, help="SoC model. e.g. SM8750", @@ -528,7 +532,7 @@ def main(): ) sub_execute.add_argument( "-m", - "--model", + "--soc_model", type=str, required=True, help="SoC model. e.g. SM8750", diff --git a/examples/qualcomm/util_scripts/gen_etrecord.py b/examples/qualcomm/util_scripts/gen_etrecord.py index 2cc00fc1db9..005e8ef969a 100644 --- a/examples/qualcomm/util_scripts/gen_etrecord.py +++ b/examples/qualcomm/util_scripts/gen_etrecord.py @@ -1,6 +1,16 @@ -import os +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.export_utils import ( + make_quantizer, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, @@ -15,16 +25,13 @@ ) from executorch.devtools import Inspector from executorch.devtools.inspector._inspector_utils import TimeScale -from executorch.examples.qualcomm.utils import ( - make_quantizer, - setup_common_args_and_variables, - SimpleADB, -) from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def main(args): + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) + # capture nn.Module into ExportedProgram sample_input = (torch.randn(1, 32, 28, 28), torch.randn(1, 32, 28, 28)) model = torch.export.export(SimpleModel(), sample_input).module() @@ -33,7 +40,7 @@ def main(args): # Quantize the model quantizer = make_quantizer( - backend=QnnExecuTorchBackendType.kHtpBackend, soc_model=args.model + backend=QnnExecuTorchBackendType.kHtpBackend, soc_model=qnn_config.soc_model ) prepared = prepare_pt2e(model, quantizer) prepared(*sample_input) @@ -61,13 +68,9 @@ def main(args): # setup ADB for on-device execution adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - soc_model=args.model, - target=args.target, ) adb.push(inputs=[sample_input]) adb.execute() diff --git a/examples/qualcomm/util_scripts/qairt_visualizer_demo.py b/examples/qualcomm/util_scripts/qairt_visualizer_demo.py index d6042191bff..9d2ce7a8806 100644 --- a/examples/qualcomm/util_scripts/qairt_visualizer_demo.py +++ b/examples/qualcomm/util_scripts/qairt_visualizer_demo.py @@ -11,49 +11,47 @@ import qairt_visualizer import torch from executorch.backends.qualcomm.debugger.utils import generate_optrace -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -from executorch.backends.qualcomm.tests.models import SimpleModel -from executorch.backends.qualcomm.utils.utils import get_soc_to_chipset_map -from executorch.examples.qualcomm.utils import ( +from executorch.backends.qualcomm.export_utils import ( build_executorch_binary, + QnnConfig, setup_common_args_and_variables, SimpleADB, ) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.tests.models import SimpleModel +from executorch.backends.qualcomm.utils.utils import get_soc_to_chipset_map def main(args) -> None: + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) model = SimpleModel() example_inputs = [(torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))] pte_filename = "qnn_simple_model" os.makedirs(args.artifact, exist_ok=True) + assert ( + qnn_config.profile_level == 3 + ), "Please turn profile_level to 3 for the purpose of this tutorial." + # lower to QNN build_executorch_binary( - model, - example_inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - example_inputs, + model=model, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=example_inputs, quant_dtype=QuantDtype.use_8a8w, - online_prepare=args.online_prepare, - optrace=True, ) # generate optrace and QHAS adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - target=args.target, ) binaries_trace = generate_optrace( args.artifact, - get_soc_to_chipset_map()[args.model], + get_soc_to_chipset_map()[args.soc_model], adb, f"{args.artifact}/{pte_filename}.pte", example_inputs, diff --git a/examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py b/examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py index a02ed60cf83..07dfc9c9558 100644 --- a/examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py +++ b/examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py @@ -21,16 +21,19 @@ OutputFormat, QNNIntermediateDebugger, ) + +from executorch.backends.qualcomm.export_utils import ( + build_executorch_binary, + QnnConfig, + setup_common_args_and_variables, + SimpleADB, +) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.devtools import Inspector from executorch.examples.models.inception_v3.model import InceptionV3Model from executorch.examples.qualcomm.utils import ( - build_executorch_binary, get_imagenet_dataset, make_output_dir, - parse_skip_delegation_node, - setup_common_args_and_variables, - SimpleADB, topk_accuracy, ) @@ -41,7 +44,7 @@ def main(args): - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -66,16 +69,11 @@ def main(args): # Init our QNNIntermediateDebugger and pass it in to build_executorch_binary(). qnn_intermediate_debugger = QNNIntermediateDebugger() build_executorch_binary( - source_model, - instance.get_example_inputs(), - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + model=source_model, + qnn_config=qnn_config, + file_name=f"{args.artifact}/{pte_filename}", + dataset=inputs, quant_dtype=QuantDtype.use_8a8w, - shared_buffer=args.shared_buffer, - dump_intermediate_outputs=args.dump_intermediate_outputs, qnn_intermediate_debugger=qnn_intermediate_debugger, ) @@ -89,15 +87,9 @@ def main(args): # Please ensure that dump_intermediate_outputs are set to true when creating SimpleADB adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", + qnn_config=qnn_config, pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - dump_intermediate_outputs=args.dump_intermediate_outputs, ) adb.push(inputs=inputs) adb.execute() diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index d2c3bf9cfd3..81369570249 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -6,389 +6,16 @@ # TODO: reenable pyre after fixing the issues # pyre-ignore-all-errors -import argparse import csv import inspect -import logging import os import random import shutil -import subprocess -import sys -import tempfile -from pathlib import Path - -from typing import Callable, Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional import numpy as np import torch -import torchao import transformers -from executorch.backends.qualcomm.debugger.qnn_intermediate_debugger import ( - QNNIntermediateDebugger, -) -from executorch.backends.qualcomm.quantizer.quantizer import ( - ModuleQConfig, - QnnQuantizer, - QuantDtype, -) -from executorch.backends.qualcomm.serialization.qc_schema import ( - QcomChipset, - QnnExecuTorchBackendType, - QnnExecuTorchOpPackageOptions, -) -from executorch.backends.qualcomm.utils.constants import ( - DSP_VERSION, - HEXAGON_SDK_ROOT, - HEXAGON_TOOLS_ROOT, -) -from executorch.backends.qualcomm.utils.utils import ( - generate_gpu_compiler_spec, - generate_htp_compiler_spec, - generate_qnn_executorch_compiler_spec, - get_qnn_context_binary_alignment, - get_soc_to_arch_map, - to_edge_transform_and_lower_to_qnn, -) -from executorch.exir.backend.utils import get_delegates -from executorch.exir.capture._config import ExecutorchBackendConfig -from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass -from torchao.quantization.pt2e import MovingAverageMinMaxObserver -from torchao.quantization.pt2e.quantize_pt2e import ( - convert_pt2e, - prepare_pt2e, - prepare_qat_pt2e, -) - - -class SimpleADB: - """ - A wrapper class for communicating with Android device - - Attributes: - qnn_sdk (str): QNN SDK path setup in environment variable - build_path (str): Path where artifacts were built - pte_path (str): Path where executorch binary was stored - workspace (str): Folder for storing artifacts on android device - device_id (str): Serial number of android device - soc_model (str): Chipset of device - host_id (str): Hostname of machine where device connects - error_only (bool): Redirect stdio and leave error messages only - shared_buffer (bool): Apply zero-copy mechanism in runtime - runner (str): Runtime executor binary - target (str): Target toolchain name - expected_input_shape (Tuple[torch.Size]): Input shape of dynamic graph - expected_output_shape (Tuple[torch.Size]): Output shape of dynamic graph - """ - - def __init__( - self, - qnn_sdk, - build_path, - pte_path, - workspace, - device_id, - soc_model, - direct_mode_build_path=None, - host_id=None, - error_only=False, - shared_buffer=False, - dump_intermediate_outputs=False, - runner=None, - target="aarch64-android", - expected_input_shape=None, - expected_output_shape=None, - ): - if runner is None: - runner = ( - "examples/qualcomm/executor_runner/qnn_executor_runner" - if direct_mode_build_path is None - else "examples/qualcomm/direct_executor_runner/qnn_executor_direct_runner" - ) - if direct_mode_build_path: - required_env = [HEXAGON_SDK_ROOT, HEXAGON_TOOLS_ROOT, DSP_VERSION] - assert all( - var in os.environ for var in required_env - ), f"Please ensure the following environment variables are set{required_env}" - self.hexagon_sdk_root = os.getenv(HEXAGON_SDK_ROOT) - self.hexagon_tools_root = os.getenv(HEXAGON_TOOLS_ROOT) - self.dsp_arch = os.getenv(DSP_VERSION) - logging.info(f"{HEXAGON_SDK_ROOT}={self.hexagon_sdk_root}") - logging.info(f"{HEXAGON_TOOLS_ROOT}={self.hexagon_tools_root}") - logging.info(f"{DSP_VERSION}={self.dsp_arch}") - self.qnn_sdk = qnn_sdk - self.build_path = build_path - self.direct_mode_build_path = direct_mode_build_path - self.pte_path = pte_path if isinstance(pte_path, list) else [pte_path] - self.workspace = workspace - self.device_id = device_id - self.host_id = host_id - if len(self.pte_path) > 0: - self.working_dir = Path(self.pte_path[0]).parent.absolute() - else: - self.working_dir = Path.cwd() - self.input_list_filename = "input_list.txt" - self.etdump_path = f"{self.workspace}/etdump.etdp" - self.dump_intermediate_outputs = dump_intermediate_outputs - self.debug_output_path = f"{self.workspace}/debug_output.bin" - self.output_folder = f"{self.workspace}/outputs" - self.htp_arch = get_soc_to_arch_map()[soc_model] - self.error_only = error_only - self.shared_buffer = shared_buffer - self.runner = runner - self.target = target - self.expected_input_shape = expected_input_shape - self.expected_output_shape = expected_output_shape - self.extra_cmds = "" - self.backend_library_paths = {} - - if self.direct_mode_build_path: - direct_general_artifacts = [ - f"{self.build_path}/examples/qualcomm/direct_executor_runner/libqnn_executorch_stub.so", - f"{self.direct_mode_build_path}/backends/qualcomm/libqnn_executorch_backend.so", - f"{self.direct_mode_build_path}/backends/qualcomm/qnn_executorch/direct_mode/libqnn_executorch_skel.so", - ] - self.backend_library_paths.update( - { - QnnExecuTorchBackendType.kHtpBackend: [ - f"{self.qnn_sdk}/lib/hexagon-v{self.htp_arch}/unsigned/libQnnHtpV{self.htp_arch}.so", - f"{self.qnn_sdk}/lib/hexagon-v{self.htp_arch}/unsigned/libQnnSystem.so", - f"{self.hexagon_tools_root}/Tools/target/hexagon/lib/v{self.htp_arch}/G0/pic/libc++abi.so.1", - f"{self.hexagon_tools_root}/Tools/target/hexagon/lib/v{self.htp_arch}/G0/pic/libc++.so.1", - ] - } - ) - for _, library_paths in self.backend_library_paths.items(): - library_paths.extend(direct_general_artifacts) - else: - traditional_general_artifacts = [ - f"{self.qnn_sdk}/lib/{self.target}/libQnnSystem.so", - f"{self.build_path}/backends/qualcomm/libqnn_executorch_backend.so", - f"{self.qnn_sdk}/lib/{self.target}/libQnnModelDlc.so", - ] - self.backend_library_paths.update( - { - QnnExecuTorchBackendType.kHtpBackend: [ - f"{self.qnn_sdk}/lib/{self.target}/libQnnHtp.so", - ( - f"{self.qnn_sdk}/lib/hexagon-v{self.htp_arch}/" - f"unsigned/libQnnHtpV{self.htp_arch}Skel.so" - ), - ( - f"{self.qnn_sdk}/lib/{self.target}/" - f"libQnnHtpV{self.htp_arch}Stub.so" - ), - f"{self.qnn_sdk}/lib/{self.target}/libQnnHtpPrepare.so", - ], - QnnExecuTorchBackendType.kGpuBackend: [ - f"{self.qnn_sdk}/lib/{self.target}/libQnnGpu.so", - ], - } - ) - for _, library_paths in self.backend_library_paths.items(): - library_paths.extend(traditional_general_artifacts) - - def _adb(self, cmd, output_callback: Optional[Callable[[str], None]] = None): - if not self.host_id: - cmds = ["adb", "-s", self.device_id] - else: - cmds = ["adb", "-H", self.host_id, "-s", self.device_id] - cmds.extend(cmd) - - if output_callback: - result = subprocess.run( - cmds, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True - ) - output_callback(result) - else: - subprocess.run( - cmds, stdout=subprocess.DEVNULL if self.error_only else sys.stdout - ) - - def push( - self, - inputs=None, - input_list=None, - files=None, - backends: Optional[Set[QnnExecuTorchBackendType]] = None, - init_env=True, - ): - artifacts = [*self.pte_path, f"{self.build_path}/{self.runner}"] - - if init_env: - self._adb(["shell", f"rm -rf {self.workspace}"]) - self._adb(["shell", f"mkdir -p {self.workspace}"]) - - if backends is None: - backends = {QnnExecuTorchBackendType.kHtpBackend} - - # backend libraries - for backend in backends: - artifacts.extend(self.backend_library_paths[backend]) - - with tempfile.TemporaryDirectory() as tmp_dir: - input_list_file, input_files = generate_inputs( - tmp_dir, self.input_list_filename, inputs - ) - - if input_list_file is not None: - # prepare input list - artifacts.append(input_list_file) - - for artifact in artifacts: - self._adb(["push", artifact, self.workspace]) - - # input data - for file_name in input_files: - self._adb(["push", file_name, self.workspace]) - - # dynamic shape related - if self.expected_input_shape and self.expected_output_shape: - shape_info = { - "input_shape": self.expected_input_shape, - "output_shape": self.expected_output_shape, - } - for name, shapes in shape_info.items(): - with open(f"{tmp_dir}/{name}.txt", "w") as f: - for s in shapes: - f.write(str(tuple(s)).strip("()") + "\n") - self._adb(["push", f"{tmp_dir}/{name}.txt", self.workspace]) - self.extra_cmds += f" --{name}_path {name}.txt" - - # custom files - if files is not None: - for file_name in files: - self._adb(["push", file_name, self.workspace]) - - def execute( - self, - custom_runner_cmd=None, - method_index=0, - output_callback: Optional[Callable[[str], None]] = None, - ): - self._adb(["shell", f"rm -rf {self.output_folder}"]) - self._adb(["shell", f"mkdir -p {self.output_folder}"]) - # run the delegation - if custom_runner_cmd is None: - qnn_executor_runner_args = ( - " ".join( - [ - f"--model_path {os.path.basename(self.pte_path[0])}", - f"--output_folder_path {self.output_folder}", - f"--input_list_path {self.input_list_filename}", - f"--etdump_path {self.etdump_path}", - "--shared_buffer" if self.shared_buffer else "", - f"--debug_output_path {self.debug_output_path}", - ( - "--dump_intermediate_outputs" - if self.dump_intermediate_outputs - else "" - ), - f"--method_index {method_index}", - ] - ) - + self.extra_cmds - ) - qnn_executor_runner_cmds = " ".join( - [ - f"cd {self.workspace} &&", - f"chmod +x {os.path.basename(self.runner)} &&", - f"export LD_LIBRARY_PATH=. && export ADSP_LIBRARY_PATH=. && echo 0x0C > {os.path.basename(self.runner)}.farf && ./{os.path.basename(self.runner)} {qnn_executor_runner_args}", - ] - ) - else: - qnn_executor_runner_cmds = custom_runner_cmd - self._adb( - ["shell", f"{qnn_executor_runner_cmds}"], output_callback=output_callback - ) - - def pull(self, host_output_path, device_output_path=None, callback=None): - if device_output_path is None: - device_output_path = self.output_folder - self._adb(["pull", "-a", device_output_path, host_output_path]) - if callback: - callback() - - def pull_etdump(self, output_path, callback=None): - self._adb(["pull", self.etdump_path, output_path]) - if callback: - callback() - - def pull_debug_output(self, etdump_path, debug_ouput_path, callback=None): - self._adb(["pull", self.etdump_path, etdump_path]) - self._adb(["pull", self.debug_output_path, debug_ouput_path]) - if callback: - callback() - - -def ptq_calibrate(captured_model, quantizer, dataset): - annotated_model = prepare_pt2e(captured_model, quantizer) - print("Quantizing(PTQ) the model...") - # calibration - if callable(dataset): - dataset(annotated_model) - else: - for data in dataset: - annotated_model(*data) - return annotated_model - - -def qat_train(ori_model, captured_model, quantizer, dataset): - data, targets = dataset - annotated_model = torchao.quantization.pt2e.move_exported_model_to_train( - prepare_qat_pt2e(captured_model, quantizer) - ) - optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001) - criterion = torch.nn.CrossEntropyLoss() - for i, d in enumerate(data): - print(f"Epoch {i}") - if i > 3: - # Freeze quantizer parameters - annotated_model.apply( - torchao.quantization.pt2e.fake_quantize.disable_observer - ) - if i > 2: - # Freeze batch norm mean and variance estimates - annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) - - output = annotated_model(*d) - loss = criterion(output, targets[i]) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - return convert_pt2e( - torchao.quantization.pt2e.move_exported_model_to_eval(annotated_model), - ) - - -def make_quantizer( - quant_dtype: Optional[QuantDtype] = QuantDtype.use_8a8w, - custom_annotations=(), - per_channel_conv=True, - per_channel_linear=False, - act_observer=MovingAverageMinMaxObserver, - act_symmetric=False, - is_qat=False, - submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None, - backend=QnnExecuTorchBackendType.kHtpBackend, - soc_model="SM8750", - eps=None, -): - quantizer = QnnQuantizer(backend=backend, soc_model=getattr(QcomChipset, soc_model)) - quantizer.add_custom_quant_annotations(custom_annotations) - quantizer.set_default_quant_config( - quant_dtype, - is_qat=is_qat, - is_conv_per_channel=per_channel_conv, - is_linear_per_channel=per_channel_linear, - act_observer=act_observer, - act_symmetric=act_symmetric, - eps=eps, - ) - submodule_qconfig_list = submodule_qconfig_list or [] - quantizer.set_submodule_qconfig_list(submodule_qconfig_list) - return quantizer def replace_module_with_custom_class( @@ -459,149 +86,6 @@ def extract_init_args_from_instance(instance): ) -# TODO: refactor to support different backends -def build_executorch_binary( - model, # noqa: B006 - inputs, # noqa: B006 - soc_model, - file_name, - dataset: List[torch.Tensor] | Callable[[torch.fx.GraphModule], None], - skip_node_id_set=None, - skip_node_op_set=None, - quant_dtype: Optional[QuantDtype] = None, - custom_quantizer: Optional[QnnQuantizer] = None, - shared_buffer=False, - metadata=None, - dump_intermediate_outputs=False, - qnn_intermediate_debugger: QNNIntermediateDebugger = None, - backend=QnnExecuTorchBackendType.kHtpBackend, - passes_job=None, - passes_dependency=None, - qat_training_data=None, - online_prepare=False, - optrace=False, - op_package_options: QnnExecuTorchOpPackageOptions = None, - direct_mode_build_path=None, -): - """ - A function to generate an ExecuTorch binary for Qualcomm platforms. - - Attributes: - model (torch.nn.Module): The model to be converted into an ExecuTorch binary. - inputs (torch.Tensor): Sample input tensors required for model export. - soc_model (QcomChipset): The target Qualcomm System on Chip (SoC) model. - backend (QnnExecuTorchBackendType): The target backend. - file_name (str): Name for the output binary file (.pte). - dataset (List[torch.Tensor] | Callable): A dataset for quantization calibration. - skip_node_id_set (set, optional): Set of node IDs to be skipped during partition. - skip_node_op_set (set, optional): Set of operation node to be skipped during partition. - quant_dtype (QuantDtype, optional): Data type for quantization. - custom_quantizer (Callable, optional): Custom quantizer. - shared_buffer (bool, optional): Applies zero-copy mechanism to optimize runtime memory allocation. - metadata (dict, optional): An optional dictionary that maps each method name to a constant value in eager mode. - dump_intermediate_outputs (bool, optional): Enables dumping model intermediate outputs. - passes_job (OrderedDict, optional): Custom passes job in capture_program, users can enable/disable specific passes or modify their attributes. - passes_dependency (Dict, optional): A dictionary mapping each pass to its corresponding list of dependencies. - qat_training_data (List[torch.Tensor], optional): A dataset for quantization aware training(QAT). Typically is a pair of tensors, such as [features, ground truth]. - online_prepare (bool, optional): Compose QNN graph on device if set to True. - optrace (bool, optional): Enable optrace mode for performance analysis if set to True. - op_package_options: Optional structure to specify op packages - loaded and used by the backend. - - Returns: - None: The function writes the output to a specified .pte file. - """ - if backend == QnnExecuTorchBackendType.kGpuBackend and not online_prepare: - raise RuntimeError("Currently GPU backend only supports online_prepare.") - backend_options = { - QnnExecuTorchBackendType.kGpuBackend: generate_gpu_compiler_spec(), - QnnExecuTorchBackendType.kHtpBackend: generate_htp_compiler_spec( - use_fp16=False if quant_dtype is not None else True - ), - }[backend] - compile_spec = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, soc_model), - backend_options=backend_options, - online_prepare=online_prepare, - optrace=optrace, - shared_buffer=shared_buffer, - dump_intermediate_outputs=dump_intermediate_outputs, - op_package_options=op_package_options, - ) - if quant_dtype is not None or custom_quantizer is not None: - captured_model = torch.export.export(model, inputs, strict=False).module() - if qat_training_data: - quantizer = custom_quantizer or make_quantizer( - quant_dtype=quant_dtype, - is_qat=True, - backend=backend, - soc_model=soc_model, - ) - # qat training - annotated_model = qat_train( - model, captured_model, quantizer, qat_training_data - ) - else: - quantizer = custom_quantizer or make_quantizer( - quant_dtype=quant_dtype, backend=backend, soc_model=soc_model - ) - # ptq calibration - with torch.no_grad(): - annotated_model = ptq_calibrate(captured_model, quantizer, dataset) - - quantized_model = convert_pt2e(annotated_model) - edge_prog_mgr = to_edge_transform_and_lower_to_qnn( - quantized_model, - inputs, - compile_spec, - constant_methods=metadata, - passes_job=passes_job, - dep_table=passes_dependency, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - ) - else: - edge_prog_mgr = to_edge_transform_and_lower_to_qnn( - model, - inputs, - compile_spec, - constant_methods=metadata, - passes_job=passes_job, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - ) - - if qnn_intermediate_debugger: - lowered_module_nodes = get_delegates(edge_prog_mgr.exported_program().graph) - assert ( - len(lowered_module_nodes) == 1 - ), "Graph with partitions are currently unsupported." - - lowered_module_node = lowered_module_nodes[0] - lower_module = getattr( - edge_prog_mgr.exported_program().graph_module, lowered_module_node.name - ) - edge_module = lower_module.original_module.module() - qnn_intermediate_debugger.set_edge_module(edge_module=edge_module) - - allocate_io = not (shared_buffer or direct_mode_build_path) - executorch_config = ExecutorchBackendConfig( - # For shared buffer, user must pass the memory address - # which is allocated by RPC memory to executor runner. - # Therefore, won't want to pre-allocate - # by memory manager in runtime. - memory_planning_pass=MemoryPlanningPass( - alloc_graph_input=allocate_io, - alloc_graph_output=allocate_io, - ), - segment_alignment=get_qnn_context_binary_alignment(), - ) - pte_name = f"{file_name}.pte" - exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) - with open(pte_name, "wb") as file: - exec_prog_mgr.write_to_file(file) - - def make_output_dir(path: str): if os.path.exists(path): shutil.rmtree(path, ignore_errors=True) @@ -686,10 +170,6 @@ def evaluate_squad(predicted_texts: List[str], target_texts: List[str]): return results -def get_backend_type(backend: str): - return getattr(QnnExecuTorchBackendType, f"k{backend.title()}Backend") - - def get_imagenet_dataset( dataset_path, data_size, image_shape, crop_size=None, shuffle=True ): @@ -880,238 +360,3 @@ def __getitem__(self, idx): targets.append(labels) return inputs, targets - - -def setup_common_args_and_variables(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "-m", - "--model", - help="SoC model of current device. e.g. 'SM8550' for Snapdragon 8 Gen 2.", - type=str, - required=True, - ) - - parser.add_argument( - "-b", - "--build_folder", - help="path to cmake binary directory for target platform, e.g., /path/to/build-android", - type=str, - required=True, - ) - - parser.add_argument( - "-H", - "--host", - help="hostname where android device is connected.", - default=None, - type=str, - ) - - parser.add_argument( - "--online_prepare", - help="If specified, compose QNN graph on device.", - action="store_true", - default=False, - ) - - parser.add_argument( - "--ip", - help="IPC address for delivering execution result", - default="", - type=str, - ) - - parser.add_argument( - "--port", - help="IPC port for delivering execution result", - default=-1, - type=int, - ) - - parser.add_argument( - "-S", - "--skip_delegate_node_ids", - help="If specified, skip delegation for the specified node based on node ids. Node ids should be separated by comma. e.g., aten_relu_default_10,aten_relu_default_2", - default=None, - type=str, - ) - - parser.add_argument( - "-f", - "--skip_delegate_node_ops", - help="If specified, skip delegation for the specified op. Node ops should be separated by comma. e.g., aten.add.Tensor,aten.relu.default", - default=None, - type=str, - ) - - parser.add_argument( - "-c", - "--compile_only", - help="If specified, only compile the model.", - action="store_true", - default=False, - ) - - parser.add_argument( - "-s", - "--device", - help="serial number for android device communicated via ADB.", - type=str, - ) - - parser.add_argument( - "--backend", - help="Backend to be deployed ('htp'/'gpu' are currently supported).", - choices=["htp", "gpu"], - default="htp", - type=str, - ) - - parser.add_argument( - "-z", - "--shared_buffer", - help="Enables usage of shared buffer between application and backend for graph I/O.", - action="store_true", - ) - - parser.add_argument( - "--skip_push", - help="If specified, skip pushing files to device.", - action="store_true", - default=False, - ) - - parser.add_argument( - "-D", - "--dump_intermediate_outputs", - help="If specified, enable dump intermediate outputs", - action="store_true", - default=False, - ) - - parser.add_argument( - "-x", - "--enable_x86_64", - help="Enable unittest to be executed on x86_64 platform", - action="store_true", - ) - - parser.add_argument( - "--ci", - help="This flag is for Continuous Integration(CI) purpose and is NOT recommended to turn on for typical use cases. It will use random inputs instead of real inputs.", - action="store_true", - default=False, - ) - - parser.add_argument( - "--seed", - help="Set the seed for generating random numbers in both torch and random.", - type=int, - ) - - parser.add_argument( - "-t", - "--target", - help="Target platform for deployment", - choices=[ - "aarch64-android", - "aarch64-oe-linux-gcc9.3", - "aarch64-oe-linux-gcc11.2", - ], - default="aarch64-android", - type=str, - ) - - parser.add_argument( - "--pre_gen_pte", - help="Run the pre-generated pte in the given directory.", - type=str, - ) - - parser.add_argument( - "--direct_build_folder", - help="Path to cmake binary directory for direct_mode. E.g., path/to/build-hexagon." - "If enabled, run self-defined protocol to control fastrpc communication.", - type=str, - ) - - # QNN_SDK_ROOT might also be an argument, but it is used in various places. - # So maybe it's fine to just use the environment. - if "QNN_SDK_ROOT" not in os.environ: - raise RuntimeError("Environment variable QNN_SDK_ROOT must be set") - print(f"QNN_SDK_ROOT={os.getenv('QNN_SDK_ROOT')}") - - def validate(args): - if not args.compile_only and args.device is None: - raise RuntimeError( - "device serial is required if not compile only. " - "Please specify a device serial by -s/--device argument." - ) - if args.seed: - torch.manual_seed(args.seed) - np.random.seed(args.seed) - random.seed(args.seed) - - parser.set_defaults(validate=validate) - - return parser - - -def parse_skip_delegation_node(args): - skip_node_id_set = set() - skip_node_op_set = set() - - if args.skip_delegate_node_ids is not None: - skip_node_id_set = set(map(str, args.skip_delegate_node_ids.split(","))) - print("Skipping following node ids: ", skip_node_id_set) - - if args.skip_delegate_node_ops is not None: - skip_node_op_set = set(map(str, args.skip_delegate_node_ops.split(","))) - print("Skipping following node ops: ", skip_node_op_set) - - return skip_node_id_set, skip_node_op_set - - -def generate_inputs( - dest_path: str, - input_list_filename: str, - inputs=None, - prefix_input_filename: str = "", -): - input_list_file = None - input_files = [] - - def prepare_input_file(tensor, fd, index, sub_index): - # transform torch.Tensor to raw file - input_file_name = f"{prefix_input_filename}_input_{index}_{sub_index}.raw" - input_file_path = f"{dest_path}/{input_file_name}" - if not isinstance(tensor, torch.Tensor): - tensor = torch.tensor(tensor) - tensor.detach().numpy().tofile(input_file_path) - input_files.append(input_file_path) - # prepare input_list - if sub_index > 0: - fd.write(" ") - fd.write(input_file_name) - - # Prepare input data - if inputs is not None: - input_list_file = f"{dest_path}/{input_list_filename}" - - with open(input_list_file, "w") as f: - for idx, data in enumerate(inputs): - sub_index = 0 - for d in data: - if isinstance(d, (list, tuple)): - for sub_d in d: - prepare_input_file(sub_d, f, idx, sub_index) - sub_index += 1 - else: - prepare_input_file(d, f, idx, sub_index) - sub_index += 1 - - f.write("\n") - - return input_list_file, input_files