Skip to content

Latest commit

 

History

History
312 lines (283 loc) · 13.9 KB

File metadata and controls

312 lines (283 loc) · 13.9 KB

Contribution for Operator Annotation

Thank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. Reading and following these guidelines will help you quickly get the essentials of annotating an operator in QnnQuantizer to unblock yourself and land pull requests more efficiently.

Sections

References

Qualcomm AI Engine Direct

PyTorch

Getting Started

Before extending operator for quantization annotation, please make sure the operator builder has been well-implemented (learn more on this tutorial).

Behavior of Annotation

In order to conduct PTQ for floating point precision graph, observers are required to be inserted after each graph nodes. The observed numeric range will go through different algorithms and return statistics of scale, offset to represent data in fixed point.

Stages could be shown as:

  • Floating point nn.Module after torch.export.export

    flowchart TB
        input & kernel & bias --> id1(convolution) --> output
    
    Loading
  • Inserting observers for inspecting numeric range

    flowchart TB
        input --> id2(input_act_obs) --> id1(convolution) --> id3(output_act_obs) --> output
        kernel --> id4(weight_obs) --> id1(convolution)
        bias --> id5(bias_obs) --> id1(convolution)
    
    Loading
  • Cascade QDQ pairs after landing encodings

    flowchart TB
        input --> id2(Q_i) --> id3(DQ_i) --> id1(convolution) --> id4(Q_o) --> id5(DQ_o) --> output
        kernel --> id6(Q_k) --> id7(DQ_k) --> id1(convolution)
        bias --> id8(Q_b) --> id9(DQ_b) --> id1(convolution)
    
    Loading

Qualcomm backend will consume the generated encodings and lower operators with fixed precision. This tutorial will guide you through the details of inserting observer and some useful utilities.

Register Annotation via Operator Type

Let's start with hooking callback for designated operator target in annotators/{backend}_rules.py:

def register_annotator(aten_ops: List[OpOverload], qnn_op: Optional[str]):
    def _wrap(op_def: GeneralOpDef):
        for aten_op in aten_ops:
            annotate_fn = op_def.annotate
            validate_fn = op_def.validate
            rule = OpQuantRule(
                aten_op=aten_op,
                qnn_op=qnn_op,
                annotate_fn=annotate_fn,
                validate_fn=validate_fn,
            )
            _RULES[rule.aten_op] = rule
        return rule

    return _wrap

The register_annotator decorator provides a convenient way to attach your own annotation and validation logic, which requires list of operator type as its input argument and a QNN operation name
For example, the torch activation functions have copy, in-place implementation with small difference appears in naming (an extra _ postfix), which will map to the same Core ATen operators after to_edge:

@register_annotator(
    [torch.ops.aten.relu.default, torch.ops.aten.relu_.default],
    QnnConstants.OpRelu.op_name,
)

Where torch.ops.aten.relu.default / torch.ops.aten.relu_.default map to copy / in-place version and both will be converted into torch.ops.aten.relu.default ultimately.
The qnn_op is used to specify quantization constraints for validation with the BackendOpInfo library. If an operator doesn’t directly correspond to a QNN operator, you can set its value to None, which will skip validation for that operator.

@register_annotator([operator.getitem], qnn_op=None)

The operator.getitem function acts as a skip operator in the QNN backend and does not correspond to any QNN operator. Therefore, we assign qnn_op=None.

Create a base class GeneralOpDef that establishes the standard annotation and validation function behaviors.

class GeneralOpDef:
    @staticmethod
    def annotate(node: Node, quantization_config: QuantizationConfig):
        annotate_single_in_single_out(node, quantization_config)

    @staticmethod
    def validate(
        node: Node, constraints_list: List[NormalizedConstraints], soc_info: SocInfo
    ) -> bool:
        valid = True
        # If there's no quantization annotation, we can't validate against constraints.
        if not _is_annotated([node]):
            return valid
        valid &= validate_against_backend_constraints(node, constraints_list)
        return valid

The annotate function signature is defined as follow with two arguments:

@staticmethod
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
  • node: graph node required to be observed
  • quantization_config: data structure describing quantization configurations for IO activation / weight / bias

Example of Conv2d Annotation

Conv2d accepts up to three input tensors: input activation, kernel, bias. There are constraints imposed by Qualcomm AI Engine Direct Manual.
Take 8-bit fixed point as example:

  • weight: must be symmetrically quantized if per-channel observer is applied
  • bias: must have QNN_DATATYPE_SFIXED_POINT_32 and be symmetrically quantized with expected encoding scales = weight.scales * input.scale, offset = 0 if per-channel observer is applied.

Let's look at the simplified per-channel quantization configuration used in QnnQuantizer:

def ptq_per_channel_quant_config(
    act_dtype=torch.uint8, weight_dtype=torch.int8
) -> QuantizationConfig:
    ...
    act_quantization_spec = QuantizationSpec(
        dtype=act_dtype,
        quant_min=torch.iinfo(act_dtype).min,
        quant_max=torch.iinfo(act_dtype).max,
        qscheme=torch.per_tensor_affine,
        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
    )

    weight_quantization_spec = QuantizationSpec(
        dtype=torch.int8,
        quant_min=torch.iinfo(weight_dtype).min + 1,
        quant_max=torch.iinfo(weight_dtype).max,
        qscheme=torch.per_channel_symmetric,
        ch_axis=0,
        observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args),
    )

    bias_quantization_spec = _derived_bias_quant_spec

    quantization_config = QuantizationConfig(
        input_activation=act_quantization_spec,
        output_activation=act_quantization_spec,
        weight=weight_quantization_spec,
        bias=bias_quantization_spec,
    )

    return quantization_config

