diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 382fe50c525..01e8503dc26 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -458,6 +458,7 @@ Please help update following table if you are contributing new operators: | PoolMax2d | ✓ | | Prelu | ✓ | | Quantize | ✓ | +| Rand | ✓ | | ReduceMax | ✓ | | ReduceMean | ✓ | | ReduceMin | ✓ | diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 518a7b7fa8c..14f53840dd7 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -79,6 +79,7 @@ op_pow, op_prelu, op_quantize, + op_rand, op_relu, op_repeat, op_reshape, @@ -190,7 +191,7 @@ op_pow, op_prelu, op_quantize, - op_unary, + op_rand, op_relu, op_repeat, op_reshape, @@ -218,6 +219,7 @@ op_topk, op_to, op_transpose, + op_unary, op_unbind, op_unsqueeze, op_upsample_bilinear2d, diff --git a/backends/qualcomm/builders/op_rand.py b/backends/qualcomm/builders/op_rand.py new file mode 100644 index 00000000000..79b1b90f16e --- /dev/null +++ b/backends/qualcomm/builders/op_rand.py @@ -0,0 +1,79 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager + +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpRandomUniformLike, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Rand(NodeVisitor): + target = ["aten.rand.default", "aten.rand_like.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: + output_tensor = node.meta["val"] + output_shape = list(output_tensor.shape) + + shape_data = np.array(output_shape, dtype=np.uint32) + shape_dims = [len(output_shape)] + + shape_tensor_wrapper = PyQnnManager.TensorWrapper( + f"{node.name}_shape", + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, # QNN only supports UINT32 for the RandomUniformLike op input + PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + {}, + len(shape_dims), + shape_dims, + [], + shape_data, + True, + ) + + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + rand_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpRandomUniformLike.op_name, + ) + + rand_op.AddInputTensors([shape_tensor_wrapper]) + rand_op.AddOutputTensors([output_tensor_wrapper]) + + rand_op.AddScalarParam( + OpRandomUniformLike.param_low, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + {QCOM_DATA: np.float32(0.0)}, + ) + + rand_op.AddScalarParam( + OpRandomUniformLike.param_high, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + {QCOM_DATA: np.float32(1.0)}, + ) + + return rand_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 94febc4123f..58037459bc8 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -497,6 +497,13 @@ class OpQuantize: op_name: str = "Quantize" +@dataclass(init=False, frozen=True) +class OpRandomUniformLike: + op_name: str = "RandomUniformLike" + param_low: str = "low" + param_high: str = "high" + + @dataclass(init=False, frozen=True) class OpReduceMax: op_name: str = "ReduceMax" diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index bd4060fa9f3..b00417fd7e2 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1781,6 +1781,14 @@ def forward(self, x): return self.prelu(x) +class Rand(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.rand_like(x) + x + + class Reciprocal(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index bdea2161d2d..5a04878bb92 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1663,6 +1663,24 @@ def test_qnn_backend_prelu(self): index += 1 self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rand(self): + sample_inputs = [ + (torch.randn(3, 4, 5),), + (torch.randn(2, 8),), + ( + torch.randn( + 10, + ), + ), + (torch.randn(1, 3, 32, 32),), + ] + for i, sample_input in enumerate(sample_inputs): + with self.subTest(i=i): + module = Rand() # noqa: F405 + self.lower_module_and_test_output( + module, sample_input, assert_output_equal=False + ) + def test_qnn_backend_reciprocal(self): module = Reciprocal() # noqa: F405 sample_input = (torch.randn([2, 2, 2, 2]),) @@ -3993,6 +4011,25 @@ def test_qnn_backend_prelu(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rand(self): + sample_inputs = [ + (torch.randn(3, 4, 5),), + (torch.randn(2, 8),), + ( + torch.randn( + 10, + ), + ), + (torch.randn(1, 3, 32, 32),), + ] + for i, sample_input in enumerate(sample_inputs): + with self.subTest(i=i): + module = Rand() # noqa: F405 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output( + module, sample_input, assert_output_equal=False + ) + def test_qnn_backend_reciprocal(self): module = Reciprocal() # noqa: F405 sample_input = (torch.randn([2, 5, 1, 3]),)