diff --git a/examples/models/llama/runner/targets.bzl b/examples/models/llama/runner/targets.bzl index f5f30f73054..81a2df117f4 100644 --- a/examples/models/llama/runner/targets.bzl +++ b/examples/models/llama/runner/targets.bzl @@ -6,6 +6,21 @@ def _get_operator_lib(aten = False): else: return ["//executorch/configurations:optimized_native_cpu_ops", "//executorch/extension/llm/custom_ops:custom_ops"] +def _get_torchao_lowbit_deps(): + """Returns torchao lowbit kernel deps for shared embedding and linear on ARM builds.""" + if runtime.is_oss: + return [] + else: + # Use select to conditionally include torchao lowbit kernels only on ARM64 builds + # These kernels are only available for aarch64 architecture + return select({ + "DEFAULT": [], + "ovr_config//cpu:arm64": [ + "//xplat/pytorch/ao/torchao/csrc/cpu/shared_kernels/embedding_xbit:op_embedding_xbit_executorch", + "//xplat/pytorch/ao/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight:op_linear_8bit_act_xbit_weight_executorch", + ], + }) + def get_qnn_dependency(): # buck build -c executorch.enable_qnn=true //executorch/examples/models/llama/runner:runner # Check if QNN is enabled before including the dependency @@ -48,7 +63,7 @@ def define_common_targets(): "//pytorch/tokenizers:llama2c_tokenizer", "//pytorch/tokenizers:hf_tokenizer", "//pytorch/tokenizers:regex_lookahead", - ] + (_get_operator_lib(aten)) + ([ + ] + (_get_operator_lib(aten)) + _get_torchao_lowbit_deps() + ([ # Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE) # Therefore enable it explicitly for now to avoid failing tests "//executorch/backends/vulkan:vulkan_backend_lib",