Here we choose torch.uint8 + MinMaxObserver for better coverage of IO activation and apply rules to weight w/PerChannelMinMaxObserver, bias w/_derived_bias_quant_spec (a callable method to calculate encoding in desired way) to meet aforementioned constraints. The well-defined quantizaton_config will then be shipped to callback for annotation.

Now, we can start to fill in the function body:

  • Register annotator

    @register_annotator(
        [
            torch.ops.aten.conv1d.default,
            torch.ops.aten.conv2d.default,
            torch.ops.aten.conv2d.padding,
            torch.ops.aten.convolution.default,
        ]
    )
    class Conv2d(GeneralOpDef):

    There are multiple targets expected to meet our annotation criteria, it's encouraged to do so for code reuse.

  • Define a annotation function interface

        @staticmethod
        def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
  • Define map of input quantization spec

            if _is_annotated([node]):
                return
    
            # block quantization
            if quantization_config.block_size is not None:
                quantization_config.weight.observer_or_fake_quant_ctr.p.keywords.update(
                    {QCOM_BLOCK_SIZE: quantization_config.block_size}
                )
    
            input_qspec_map = {}
    
            # annotate input activation
            input_act = node.args[0]
            input_spec = quantization_config.input_activation
            input_qspec_map[input_act] = input_spec
    
            # annotate kernel
            kernel = node.args[1]
            input_qspec_map[kernel] = quantization_config.weight
    
            # annotate bias
            if len(node.args) > 2:
                bias = node.args[2]
                input_qspec_map[bias] = quantization_config.bias(node)

    We first check if current graph node has been annotated. If not, an input_qspec_map dictionary required by PyTorch framework will be declared for providing mapping between graph nodes and their configurations.
    The parameters' order could be found here mentioned in ATen Operator Definitions. Since bias node is optional, the implementation will invoke _derived_bias_quant_spec to calculate the per-channel bias encoding only if it exists.

  • Update node's meta with framework compatible data structure

            node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
                input_qspec_map=input_qspec_map,
                output_qspec=quantization_config.output_activation,
                _annotated=True,
            )

    After done processing input_qspec_map, it's required to have it in node's meta with special tag (Q_ANNOTATION_KEY) for convert_pt2e to properly insert observers.

  • Define a validation function interface

        @staticmethod
        def validate(
            node: Node, constraints_list: List[NormalizedConstraints], soc_info: SocInfo
        ) -> bool:
  • Check if current node is annotated

            valid = True
            if not _is_annotated([node]):
                return valid
  • Check if current node supports LPBQ

            weight_node = node.args[1]
            weight_qspec = node.meta[Q_ANNOTATION_KEY].input_qspec_map.get(
                weight_node, None
            )
            if (
                weight_qspec
                and weight_qspec.observer_or_fake_quant_ctr.p.keywords.get(
                    QCOM_BLOCK_SIZE, None
                )
                is not None
            ):
                valid &= validate_lpbq_support(soc_info)
                if not valid:
                    logging.warning(
                        f"LPBQ (16a4w block-wise quantization) requires V69 or newer for {node.name}"
                    )
  • Check if current node supports 16a16w quantization

        act_node = node.args[0]
        act_qspec = node.meta[Q_ANNOTATION_KEY].input_qspec_map.get(act_node, None)
        if (
            act_qspec
            and act_qspec.dtype == torch.int32
            and weight_qspec
            and weight_qspec.dtype == torch.int32
        ):
            valid &= validate_16a16w_support(soc_info)
            if not valid:
                logging.warning(
                    f"16-bit activations + 16-bit weights requires V73 or newer for {node.name}"
                )
  • Validate the current node against the backend constraints obtained from BackendOpInfo based on the qnn_op.

        valid &= validate_against_backend_constraints(node, constraints_list)
        return valid
    • Validate against the backend constraints by doing the following:
      • Make sure that SharedQuantizationSpec is applied for is_math_invariant operator, such as view operations.
      • Check the scale and zero_point values for specific operations. For example, sigmoid op requires scale = 1 / (q_max - q_min + 1) and zero_point = 0.
      • Ensure that the qscheme satisfies symmetric constraints.
      • Verify that the input and output dtype are supported.

Common Annotators

For operators without extra parameters to be observed, there are pre-defined annotation method for convenience:

  • Single in single out operators, e.g.:

    @register_annotator(
        [torch.ops.aten.relu.default, torch.ops.aten.relu_.default],
        QnnConstants.OpRelu.op_name,
    )
    class Relu(GeneralOpDef):
        pass
  • Binary in single out operators, e.g.:

    @register_annotator(
        [torch.ops.aten.add, torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor],
        QnnConstants.OpElementWiseAdd.op_name,
    )
    class Add(GeneralOpDef):
        @staticmethod
        def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
            annotate_binary(node, quantization_config)
  • Shared encodings between input / output, e.g.:

    # For operators without arithmetical function, IOs are expected to own the same encodings.
    @register_annotator(
        [
            torch.ops.aten.permute.default,
            torch.ops.aten.swapaxes.default,
            torch.ops.aten.transpose.int,
        ],
        QnnConstants.OpTranspose.op_name,
    )
    class Permute(GeneralOpDef):
        @staticmethod
        def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
            annotate_in_out_obs_sharing_op(node, quantization_config)
            if not _is_annotated([node]):
                annotate_single_in_share_out(node, quantization_config)

    This annotator only works for single-in-single-out scenario with node's input that has already been annotated. If not, we still need to invoke annotate_single_in_share_out again (this path should be less likely).

Issues

Please refer to the issue section for more information.

Pull Requests

Please refer to the PR section for more information.