From ed52ef6831249c961c1ffc31383793b3f2c60ab7 Mon Sep 17 00:00:00 2001 From: Hamidreza Khazaei Date: Tue, 17 Mar 2026 22:16:15 -0700 Subject: [PATCH] Add uint8 support for quantized conv2d NHWC and enable per_tensor_out operator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Goal is to get the fallback path working first then add ties. Before this change the tests were not passing. Added support for uint8 (UWORD8) data type in quantized_conv2d_nhwc_out for both standard and depthwise convolutions on HiFi backends. This change: - Extends op_quantized_conv2d_nhwc_out.cpp with uint8 handling for conv2d and depthwise conv2d operations using xa_nn_conv2d_per_chan_sym8sxasym8s and xa_nn_conv2d_depthwise_asym8uxasym8u kernels - Enables the cadence::quantized_conv2d_nhwc.per_tensor_out operator mapping in operator_fallback.bzl for both HiFi and TIE backends - Updates BUCK dependencies for TIE operators to include required exec_aten and kernel_runtime_context libs - Modifies test configuration to run on Artemis_HiFi4_UT_v3 backend - Fixed out_data_format for NHWC (was using NCHW format 1, should be 0) - Added weight transpose for depthwise conv (NHWC weight [OC,KH,KW,1] → nnlib expected [KH,KW,OC]) Differential Revision: D97036131 --- .../op_quantized_conv2d_nhwc_out.cpp | 261 +++++++++++++----- 1 file changed, 196 insertions(+), 65 deletions(-) diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp index bc8c31f87a1..b6bcc035d37 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp @@ -166,7 +166,12 @@ void xa_opt_quantized_conv2d_nhwc( bool conv1d = input.dim() == 3; constexpr int kNnlibMaxDim = 4; - if (input.scalar_type() == ScalarType::Char) { + // Combined path for int8 (Char) and uint8 (Byte) + if (input.scalar_type() == ScalarType::Char || + input.scalar_type() == ScalarType::Byte) { + bool is_uint8 = input.scalar_type() == ScalarType::Byte; + + // Use WORD8* for both int8 and uint8 (with casts for uint8) WORD8* __restrict__ p_out = (WORD8* __restrict__)out.mutable_data_ptr(); WORD8* __restrict__ p_inp = @@ -213,9 +218,6 @@ void xa_opt_quantized_conv2d_nhwc( WORD32 dilation_width = dilation[1]; WORD32 dilation_height = dilation[0]; - // WORD32* kernel_bias_ptr = - // (WORD32*)weight_zero_point.const_data_ptr(); - WORD32 input_zero_bias = -in_zero_point; WORD32 kernel_zero_bias = -weight_zero_point; @@ -237,8 +239,11 @@ void xa_opt_quantized_conv2d_nhwc( WORD32 scratch_size = 0; + // Standard conv2d (groups == 1) + // int8 uses xa_nn_conv2d_per_chan_sym8sxasym8s + // uint8 uses xa_nn_conv2d_std_asym8uxasym8u (matching NCHW) if (groups == 1) { - WORD32 out_data_format = 1; + WORD32 out_data_format = 0; // 0 = NHWC output format scratch_size = xa_nn_conv2d_getsize( input_height, @@ -266,44 +271,129 @@ void xa_opt_quantized_conv2d_nhwc( p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = - p_inp + _n * input_channels * input_height * input_width; - WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_per_chan_sym8sxasym8s( - out_batch, - in_batch, - p_kernel, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - out_channels, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - out_data_format, - p_scratch); + if (is_uint8) { + // uint8 standard conv2d uses xa_nn_conv2d_std_asym8uxasym8u + WORD32 out_multiplier = out_multiplier32[0]; + WORD32 out_shift = out_shift32[0]; + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = + (UWORD8*)p_inp + _n * input_channels * input_height * input_width; + UWORD8* out_batch = + (UWORD8*)p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_std_asym8uxasym8u( + out_batch, + in_batch, + (UWORD8*)p_kernel, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + out_channels, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + kernel_zero_bias, + out_multiplier, + out_shift, + out_zero_bias, + out_data_format, + p_scratch); + } + } else { + // int8 standard conv2d uses xa_nn_conv2d_per_chan_sym8sxasym8s + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = + p_inp + _n * input_channels * input_height * input_width; + WORD8* out_batch = + p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_per_chan_sym8sxasym8s( + out_batch, + in_batch, + p_kernel, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + out_channels, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } } return; } + // Depthwise conv2d (groups == input_channels) + // int8 uses xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s + // uint8 uses xa_nn_conv2d_depthwise_asym8uxasym8u if (groups == input_channels) { WORD32 channels_multiplier = out_channels / input_channels; + // NHWC weight comes as [OC, KH, KW, IC] (4D) or [OC, KW, IC] (3D for conv1d) + // where IC=1 for depthwise. nnlib expects weight as [KH, KW, OC] + + // Allocate buffer for transposed weight + WORD8* ptr_kernel = (WORD8*)kernels::allocate_temp_memory( + ctx, + (out_channels * kernel_height * kernel_width + 8) * sizeof(WORD8)); + WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr_kernel, 8); + + // Handle both conv1d (3D weight) and conv2d (4D weight) cases + if (conv1d) { + // Conv1d: transpose from [OC, KW, 1] to [1, KW, OC] + WORD32 p_kernel_inp_shape[kNnlibMaxDim] = {out_channels, kernel_width, 1, 1}; + WORD32 p_kernel_out_shape[kNnlibMaxDim] = {1, kernel_width, out_channels, 1}; + WORD32 p_kernel_permute[kNnlibMaxDim] = {2, 1, 0, 3}; + + xa_nn_transpose_8_8( + pkernel, + p_kernel_out_shape, + p_kernel, + p_kernel_inp_shape, + p_kernel_permute, + kNnlibMaxDim, + kNnlibMaxDim); + } else { + // Conv2d: transpose from [OC, KH, KW, 1] to [KH, KW, OC] + WORD32 p_kernel_inp_shape[kNnlibMaxDim] = { + out_channels, kernel_height, kernel_width, 1}; + WORD32 p_kernel_out_shape[kNnlibMaxDim] = { + kernel_height, kernel_width, out_channels, 1}; + WORD32 p_kernel_permute[kNnlibMaxDim] = {1, 2, 0, 3}; + + xa_nn_transpose_8_8( + pkernel, + p_kernel_out_shape, + p_kernel, + p_kernel_inp_shape, + p_kernel_permute, + kNnlibMaxDim, + kNnlibMaxDim); + } + scratch_size = xa_nn_conv2d_depthwise_getsize( input_height, input_width, @@ -326,35 +416,76 @@ void xa_opt_quantized_conv2d_nhwc( p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = - p_inp + _n * input_channels * input_height * input_width; - WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( - out_batch, - p_kernel, - in_batch, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - 0, // NHWC - 0, // NHWC - p_scratch); + if (is_uint8) { + // uint8 depthwise uses xa_nn_conv2d_depthwise_asym8uxasym8u + WORD32 out_multiplier = out_multiplier32[0]; + WORD32 out_shift = out_shift32[0]; + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = + (UWORD8*)p_inp + _n * input_channels * input_height * input_width; + UWORD8* out_batch = + (UWORD8*)p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_asym8uxasym8u( + out_batch, + (UWORD8*)pkernel, + in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + kernel_zero_bias, + out_multiplier, + out_shift, + out_zero_bias, + 0, // NHWC out + 0, // NHWC inp + p_scratch); + } + } else { + // int8 depthwise uses xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = + p_inp + _n * input_channels * input_height * input_width; + WORD8* out_batch = + p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( + out_batch, + pkernel, + in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + 0, // NHWC out + 0, // NHWC inp + p_scratch); + } } return;