Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/builders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ Please help update following table if you are contributing new operators:
| PoolMax2d | ✓ |
| Prelu | ✓ |
| Quantize | ✓ |
| Rand | ✓ |
| ReduceMax | ✓ |
| ReduceMean | ✓ |
| ReduceMin | ✓ |
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
op_pow,
op_prelu,
op_quantize,
op_rand,
op_relu,
op_repeat,
op_reshape,
Expand Down Expand Up @@ -190,7 +191,7 @@
op_pow,
op_prelu,
op_quantize,
op_unary,
op_rand,
op_relu,
op_repeat,
op_reshape,
Expand Down Expand Up @@ -218,6 +219,7 @@
op_topk,
op_to,
op_transpose,
op_unary,
op_unbind,
op_unsqueeze,
op_upsample_bilinear2d,
Expand Down
79 changes: 79 additions & 0 deletions backends/qualcomm/builders/op_rand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpRandomUniformLike, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Rand(NodeVisitor):
target = ["aten.rand.default", "aten.rand_like.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
) -> PyQnnManager.PyQnnOpWrapper:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Add input type guard to support only int32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid comment, I'll clarify since the description was inaccurate. This op doesn't use input values, just the input's shape, so technically it can "support" input types other than UINT32, it just doesn't make a difference. I clarified the description, added a comment in the code and a floating-point test for good measure.

output_tensor = node.meta["val"]
output_shape = list(output_tensor.shape)

shape_data = np.array(output_shape, dtype=np.uint32)
shape_dims = [len(output_shape)]

shape_tensor_wrapper = PyQnnManager.TensorWrapper(
f"{node.name}_shape",
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, # QNN only supports UINT32 for the RandomUniformLike op input
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
{},
len(shape_dims),
shape_dims,
[],
shape_data,
True,
)

output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

rand_op = PyQnnManager.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpRandomUniformLike.op_name,
)

rand_op.AddInputTensors([shape_tensor_wrapper])
rand_op.AddOutputTensors([output_tensor_wrapper])

rand_op.AddScalarParam(
OpRandomUniformLike.param_low,
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{QCOM_DATA: np.float32(0.0)},
)

rand_op.AddScalarParam(
OpRandomUniformLike.param_high,
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{QCOM_DATA: np.float32(1.0)},
)

return rand_op
7 changes: 7 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,13 @@ class OpQuantize:
op_name: str = "Quantize"


@dataclass(init=False, frozen=True)
class OpRandomUniformLike:
op_name: str = "RandomUniformLike"
param_low: str = "low"
param_high: str = "high"


@dataclass(init=False, frozen=True)
class OpReduceMax:
op_name: str = "ReduceMax"
Expand Down
8 changes: 8 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,6 +1781,14 @@ def forward(self, x):
return self.prelu(x)


class Rand(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.rand_like(x) + x


class Reciprocal(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
37 changes: 37 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,24 @@ def test_qnn_backend_prelu(self):
index += 1
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_rand(self):
sample_inputs = [
(torch.randn(3, 4, 5),),
(torch.randn(2, 8),),
(
torch.randn(
10,
),
),
(torch.randn(1, 3, 32, 32),),
]
for i, sample_input in enumerate(sample_inputs):
with self.subTest(i=i):
module = Rand() # noqa: F405
self.lower_module_and_test_output(
module, sample_input, assert_output_equal=False
)

def test_qnn_backend_reciprocal(self):
module = Reciprocal() # noqa: F405
sample_input = (torch.randn([2, 2, 2, 2]),)
Expand Down Expand Up @@ -3993,6 +4011,25 @@ def test_qnn_backend_prelu(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_rand(self):
sample_inputs = [
(torch.randn(3, 4, 5),),
(torch.randn(2, 8),),
(
torch.randn(
10,
),
),
(torch.randn(1, 3, 32, 32),),
]
for i, sample_input in enumerate(sample_inputs):
with self.subTest(i=i):
module = Rand() # noqa: F405
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(
module, sample_input, assert_output_equal=False
)

def test_qnn_backend_reciprocal(self):
module = Reciprocal() # noqa: F405
sample_input = (torch.randn([2, 5, 1, 3]),)
Expand Down
Loading