From dd1824a647a10736ee934c6153409ec1967b94b1 Mon Sep 17 00:00:00 2001 From: LindseyMei <648816901@qq.com> Date: Mon, 29 Jun 2026 08:39:09 +0000 Subject: [PATCH] feat: support softmax operator on metax Add MetaX backend for softmax operator, reusing cuda/kernel.cuh with hccub/cub block reduce headers. Tested on MetaX C500: all shapes/axes/dtypes/inplace modes pass. Signed-off-by: LindseyMei <648816901@qq.com> --- .../ops/softmax/metax/softmax_metax.h | 8 + .../ops/softmax/metax/softmax_metax.maca | 162 ++++++++++++++++++ src/infiniop/ops/softmax/operator.cc | 16 ++ 3 files changed, 186 insertions(+) create mode 100644 src/infiniop/ops/softmax/metax/softmax_metax.h create mode 100644 src/infiniop/ops/softmax/metax/softmax_metax.maca diff --git a/src/infiniop/ops/softmax/metax/softmax_metax.h b/src/infiniop/ops/softmax/metax/softmax_metax.h new file mode 100644 index 000000000..d05575647 --- /dev/null +++ b/src/infiniop/ops/softmax/metax/softmax_metax.h @@ -0,0 +1,8 @@ +#ifndef __SOFTMAX_METAX_H__ +#define __SOFTMAX_METAX_H__ + +#include "../softmax.h" + +DESCRIPTOR(metax) + +#endif // __SOFTMAX_METAX_H__ diff --git a/src/infiniop/ops/softmax/metax/softmax_metax.maca b/src/infiniop/ops/softmax/metax/softmax_metax.maca new file mode 100644 index 000000000..d2ea09ac2 --- /dev/null +++ b/src/infiniop/ops/softmax/metax/softmax_metax.maca @@ -0,0 +1,162 @@ +#include "../../../devices/metax/metax_common.h" +#include "softmax_metax.h" + +#ifdef ENABLE_METAX_MC_API +#include +#else +#include +#endif +#include "../../../devices/metax/metax_kernel_common.h" + +#include "../cuda/kernel.cuh" + +template +INFINIOP_METAX_KERNEL blockSoftmax( + Tdata *y, const Tdata *x, + size_t dimsize, + ptrdiff_t stride) { + blockSoftmaxKernel(x, y, dimsize, stride); +} + +template +INFINIOP_METAX_KERNEL warpSoftmax( + Tdata *y, const Tdata *x, + size_t othersize, + size_t dimsize, + ptrdiff_t stride) { + warpSoftmaxKernel(x, y, othersize, dimsize, stride); +} + +namespace op::softmax::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + int axis) { + auto info = SoftmaxInfo::create(y_desc, x_desc, axis); + CHECK_RESULT(info); + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype, + size_t othersize, size_t dimsize, ptrdiff_t stride, + hcStream_t stream) { + int num_blocks = (int)othersize; + if (dtype == INFINI_DTYPE_F16) { + if (dimsize > 1024) { + blockSoftmax + <<>>((half *)y, (const half *)x, + dimsize, stride); + } else if (dimsize > 31) { + constexpr unsigned int BLOCK_SIZE_x = 32; + constexpr unsigned int BLOCK_SIZE_y = 32; + constexpr int numPerThreadx = 32; + int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + warpSoftmax + <<>>((half *)y, (const half *)x, + othersize, dimsize, stride); + } else { + constexpr unsigned int BLOCK_SIZE_x = 16; + constexpr unsigned int BLOCK_SIZE_y = 32; + constexpr int numPerThreadx = 2; + int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + warpSoftmax + <<>>((half *)y, (const half *)x, + othersize, dimsize, stride); + } + + } else if (dtype == INFINI_DTYPE_BF16) { + if (dimsize > 1024) { + blockSoftmax + <<>>((cuda_bfloat16 *)y, (const cuda_bfloat16 *)x, + dimsize, stride); + } else if (dimsize > 31) { + constexpr unsigned int BLOCK_SIZE_x = 32; + constexpr unsigned int BLOCK_SIZE_y = 32; + constexpr int numPerThreadx = 32; + int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + warpSoftmax + <<>>((cuda_bfloat16 *)y, (const cuda_bfloat16 *)x, + othersize, dimsize, stride); + } else { + constexpr unsigned int BLOCK_SIZE_x = 16; + constexpr unsigned int BLOCK_SIZE_y = 32; + constexpr int numPerThreadx = 2; + int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + warpSoftmax + <<>>((cuda_bfloat16 *)y, (const cuda_bfloat16 *)x, + othersize, dimsize, stride); + } + + } else if (dtype == INFINI_DTYPE_F32) { + if (dimsize > 1024) { + blockSoftmax + <<>>((float *)y, (const float *)x, + dimsize, stride); + } else if (dimsize > 31) { + constexpr unsigned int BLOCK_SIZE_x = 32; + constexpr unsigned int BLOCK_SIZE_y = 32; + constexpr int numPerThreadx = 32; + int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + warpSoftmax + <<>>((float *)y, (const float *)x, + othersize, dimsize, stride); + } else { + constexpr unsigned int BLOCK_SIZE_x = 16; + constexpr unsigned int BLOCK_SIZE_y = 32; + constexpr int numPerThreadx = 2; + int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + warpSoftmax + <<>>((float *)y, (const float *)x, + othersize, dimsize, stride); + } + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, + void *y, + const void *x, + void *stream_) const { + hcStream_t stream = (hcStream_t)stream_; + if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + y, x, _info.dtype, _info.othersize, _info.dimsize, _info.stride, stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + y, x, _info.dtype, _info.othersize, _info.dimsize, _info.stride, stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::softmax::metax diff --git a/src/infiniop/ops/softmax/operator.cc b/src/infiniop/ops/softmax/operator.cc index 1664e257f..040c6160c 100644 --- a/src/infiniop/ops/softmax/operator.cc +++ b/src/infiniop/ops/softmax/operator.cc @@ -10,6 +10,10 @@ #include "nvidia/softmax_nvidia.cuh" #endif +#ifdef ENABLE_METAX_API +#include "metax/softmax_metax.h" +#endif + __INFINI_C infiniStatus_t infiniopCreateSoftmaxDescriptor( infiniopHandle_t handle, infiniopSoftmaxDescriptor_t *desc_ptr, @@ -43,6 +47,9 @@ __INFINI_C infiniStatus_t infiniopCreateSoftmaxDescriptor( #endif #ifdef ENABLE_CAMBRICON_API CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -74,6 +81,9 @@ __INFINI_C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescrip #endif #ifdef ENABLE_CAMBRICON_API GET(INFINI_DEVICE_CAMBRICON, bang); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -110,6 +120,9 @@ __INFINI_C infiniStatus_t infiniopSoftmax( #endif #ifdef ENABLE_CAMBRICON_API CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -141,6 +154,9 @@ __INFINI_C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescri #endif #ifdef ENABLE_CAMBRICON_API DESTROY(INFINI_DEVICE_CAMBRICON, bang); +#endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;