From 050d217aeb0e6d273a11dbb65d6bb6f1725ec926 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Tue, 28 Apr 2026 06:42:19 +0000 Subject: [PATCH 01/25] Add `blas_amax` and `blas_amin` operators --- include/infiniop.h | 2 + include/infiniop/ops/blas_amax.h | 24 ++++ include/infiniop/ops/blas_amin.h | 24 ++++ src/infiniop/devices/metax/metax_ht2mc.h | 7 + .../ops/blas_amax/bang/blas_amax_bang.h | 8 ++ .../ops/blas_amax/bang/blas_amax_bang.mlu | 96 +++++++++++++ .../blas_amax/bang/blas_amax_bang_kernel.mlu | 134 ++++++++++++++++++ src/infiniop/ops/blas_amax/blas_amax.h | 90 ++++++++++++ .../ops/blas_amax/cpu/blas_amax_cpu.cc | 103 ++++++++++++++ .../ops/blas_amax/cpu/blas_amax_cpu.h | 8 ++ .../ops/blas_amax/metax/blas_amax_metax.cc | 71 ++++++++++ .../ops/blas_amax/metax/blas_amax_metax.h | 8 ++ src/infiniop/ops/blas_amax/operator.cc | 122 ++++++++++++++++ .../ops/blas_amin/bang/blas_amin_bang.h | 8 ++ .../ops/blas_amin/bang/blas_amin_bang.mlu | 96 +++++++++++++ .../blas_amin/bang/blas_amin_bang_kernel.mlu | 133 +++++++++++++++++ src/infiniop/ops/blas_amin/blas_amin.h | 90 ++++++++++++ .../ops/blas_amin/cpu/blas_amin_cpu.cc | 103 ++++++++++++++ .../ops/blas_amin/cpu/blas_amin_cpu.h | 8 ++ .../ops/blas_amin/metax/blas_amin_metax.cc | 71 ++++++++++ .../ops/blas_amin/metax/blas_amin_metax.h | 8 ++ src/infiniop/ops/blas_amin/operator.cc | 122 ++++++++++++++++ test/infiniop/blas_amax.py | 132 +++++++++++++++++ test/infiniop/blas_amin.py | 132 +++++++++++++++++ test/infiniop/libinfiniop/op_register.py | 64 +++++++++ 25 files changed, 1664 insertions(+) create mode 100644 include/infiniop/ops/blas_amax.h create mode 100644 include/infiniop/ops/blas_amin.h create mode 100644 src/infiniop/ops/blas_amax/bang/blas_amax_bang.h create mode 100644 src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu create mode 100644 src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu create mode 100644 src/infiniop/ops/blas_amax/blas_amax.h create mode 100644 src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.cc create mode 100644 src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.h create mode 100644 src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc create mode 100644 src/infiniop/ops/blas_amax/metax/blas_amax_metax.h create mode 100644 src/infiniop/ops/blas_amax/operator.cc create mode 100644 src/infiniop/ops/blas_amin/bang/blas_amin_bang.h create mode 100644 src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu create mode 100644 src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu create mode 100644 src/infiniop/ops/blas_amin/blas_amin.h create mode 100644 src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.cc create mode 100644 src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.h create mode 100644 src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc create mode 100644 src/infiniop/ops/blas_amin/metax/blas_amin_metax.h create mode 100644 src/infiniop/ops/blas_amin/operator.cc create mode 100644 test/infiniop/blas_amax.py create mode 100644 test/infiniop/blas_amin.py diff --git a/include/infiniop.h b/include/infiniop.h index 0ec995823..b456de868 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -20,6 +20,8 @@ #include "infiniop/ops/avg_pool1d.h" #include "infiniop/ops/avg_pool3d.h" #include "infiniop/ops/binary_cross_entropy_with_logits.h" +#include "infiniop/ops/blas_amax.h" +#include "infiniop/ops/blas_amin.h" #include "infiniop/ops/block_diag.h" #include "infiniop/ops/broadcast_to.h" #include "infiniop/ops/causal_softmax.h" diff --git a/include/infiniop/ops/blas_amax.h b/include/infiniop/ops/blas_amax.h new file mode 100644 index 000000000..be69ad4df --- /dev/null +++ b/include/infiniop/ops/blas_amax.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_BLAS_AMAX_API_H__ +#define __INFINIOP_BLAS_AMAX_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopBlasAmaxDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateBlasAmaxDescriptor(infiniopHandle_t handle, + infiniopBlasAmaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t result); + +__INFINI_C __export infiniStatus_t infiniopGetBlasAmaxWorkspaceSize(infiniopBlasAmaxDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopBlasAmax(infiniopBlasAmaxDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyBlasAmaxDescriptor(infiniopBlasAmaxDescriptor_t desc); + +#endif // __INFINIOP_BLAS_AMAX_API_H__ \ No newline at end of file diff --git a/include/infiniop/ops/blas_amin.h b/include/infiniop/ops/blas_amin.h new file mode 100644 index 000000000..f568acd4a --- /dev/null +++ b/include/infiniop/ops/blas_amin.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_BLAS_AMIN_API_H__ +#define __INFINIOP_BLAS_AMIN_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopBlasAminDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateBlasAminDescriptor(infiniopHandle_t handle, + infiniopBlasAminDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t result); + +__INFINI_C __export infiniStatus_t infiniopGetBlasAminWorkspaceSize(infiniopBlasAminDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopBlasAmin(infiniopBlasAminDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyBlasAminDescriptor(infiniopBlasAminDescriptor_t desc); + +#endif // __INFINIOP_BLAS_AMIN_API_H__ \ No newline at end of file diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index d391f61d2..8efa3b5a2 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -118,6 +118,11 @@ #define hcblasGemmEx mcblasGemmEx #define hcblasCreate mcblasCreate #define hcblasComputeType_t mcblasComputeType_t +#define hcblasSetPointerMode mcblasSetPointerMode +#define hcblasIsamax mcblasIsamax +#define hcblasIdamax mcblasIdamax +#define hcblasIsamin mcblasIsamin +#define hcblasIdamin mcblasIdamin #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS #define HCBLAS_OP_T MCBLAS_OP_T #define HCBLAS_OP_N MCBLAS_OP_N @@ -125,6 +130,8 @@ #define HCBLAS_GEMM_DEFAULT MCBLAS_GEMM_DEFAULT #define HCBLAS_COMPUTE_32F_FAST_TF32 MCBLAS_COMPUTE_32F_FAST_TF32 #define HCBLAS_COMPUTE_32F MCBLAS_COMPUTE_32F +#define HCBLAS_POINTER_MODE_DEVICE MCBLAS_POINTER_MODE_DEVICE +#define HCBLAS_POINTER_MODE_HOST MCBLAS_POINTER_MODE_HOST #define __hpcc_fp8_e4m3 __maca_fp8_e4m3 #define __hpcc_bfloat16 __maca_bfloat16 #endif diff --git a/src/infiniop/ops/blas_amax/bang/blas_amax_bang.h b/src/infiniop/ops/blas_amax/bang/blas_amax_bang.h new file mode 100644 index 000000000..88a0250a7 --- /dev/null +++ b/src/infiniop/ops/blas_amax/bang/blas_amax_bang.h @@ -0,0 +1,8 @@ +#ifndef __BLAS_AMAX_BANG_H__ +#define __BLAS_AMAX_BANG_H__ + +#include "../blas_amax.h" + +DESCRIPTOR(bang) + +#endif // __BLAS_AMAX_BANG_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu b/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu new file mode 100644 index 000000000..38b450caf --- /dev/null +++ b/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu @@ -0,0 +1,96 @@ +#include "../../../devices/bang/common_bang.h" +#include "blas_amax_bang.h" +#include "blas_amax_bang_kernel.mlu" + +namespace op::blas_amax::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = BlasAmaxInfo::createBlasAmaxInfo(x_desc, result_desc); + CHECK_RESULT(info); + + // Create descriptor + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateBlasAmax( + const BlasAmaxInfo &info, + const Tdata *x, + int *result, + cnrtQueue_t queue) { + + const size_t size = info.getSize(); + const ptrdiff_t incx = info.getIncx(); + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeUnion1; + + if (incx == 1) { + blasAmaxKernelContiguous<<>>( + size, + x, + result); + } else { + blasAmaxKernelStrided<<>>( + size, + x, + incx, + result); + } + + cnrtQueueSync(queue); + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_BLAS_AMAX(TDATA) \ + calculateBlasAmax(_info, \ + (const TDATA *)x, \ + (int *)result, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_BLAS_AMAX(half); + case INFINI_DTYPE_F32: + return CALCULATE_BLAS_AMAX(float); + case INFINI_DTYPE_BF16: + return CALCULATE_BLAS_AMAX(bfloat16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_BLAS_AMAX + +} // namespace op::blas_amax::bang \ No newline at end of file diff --git a/src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu b/src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu new file mode 100644 index 000000000..be6efb60a --- /dev/null +++ b/src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu @@ -0,0 +1,134 @@ +#include "../../../devices/bang/common_bang.h" +#include "blas_amax_bang.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void blasAmaxKernelContiguous( + size_t n, + const Tdata *x, + int *result) { + + __mlu_shared__ int shared_max_index[4]; + __mlu_shared__ Tdata shared_max_value[4]; + + Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); + size_t max_chunk_elements = nram_usable / sizeof(Tdata); + + int align_elements = ALIGN_SIZE / sizeof(Tdata); + if (align_elements == 0) { + align_elements = 1; + } + max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int core_elements = elements_per_core + (taskId < remain ? 1 : 0); + int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + + int chunks = core_elements / max_chunk_elements; + int chunk_rem = core_elements % max_chunk_elements; + + int max_index = -1; + Tdata max_value = static_cast(0); + + for (int c = 0; c < chunks; c++) { + size_t current_offset = core_offset + c * max_chunk_elements; + __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + + __bang_abs(nram_x, nram_x, max_chunk_elements); + + for (int i = 0; i < max_chunk_elements; i++) { + Tdata abs_val = nram_x[i]; + if (abs_val > max_value) { + max_value = abs_val; + max_index = current_offset + i; + } + } + } + + if (chunk_rem > 0) { + size_t current_offset = core_offset + chunks * max_chunk_elements; + + __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + + __bang_abs(nram_x, nram_x, chunk_rem); + + for (int i = 0; i < chunk_rem; i++) { + Tdata abs_val = nram_x[i]; + if (abs_val > max_value) { + max_value = abs_val; + max_index = current_offset + i; + } + } + } + + shared_max_index[coreId] = max_index; + shared_max_value[coreId] = max_value; + + __sync_cluster(); + + if (coreId == 0) { + int cluster_max_index = -1; + Tdata cluster_max_value = static_cast(0); + + for (int i = 0; i < coreDim; i++) { + if (shared_max_value[i] > cluster_max_value) { + cluster_max_value = shared_max_value[i]; + cluster_max_index = shared_max_index[i]; + } + } + + result[0] = cluster_max_index + 1; // Convert to 1-based index + } +} + +template +__mlu_global__ void blasAmaxKernelStrided( + size_t n, + const Tdata *x, + size_t incx, + int *result) { + + __mlu_shared__ int shared_max_index[4]; + __mlu_shared__ Tdata shared_max_value[4]; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0); + int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; + + int max_index = -1; + Tdata max_value = static_cast(0); + + for (int i = start_idx; i < start_idx + actual_tasks; ++i) { + size_t offset = i * incx; + Tdata abs_val = x[offset] > static_cast(0) ? x[offset] : -x[offset]; + + if (abs_val > max_value) { + max_value = abs_val; + max_index = i; + } + } + + shared_max_index[coreId] = max_index; + shared_max_value[coreId] = max_value; + + __sync_cluster(); + + if (coreId == 0) { + int cluster_max_index = -1; + Tdata cluster_max_value = static_cast(0); + + for (int i = 0; i < coreDim; i++) { + if (shared_max_value[i] > cluster_max_value) { + cluster_max_value = shared_max_value[i]; + cluster_max_index = shared_max_index[i]; + } + } + + result[0] = cluster_max_index + 1; // Convert to 1-based index + } +} \ No newline at end of file diff --git a/src/infiniop/ops/blas_amax/blas_amax.h b/src/infiniop/ops/blas_amax/blas_amax.h new file mode 100644 index 000000000..06247fe29 --- /dev/null +++ b/src/infiniop/ops/blas_amax/blas_amax.h @@ -0,0 +1,90 @@ +#ifndef __BLAS_AMAX_H__ +#define __BLAS_AMAX_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/blas_amax.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::blas_amax::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + BlasAmaxInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + BlasAmaxInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t result_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + const void *x, \ + void *result, \ + void *stream) const; \ + }; \ + } + +class BlasAmaxInfo { +private: + size_t _size; + ptrdiff_t _incx; + infiniDtype_t _dtype; + + BlasAmaxInfo(size_t size, + ptrdiff_t incx, + infiniDtype_t dtype) + : _size(size), _incx(incx), _dtype(dtype) {} + +public: + inline size_t getSize() const { return _size; } + inline ptrdiff_t getIncx() const { return _incx; } + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static utils::Result createBlasAmaxInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = x_desc->dtype(); + auto itype = result_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_DTYPE(itype, INFINI_DTYPE_I32); + + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto size = x_desc->numel(); + auto incx = x_desc->stride(0); + + BlasAmaxInfo info(size, incx, dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __BLAS_AMAX_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.cc b/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.cc new file mode 100644 index 000000000..237dfde64 --- /dev/null +++ b/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.cc @@ -0,0 +1,103 @@ +#include "blas_amax_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::blas_amax::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = BlasAmaxInfo::createBlasAmaxInfo(x_desc, result_desc); + CHECK_RESULT(info); + + // Create descriptor + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateBlasAmax( + const BlasAmaxInfo &info, + const Tdata *x, + int *result) { + + const ptrdiff_t size = info.getSize(); + const ptrdiff_t incx = info.getIncx(); + + if (size < 1 || incx == 0) { + result[0] = 0; + return INFINI_STATUS_SUCCESS; + } + + int max_index = 0; + if constexpr (std::is_same::value || std::is_same::value) { + float max_value = std::abs(utils::cast(x[0])); + + for (ptrdiff_t i = 1; i < size; ++i) { + float current_value = std::abs(utils::cast(x[i * incx])); + if (current_value > max_value) { + max_value = current_value; + max_index = i; + } + } + } else { + Tdata max_value = std::abs(x[0]); + + for (ptrdiff_t i = 1; i < size; ++i) { + Tdata current_value = std::abs(x[i * incx]); + if (current_value > max_value) { + max_value = current_value; + max_index = i; + } + } + } + + result[0] = utils::cast(max_index) + 1; + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_BLAS_AMAX(TDATA) \ + calculateBlasAmax(_info, \ + (const TDATA *)x, \ + (int *)result) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + (void)stream; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_BLAS_AMAX(fp16_t); + case INFINI_DTYPE_F32: + return CALCULATE_BLAS_AMAX(float); + case INFINI_DTYPE_F64: + return CALCULATE_BLAS_AMAX(double); + case INFINI_DTYPE_BF16: + return CALCULATE_BLAS_AMAX(bf16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_BLAS_AMAX + +} // namespace op::blas_amax::cpu \ No newline at end of file diff --git a/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.h b/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.h new file mode 100644 index 000000000..2aa42e756 --- /dev/null +++ b/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.h @@ -0,0 +1,8 @@ +#ifndef __BLAS_AMAX_CPU_H__ +#define __BLAS_AMAX_CPU_H__ + +#include "../blas_amax.h" + +DESCRIPTOR(cpu) + +#endif // __BLAS_AMAX_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc b/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc new file mode 100644 index 000000000..298ad5510 --- /dev/null +++ b/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc @@ -0,0 +1,71 @@ +#include "blas_amax_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::blas_amax::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = BlasAmaxInfo::createBlasAmaxInfo(x_desc, result_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const size_t size = _info.getSize(); + const ptrdiff_t incx = _info.getIncx(); + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasIsamax(handle, size, (const float *)x, incx, (int *)result)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasIdamax(handle, size, (const double *)x, incx, (int *)result)); + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::blas_amax::metax \ No newline at end of file diff --git a/src/infiniop/ops/blas_amax/metax/blas_amax_metax.h b/src/infiniop/ops/blas_amax/metax/blas_amax_metax.h new file mode 100644 index 000000000..e4f9af0c0 --- /dev/null +++ b/src/infiniop/ops/blas_amax/metax/blas_amax_metax.h @@ -0,0 +1,8 @@ +#ifndef __BLAS_AMAX_METAX_H__ +#define __BLAS_AMAX_METAX_H__ + +#include "../blas_amax.h" + +DESCRIPTOR(metax) + +#endif // __BLAS_AMAX_METAX_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_amax/operator.cc b/src/infiniop/ops/blas_amax/operator.cc new file mode 100644 index 000000000..2e8e37247 --- /dev/null +++ b/src/infiniop/ops/blas_amax/operator.cc @@ -0,0 +1,122 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/blas_amax.h" + +#ifdef ENABLE_CPU_API +#include "cpu/blas_amax_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/blas_amax_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/blas_amax_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateBlasAmaxDescriptor( + infiniopHandle_t handle, + infiniopBlasAmaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::blas_amax::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + x_desc, \ + result_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetBlasAmaxWorkspaceSize(infiniopBlasAmaxDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__INFINI_C infiniStatus_t infiniopBlasAmax( + infiniopBlasAmaxDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, x, result, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyBlasAmaxDescriptor(infiniopBlasAmaxDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} \ No newline at end of file diff --git a/src/infiniop/ops/blas_amin/bang/blas_amin_bang.h b/src/infiniop/ops/blas_amin/bang/blas_amin_bang.h new file mode 100644 index 000000000..6cc84d3a8 --- /dev/null +++ b/src/infiniop/ops/blas_amin/bang/blas_amin_bang.h @@ -0,0 +1,8 @@ +#ifndef __BLAS_AMIN_BANG_H__ +#define __BLAS_AMIN_BANG_H__ + +#include "../blas_amin.h" + +DESCRIPTOR(bang) + +#endif // __BLAS_AMIN_BANG_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu b/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu new file mode 100644 index 000000000..4b727093c --- /dev/null +++ b/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu @@ -0,0 +1,96 @@ +#include "../../../devices/bang/common_bang.h" +#include "blas_amin_bang.h" +#include "blas_amin_bang_kernel.mlu" + +namespace op::blas_amin::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = BlasAminInfo::createBlasAminInfo(x_desc, result_desc); + CHECK_RESULT(info); + + // Create descriptor + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateBlasAmin( + const BlasAminInfo &info, + const Tdata *x, + int *result, + cnrtQueue_t queue) { + + const size_t size = info.getSize(); + const ptrdiff_t incx = info.getIncx(); + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeUnion1; + + if (incx == 1) { + blasAminKernelContiguous<<>>( + size, + x, + result); + } else { + blasAminKernelStrided<<>>( + size, + x, + incx, + result); + } + + cnrtQueueSync(queue); + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_BLAS_AMIN(TDATA) \ + calculateBlasAmin(_info, \ + (const TDATA *)x, \ + (int *)result, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_BLAS_AMIN(half); + case INFINI_DTYPE_F32: + return CALCULATE_BLAS_AMIN(float); + case INFINI_DTYPE_BF16: + return CALCULATE_BLAS_AMIN(bfloat16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_BLAS_AMIN + +} // namespace op::blas_amin::bang \ No newline at end of file diff --git a/src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu b/src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu new file mode 100644 index 000000000..57b14169f --- /dev/null +++ b/src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu @@ -0,0 +1,133 @@ +#include "../../../devices/bang/common_bang.h" +#include "blas_amin_bang.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void blasAminKernelContiguous( + size_t n, + const Tdata *x, + int *result) { + + __mlu_shared__ int shared_min_index[4]; + __mlu_shared__ Tdata shared_min_value[4]; + + Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); + size_t max_chunk_elements = nram_usable / sizeof(Tdata); + + int align_elements = ALIGN_SIZE / sizeof(Tdata); + if (align_elements == 0) { + align_elements = 1; + } + max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int core_elements = elements_per_core + (taskId < remain ? 1 : 0); + int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + + int chunks = core_elements / max_chunk_elements; + int chunk_rem = core_elements % max_chunk_elements; + + int min_index = -1; + Tdata min_value = static_cast(0); + bool initialized = false; + + for (int c = 0; c < chunks; c++) { + size_t current_offset = core_offset + c * max_chunk_elements; + __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + + __bang_abs(nram_x, nram_x, max_chunk_elements); + + for (int i = 0; i < max_chunk_elements; i++) { + Tdata abs_val = nram_x[i]; + if (!initialized || abs_val < min_value) { + min_value = abs_val; + min_index = current_offset + i; + initialized = true; + } + } + } + + if (chunk_rem > 0) { + size_t current_offset = core_offset + chunks * max_chunk_elements; + + __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + + __bang_abs(nram_x, nram_x, chunk_rem); + + for (int i = 0; i < chunk_rem; i++) { + Tdata abs_val = nram_x[i]; + if (!initialized || abs_val < min_value) { + min_value = abs_val; + min_index = current_offset + i; + initialized = true; + } + } + } + + shared_min_index[coreId] = min_index; + shared_min_value[coreId] = min_value; + + __sync_cluster(); + + if (coreId == 0) { + for (int i = 1; i < coreDim; i++) { + if (shared_min_index[i] >= 0 && shared_min_value[i] < min_value) { + min_value = shared_min_value[i]; + min_index = shared_min_index[i]; + } + } + + result[0] = min_index + 1; // Convert to 1-based index + } +} + +template +__mlu_global__ void blasAminKernelStrided( + size_t n, + const Tdata *x, + size_t incx, + int *result) { + + __mlu_shared__ int shared_min_index[4]; + __mlu_shared__ Tdata shared_min_value[4]; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0); + int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; + + int min_index = -1; + Tdata min_value = static_cast(0); + bool initialized = false; + + for (int i = start_idx; i < start_idx + actual_tasks; ++i) { + size_t offset = i * incx; + Tdata abs_val = x[offset] > static_cast(0) ? x[offset] : -x[offset]; + + if (!initialized || abs_val < min_value) { + min_value = abs_val; + min_index = i; + initialized = true; + } + } + + shared_min_index[coreId] = min_index; + shared_min_value[coreId] = min_value; + + __sync_cluster(); + + if (coreId == 0) { + for (int i = 1; i < coreDim; i++) { + if (shared_min_index[i] >= 0 && shared_min_value[i] < min_value) { + min_value = shared_min_value[i]; + min_index = shared_min_index[i]; + } + } + + result[0] = min_index + 1; // Convert to 1-based index + } +} \ No newline at end of file diff --git a/src/infiniop/ops/blas_amin/blas_amin.h b/src/infiniop/ops/blas_amin/blas_amin.h new file mode 100644 index 000000000..a05db711e --- /dev/null +++ b/src/infiniop/ops/blas_amin/blas_amin.h @@ -0,0 +1,90 @@ +#ifndef __BLAS_AMIN_H__ +#define __BLAS_AMIN_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/blas_amin.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::blas_amin::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + BlasAminInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + BlasAminInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t result_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + const void *x, \ + void *result, \ + void *stream) const; \ + }; \ + } + +class BlasAminInfo { +private: + size_t _size; + ptrdiff_t _incx; + infiniDtype_t _dtype; + + BlasAminInfo(size_t size, + ptrdiff_t incx, + infiniDtype_t dtype) + : _size(size), _incx(incx), _dtype(dtype) {} + +public: + inline size_t getSize() const { return _size; } + inline ptrdiff_t getIncx() const { return _incx; } + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static utils::Result createBlasAminInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = x_desc->dtype(); + auto itype = result_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_DTYPE(itype, INFINI_DTYPE_I32); + + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto size = x_desc->numel(); + auto incx = x_desc->stride(0); + + BlasAminInfo info(size, incx, dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __BLAS_AMIN_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.cc b/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.cc new file mode 100644 index 000000000..ce6b567c0 --- /dev/null +++ b/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.cc @@ -0,0 +1,103 @@ +#include "blas_amin_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::blas_amin::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = BlasAminInfo::createBlasAminInfo(x_desc, result_desc); + CHECK_RESULT(info); + + // Create descriptor + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateBlasAmin( + const BlasAminInfo &info, + const Tdata *x, + int *result) { + + const ptrdiff_t size = info.getSize(); + const ptrdiff_t incx = info.getIncx(); + + if (size < 1 || incx == 0) { + result[0] = 0; + return INFINI_STATUS_SUCCESS; + } + + int min_index = 0; + if constexpr (std::is_same::value || std::is_same::value) { + float min_value = std::abs(utils::cast(x[0])); + + for (ptrdiff_t i = 1; i < size; ++i) { + float current_value = std::abs(utils::cast(x[i * incx])); + if (current_value < min_value) { + min_value = current_value; + min_index = i; + } + } + } else { + Tdata min_value = std::abs(x[0]); + + for (ptrdiff_t i = 1; i < size; ++i) { + Tdata current_value = std::abs(x[i * incx]); + if (current_value < min_value) { + min_value = current_value; + min_index = i; + } + } + } + + result[0] = utils::cast(min_index) + 1; + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_BLAS_AMIN(TDATA) \ + calculateBlasAmin(_info, \ + (const TDATA *)x, \ + (int *)result) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + (void)stream; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_BLAS_AMIN(fp16_t); + case INFINI_DTYPE_F32: + return CALCULATE_BLAS_AMIN(float); + case INFINI_DTYPE_F64: + return CALCULATE_BLAS_AMIN(double); + case INFINI_DTYPE_BF16: + return CALCULATE_BLAS_AMIN(bf16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_BLAS_AMIN + +} // namespace op::blas_amin::cpu \ No newline at end of file diff --git a/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.h b/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.h new file mode 100644 index 000000000..ea4daa397 --- /dev/null +++ b/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.h @@ -0,0 +1,8 @@ +#ifndef __BLAS_AMIN_CPU_H__ +#define __BLAS_AMIN_CPU_H__ + +#include "../blas_amin.h" + +DESCRIPTOR(cpu) + +#endif // __BLAS_AMIN_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc b/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc new file mode 100644 index 000000000..d924f7f63 --- /dev/null +++ b/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc @@ -0,0 +1,71 @@ +#include "blas_amin_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::blas_amin::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = BlasAminInfo::createBlasAminInfo(x_desc, result_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const size_t size = _info.getSize(); + const ptrdiff_t incx = _info.getIncx(); + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasIsamin(handle, size, (const float *)x, incx, (int *)result)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasIdamin(handle, size, (const double *)x, incx, (int *)result)); + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::blas_amin::metax \ No newline at end of file diff --git a/src/infiniop/ops/blas_amin/metax/blas_amin_metax.h b/src/infiniop/ops/blas_amin/metax/blas_amin_metax.h new file mode 100644 index 000000000..ce4b6b0b7 --- /dev/null +++ b/src/infiniop/ops/blas_amin/metax/blas_amin_metax.h @@ -0,0 +1,8 @@ +#ifndef __BLAS_AMIN_METAX_H__ +#define __BLAS_AMIN_METAX_H__ + +#include "../blas_amin.h" + +DESCRIPTOR(metax) + +#endif // __BLAS_AMIN_METAX_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_amin/operator.cc b/src/infiniop/ops/blas_amin/operator.cc new file mode 100644 index 000000000..f47044263 --- /dev/null +++ b/src/infiniop/ops/blas_amin/operator.cc @@ -0,0 +1,122 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/blas_amin.h" + +#ifdef ENABLE_CPU_API +#include "cpu/blas_amin_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/blas_amin_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/blas_amin_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateBlasAminDescriptor( + infiniopHandle_t handle, + infiniopBlasAminDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::blas_amin::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + x_desc, \ + result_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetBlasAminWorkspaceSize(infiniopBlasAminDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__INFINI_C infiniStatus_t infiniopBlasAmin( + infiniopBlasAminDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, x, result, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyBlasAminDescriptor(infiniopBlasAminDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} \ No newline at end of file diff --git a/test/infiniop/blas_amax.py b/test/infiniop/blas_amax.py new file mode 100644 index 000000000..d4d806305 --- /dev/null +++ b/test/infiniop/blas_amax.py @@ -0,0 +1,132 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + +# ============================================================================== +# Configuration +# ============================================================================== + +_TEST_CASES = [ + # n, x_stride + (3, None), + (8, (2,)), + (32, None), + (257, (3,)), + (65535, None), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def test( + handle, + device, + n, + x_stride=None, + dtype=torch.float16, + sync=None, +): + torch.manual_seed(0) + if device != 0: + torch.cuda.manual_seed_all(0) + + x = TestTensor((n,), x_stride, dtype, device) + result = TestTensor(tuple(), None, InfiniDtype.I32, device, mode="zeros") + + print( + f"Testing blas_amax on {InfiniDeviceNames[device]} with n:{n} x_stride:{x_stride} " + f"dtype:{InfiniDtypeNames[dtype]}" + ) + + result_ref = torch.argmax(x.torch_tensor().abs()).to(torch.int32) + 1 + result.update_torch_tensor(result_ref) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateBlasAmaxDescriptor( + handle, + ctypes.byref(descriptor), + x.descriptor, + result.descriptor, + ) + ) + + for tensor in [x, result]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetBlasAmaxWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, result.device) + + def lib_blas_amax(): + check_error( + LIBINFINIOP.infiniopBlasAmax( + descriptor, + workspace.data(), + workspace.size(), + x.data(), + result.data(), + None, + ) + ) + + lib_blas_amax() + + if DEBUG: + debug(result.actual_tensor(), result.torch_tensor()) + assert torch.equal(result.actual_tensor(), result.torch_tensor()) + + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: torch.argmax(x.torch_tensor().abs()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_blas_amax(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + check_error(LIBINFINIOP.infiniopDestroyBlasAmaxDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92m Test passed! \033[0m") diff --git a/test/infiniop/blas_amin.py b/test/infiniop/blas_amin.py new file mode 100644 index 000000000..4899f85dd --- /dev/null +++ b/test/infiniop/blas_amin.py @@ -0,0 +1,132 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + +# ============================================================================== +# Configuration +# ============================================================================== + +_TEST_CASES = [ + # n, x_stride + (3, None), + (8, (2,)), + (32, None), + (257, (3,)), + (65535, None), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def test( + handle, + device, + n, + x_stride=None, + dtype=torch.float16, + sync=None, +): + torch.manual_seed(0) + if device != 0: + torch.cuda.manual_seed_all(0) + + x = TestTensor((n,), x_stride, dtype, device) + result = TestTensor(tuple(), None, InfiniDtype.I32, device, mode="zeros") + + print( + f"Testing blas_amin on {InfiniDeviceNames[device]} with n:{n} x_stride:{x_stride} " + f"dtype:{InfiniDtypeNames[dtype]}" + ) + + result_ref = torch.argmin(x.torch_tensor().abs()).to(torch.int32) + 1 + result.update_torch_tensor(result_ref) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateBlasAminDescriptor( + handle, + ctypes.byref(descriptor), + x.descriptor, + result.descriptor, + ) + ) + + for tensor in [x, result]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetBlasAminWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, result.device) + + def lib_blas_amin(): + check_error( + LIBINFINIOP.infiniopBlasAmin( + descriptor, + workspace.data(), + workspace.size(), + x.data(), + result.data(), + None, + ) + ) + + lib_blas_amin() + + if DEBUG: + debug(result.actual_tensor(), result.torch_tensor()) + assert torch.equal(result.actual_tensor(), result.torch_tensor()) + + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: torch.argmax(x.torch_tensor().abs()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_blas_amin(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + check_error(LIBINFINIOP.infiniopDestroyBlasAminDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92m Test passed! \033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 2802bc5bc..b92b17e58 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2159,3 +2159,67 @@ def fused_ffn_(lib): lib.infiniopDestroyFusedFFNDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def blas_amax_(lib): + lib.infiniopCreateBlasAmaxDescriptor.restype = c_int32 + lib.infiniopCreateBlasAmaxDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetBlasAmaxWorkspaceSize.restype = c_int32 + lib.infiniopGetBlasAmaxWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopBlasAmax.restype = c_int32 + lib.infiniopBlasAmax.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyBlasAmaxDescriptor.restype = c_int32 + lib.infiniopDestroyBlasAmaxDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def blas_amin_(lib): + lib.infiniopCreateBlasAminDescriptor.restype = c_int32 + lib.infiniopCreateBlasAminDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetBlasAminWorkspaceSize.restype = c_int32 + lib.infiniopGetBlasAminWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopBlasAmin.restype = c_int32 + lib.infiniopBlasAmin.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyBlasAminDescriptor.restype = c_int32 + lib.infiniopDestroyBlasAminDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] From c2dba21b2e34f5f4c5b0fd4c0b02e38d8a0a10f3 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Tue, 28 Apr 2026 06:42:53 +0000 Subject: [PATCH 02/25] Add `asum` operator --- include/infiniop.h | 1 + include/infiniop/ops/asum.h | 24 +++ src/infiniop/devices/metax/metax_ht2mc.h | 2 + src/infiniop/ops/asum/asum.h | 90 +++++++++++ src/infiniop/ops/asum/bang/asum_bang.h | 8 + src/infiniop/ops/asum/bang/asum_bang.mlu | 95 ++++++++++++ .../ops/asum/bang/asum_bang_kernel.mlu | 105 +++++++++++++ src/infiniop/ops/asum/cpu/asum_cpu.cc | 89 +++++++++++ src/infiniop/ops/asum/cpu/asum_cpu.h | 8 + src/infiniop/ops/asum/metax/asum_metax.cc | 71 +++++++++ src/infiniop/ops/asum/metax/asum_metax.h | 8 + src/infiniop/ops/asum/operator.cc | 124 +++++++++++++++ test/infiniop/asum.py | 143 ++++++++++++++++++ test/infiniop/libinfiniop/op_register.py | 32 ++++ 14 files changed, 800 insertions(+) create mode 100644 include/infiniop/ops/asum.h create mode 100644 src/infiniop/ops/asum/asum.h create mode 100644 src/infiniop/ops/asum/bang/asum_bang.h create mode 100644 src/infiniop/ops/asum/bang/asum_bang.mlu create mode 100644 src/infiniop/ops/asum/bang/asum_bang_kernel.mlu create mode 100644 src/infiniop/ops/asum/cpu/asum_cpu.cc create mode 100644 src/infiniop/ops/asum/cpu/asum_cpu.h create mode 100644 src/infiniop/ops/asum/metax/asum_metax.cc create mode 100644 src/infiniop/ops/asum/metax/asum_metax.h create mode 100644 src/infiniop/ops/asum/operator.cc create mode 100644 test/infiniop/asum.py diff --git a/include/infiniop.h b/include/infiniop.h index b456de868..df03a1865 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -15,6 +15,7 @@ #include "infiniop/ops/all.h" #include "infiniop/ops/asin.h" #include "infiniop/ops/asinh.h" +#include "infiniop/ops/asum.h" #include "infiniop/ops/atanh.h" #include "infiniop/ops/attention.h" #include "infiniop/ops/avg_pool1d.h" diff --git a/include/infiniop/ops/asum.h b/include/infiniop/ops/asum.h new file mode 100644 index 000000000..89336bce7 --- /dev/null +++ b/include/infiniop/ops/asum.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_ASUM_API_H__ +#define __INFINIOP_ASUM_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopAsumDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateAsumDescriptor(infiniopHandle_t handle, + infiniopAsumDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t result); + +__INFINI_C __export infiniStatus_t infiniopGetAsumWorkspaceSize(infiniopAsumDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopAsum(infiniopAsumDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyAsumDescriptor(infiniopAsumDescriptor_t desc); + +#endif // __INFINIOP_ASUM_API_H__ diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index 8efa3b5a2..89658a1ad 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -123,6 +123,8 @@ #define hcblasIdamax mcblasIdamax #define hcblasIsamin mcblasIsamin #define hcblasIdamin mcblasIdamin +#define hcblasSasum mcblasSasum +#define hcblasDasum mcblasDasum #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS #define HCBLAS_OP_T MCBLAS_OP_T #define HCBLAS_OP_N MCBLAS_OP_N diff --git a/src/infiniop/ops/asum/asum.h b/src/infiniop/ops/asum/asum.h new file mode 100644 index 000000000..0174e3591 --- /dev/null +++ b/src/infiniop/ops/asum/asum.h @@ -0,0 +1,90 @@ +#ifndef __ASUM_H__ +#define __ASUM_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/asum.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::asum::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + AsumInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + AsumInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t result_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + const void *x, \ + void *result, \ + void *stream) const; \ + }; \ + } + +class AsumInfo { +private: + size_t _size; + ptrdiff_t _incx; + infiniDtype_t _dtype; + + AsumInfo(size_t size, + ptrdiff_t incx, + infiniDtype_t dtype) + : _size(size), _incx(incx), _dtype(dtype) {} + +public: + inline size_t getSize() const { return _size; } + inline ptrdiff_t getIncx() const { return _incx; } + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static ResultType createAsumInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = x_desc->dtype(); + + CHECK_OR_RETURN(result_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto size = x_desc->numel(); + auto incx = x_desc->stride(0); + + AsumInfo info(size, incx, dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __ASUM_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asum/bang/asum_bang.h b/src/infiniop/ops/asum/bang/asum_bang.h new file mode 100644 index 000000000..e5d99dac3 --- /dev/null +++ b/src/infiniop/ops/asum/bang/asum_bang.h @@ -0,0 +1,8 @@ +#ifndef __ASUM_BANG_H__ +#define __ASUM_BANG_H__ + +#include "../asum.h" + +DESCRIPTOR(bang) + +#endif // __ASUM_BANG_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asum/bang/asum_bang.mlu b/src/infiniop/ops/asum/bang/asum_bang.mlu new file mode 100644 index 000000000..f9658b072 --- /dev/null +++ b/src/infiniop/ops/asum/bang/asum_bang.mlu @@ -0,0 +1,95 @@ +#include "../../../devices/bang/common_bang.h" +#include "asum_bang.h" +#include "asum_bang_kernel.mlu" + +namespace op::asum::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = AsumInfo::createAsumInfo(x_desc, result_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateAsum( + const AsumInfo &info, + const Tdata *x, + Tdata *result, + cnrtQueue_t queue) { + + const size_t size = info.getSize(); + const ptrdiff_t incx = info.getIncx(); + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeUnion1; + + if (incx == 1) { + asumKernelContiguous<<>>( + size, + x, + result); + } else { + asumKernelStrided<<>>( + size, + x, + incx, + result); + } + + cnrtQueueSync(queue); + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_ASUM(TDATA) \ + calculateAsum(_info, \ + (const TDATA *)x, \ + (TDATA *)result, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_ASUM(half); + case INFINI_DTYPE_F32: + return CALCULATE_ASUM(float); + case INFINI_DTYPE_BF16: + return CALCULATE_ASUM(bfloat16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_ASUM + +} // namespace op::asum::bang \ No newline at end of file diff --git a/src/infiniop/ops/asum/bang/asum_bang_kernel.mlu b/src/infiniop/ops/asum/bang/asum_bang_kernel.mlu new file mode 100644 index 000000000..37ba8f681 --- /dev/null +++ b/src/infiniop/ops/asum/bang/asum_bang_kernel.mlu @@ -0,0 +1,105 @@ +#include "../../../devices/bang/common_bang.h" +#include "asum_bang.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void asumKernelContiguous( + size_t n, + const Tdata *x, + Tdata *result) { + + __mlu_shared__ Tdata shared_partial_sum[4]; + + Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); + size_t max_chunk_elements = nram_usable / sizeof(Tdata); + + int align_elements = ALIGN_SIZE / sizeof(Tdata); + if (align_elements == 0) { + align_elements = 1; + } + max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int core_elements = elements_per_core + (taskId < remain ? 1 : 0); + int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + + int chunks = core_elements / max_chunk_elements; + int chunk_rem = core_elements % max_chunk_elements; + + Tdata partial_sum = static_cast(0); + + for (int c = 0; c < chunks; c++) { + size_t current_offset = core_offset + c * max_chunk_elements; + __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + + __bang_abs(nram_x, nram_x, max_chunk_elements); + + partial_sum += __bang_sum(nram_x, max_chunk_elements); + } + + if (chunk_rem > 0) { + size_t current_offset = core_offset + chunks * max_chunk_elements; + + __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + + __bang_abs(nram_x, nram_x, chunk_rem); + + partial_sum += __bang_sum(nram_x, chunk_rem); + } + + shared_partial_sum[coreId] = partial_sum; + + __sync_cluster(); + + if (coreId == 0) { + Tdata cluster_sum = static_cast(0); + + for (int i = 0; i < coreDim; i++) { + cluster_sum += shared_partial_sum[i]; + } + + result[0] = cluster_sum; + } +} + +template +__mlu_global__ void asumKernelStrided( + size_t n, + const Tdata *x, + size_t incx, + Tdata *result) { + + __mlu_shared__ Tdata shared_partial_sum[4]; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0); + int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; + + Tdata partial_sum = static_cast(0); + + for (int i = start_idx; i < start_idx + actual_tasks; ++i) { + size_t offset = i * incx; + Tdata abs_val = x[offset] > static_cast(0) ? x[offset] : -x[offset]; + + partial_sum += abs_val; + } + + shared_partial_sum[coreId] = partial_sum; + + __sync_cluster(); + + if (coreId == 0) { + Tdata cluster_sum = static_cast(0); + + for (int i = 0; i < coreDim; i++) { + cluster_sum += shared_partial_sum[i]; + } + + result[0] = cluster_sum; + } +} \ No newline at end of file diff --git a/src/infiniop/ops/asum/cpu/asum_cpu.cc b/src/infiniop/ops/asum/cpu/asum_cpu.cc new file mode 100644 index 000000000..52692d984 --- /dev/null +++ b/src/infiniop/ops/asum/cpu/asum_cpu.cc @@ -0,0 +1,89 @@ +#include "asum_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::asum::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = AsumInfo::createAsumInfo(x_desc, result_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateAsum( + const AsumInfo &info, + const Tdata *x, + Tdata *result) { + + const ptrdiff_t size = info.getSize(); + const ptrdiff_t incx = info.getIncx(); + + if constexpr (std::is_same::value || std::is_same::value) { + float total_sum = 0.0; + + for (ptrdiff_t i = 0; i < size; ++i) { + total_sum += std::abs(utils::cast(x[i * incx])); + } + + result[0] = utils::cast(total_sum); + } else { + Tdata total_sum = 0.0; + + for (ptrdiff_t i = 0; i < size; ++i) { + total_sum += std::abs(x[i * incx]); + } + + result[0] = total_sum; + } + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_ASUM(TDATA) \ + calculateAsum(_info, \ + (const TDATA *)x, \ + (TDATA *)result) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_ASUM(fp16_t); + case INFINI_DTYPE_BF16: + return CALCULATE_ASUM(bf16_t); + case INFINI_DTYPE_F32: + return CALCULATE_ASUM(float); + case INFINI_DTYPE_F64: + return CALCULATE_ASUM(double); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_ASUM + +} // namespace op::asum::cpu diff --git a/src/infiniop/ops/asum/cpu/asum_cpu.h b/src/infiniop/ops/asum/cpu/asum_cpu.h new file mode 100644 index 000000000..6cebbaf34 --- /dev/null +++ b/src/infiniop/ops/asum/cpu/asum_cpu.h @@ -0,0 +1,8 @@ +#ifndef __ASUM_CPU_H__ +#define __ASUM_CPU_H__ + +#include "../asum.h" + +DESCRIPTOR(cpu) + +#endif // __ASUM_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asum/metax/asum_metax.cc b/src/infiniop/ops/asum/metax/asum_metax.cc new file mode 100644 index 000000000..edd3febed --- /dev/null +++ b/src/infiniop/ops/asum/metax/asum_metax.cc @@ -0,0 +1,71 @@ +#include "asum_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::asum::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = AsumInfo::createAsumInfo(x_desc, result_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const size_t size = _info.getSize(); + const ptrdiff_t incx = _info.getIncx(); + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasSasum(handle, size, (const float *)x, incx, (float *)result)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasDasum(handle, size, (const double *)x, incx, (double *)result)); + break; + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::asum::metax \ No newline at end of file diff --git a/src/infiniop/ops/asum/metax/asum_metax.h b/src/infiniop/ops/asum/metax/asum_metax.h new file mode 100644 index 000000000..4b675b583 --- /dev/null +++ b/src/infiniop/ops/asum/metax/asum_metax.h @@ -0,0 +1,8 @@ +#ifndef __ASUM_METAX_H__ +#define __ASUM_METAX_H__ + +#include "../asum.h" + +DESCRIPTOR(metax) + +#endif // __ASUM_METAX_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asum/operator.cc b/src/infiniop/ops/asum/operator.cc new file mode 100644 index 000000000..51769f0d3 --- /dev/null +++ b/src/infiniop/ops/asum/operator.cc @@ -0,0 +1,124 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/asum.h" + +#ifdef ENABLE_CPU_API +#include "cpu/asum_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/asum_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/asum_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateAsumDescriptor( + infiniopHandle_t handle, + infiniopAsumDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::asum::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + x_desc, \ + result_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetAsumWorkspaceSize(infiniopAsumDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__INFINI_C infiniStatus_t infiniopAsum( + infiniopAsumDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, x, result, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t +infiniopDestroyAsumDescriptor(infiniopAsumDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/test/infiniop/asum.py b/test/infiniop/asum.py new file mode 100644 index 000000000..189b3dc7a --- /dev/null +++ b/test/infiniop/asum.py @@ -0,0 +1,143 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + +# ============================================================================== +# Configuration +# ============================================================================== + +_TEST_CASES = [ + # n, x_stride + (3, None), + (8, (2,)), + (32, None), + (257, (3,)), + (65535, None), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, + InfiniDtype.F64: {"atol": 1e-9, "rtol": 1e-9}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def test( + handle, + device, + n, + x_stride=None, + dtype=torch.float16, + sync=None, +): + torch.manual_seed(0) + if device != 0: + torch.cuda.manual_seed_all(0) + + x = TestTensor((n,), x_stride, dtype, device) + result = TestTensor(tuple(), None, dtype, device, mode="zeros") + + print( + f"Testing asum on {InfiniDeviceNames[device]} with n:{n} x_stride:{x_stride} " + f"dtype:{InfiniDtypeNames[dtype]}" + ) + + result_ref = torch.sum(x.torch_tensor().abs()) + result.update_torch_tensor(result_ref) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateAsumDescriptor( + handle, + ctypes.byref(descriptor), + x.descriptor, + result.descriptor, + ) + ) + + for tensor in [x, result]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetAsumWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, result.device) + + def lib_asum(): + check_error( + LIBINFINIOP.infiniopAsum( + descriptor, + workspace.data(), + workspace.size(), + x.data(), + result.data(), + None, + ) + ) + + lib_asum() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(result.actual_tensor(), result.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose( + result.actual_tensor(), result.torch_tensor(), atol=atol, rtol=rtol + ) + + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: torch.sum(x.torch_tensor().abs()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_asum(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + check_error(LIBINFINIOP.infiniopDestroyAsumDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92m Test passed! \033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index b92b17e58..16f85aca1 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2223,3 +2223,35 @@ def blas_amin_(lib): lib.infiniopDestroyBlasAminDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def asum_(lib): + lib.infiniopCreateAsumDescriptor.restype = c_int32 + lib.infiniopCreateAsumDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetAsumWorkspaceSize.restype = c_int32 + lib.infiniopGetAsumWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopAsum.restype = c_int32 + lib.infiniopAsum.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyAsumDescriptor.restype = c_int32 + lib.infiniopDestroyAsumDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] From 5df0cc85da9cbd412347bd5afb28f5e051a438e0 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Tue, 28 Apr 2026 06:52:17 +0000 Subject: [PATCH 03/25] Add `axpy` operator --- include/infiniop.h | 1 + include/infiniop/ops/axpy.h | 26 ++++ src/infiniop/devices/metax/metax_ht2mc.h | 2 + src/infiniop/ops/axpy/axpy.h | 101 ++++++++++++ src/infiniop/ops/axpy/bang/axpy_bang.h | 8 + src/infiniop/ops/axpy/bang/axpy_bang.mlu | 104 +++++++++++++ .../ops/axpy/bang/axpy_bang_kernel.mlu | 84 ++++++++++ src/infiniop/ops/axpy/cpu/axpy_cpu.cc | 91 +++++++++++ src/infiniop/ops/axpy/cpu/axpy_cpu.h | 8 + src/infiniop/ops/axpy/metax/axpy_metax.cc | 74 +++++++++ src/infiniop/ops/axpy/metax/axpy_metax.h | 8 + src/infiniop/ops/axpy/operator.cc | 127 +++++++++++++++ test/infiniop/axpy.py | 145 ++++++++++++++++++ test/infiniop/libinfiniop/op_register.py | 34 ++++ 14 files changed, 813 insertions(+) create mode 100644 include/infiniop/ops/axpy.h create mode 100644 src/infiniop/ops/axpy/axpy.h create mode 100644 src/infiniop/ops/axpy/bang/axpy_bang.h create mode 100644 src/infiniop/ops/axpy/bang/axpy_bang.mlu create mode 100644 src/infiniop/ops/axpy/bang/axpy_bang_kernel.mlu create mode 100644 src/infiniop/ops/axpy/cpu/axpy_cpu.cc create mode 100644 src/infiniop/ops/axpy/cpu/axpy_cpu.h create mode 100644 src/infiniop/ops/axpy/metax/axpy_metax.cc create mode 100644 src/infiniop/ops/axpy/metax/axpy_metax.h create mode 100644 src/infiniop/ops/axpy/operator.cc create mode 100644 test/infiniop/axpy.py diff --git a/include/infiniop.h b/include/infiniop.h index df03a1865..5b6674adc 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -20,6 +20,7 @@ #include "infiniop/ops/attention.h" #include "infiniop/ops/avg_pool1d.h" #include "infiniop/ops/avg_pool3d.h" +#include "infiniop/ops/axpy.h" #include "infiniop/ops/binary_cross_entropy_with_logits.h" #include "infiniop/ops/blas_amax.h" #include "infiniop/ops/blas_amin.h" diff --git a/include/infiniop/ops/axpy.h b/include/infiniop/ops/axpy.h new file mode 100644 index 000000000..1cf459602 --- /dev/null +++ b/include/infiniop/ops/axpy.h @@ -0,0 +1,26 @@ +#ifndef __INFINIOP_AXPY_API_H__ +#define __INFINIOP_AXPY_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopAxpyDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateAxpyDescriptor(infiniopHandle_t handle, + infiniopAxpyDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t alpha, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y); + +__INFINI_C __export infiniStatus_t infiniopGetAxpyWorkspaceSize(infiniopAxpyDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopAxpy(infiniopAxpyDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *alpha, + const void *x, + void *y, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyAxpyDescriptor(infiniopAxpyDescriptor_t desc); + +#endif // __INFINIOP_AXPY_API_H__ \ No newline at end of file diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index 89658a1ad..c771bf85b 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -125,6 +125,8 @@ #define hcblasIdamin mcblasIdamin #define hcblasSasum mcblasSasum #define hcblasDasum mcblasDasum +#define hcblasSaxpy mcblasSaxpy +#define hcblasDaxpy mcblasDaxpy #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS #define HCBLAS_OP_T MCBLAS_OP_T #define HCBLAS_OP_N MCBLAS_OP_N diff --git a/src/infiniop/ops/axpy/axpy.h b/src/infiniop/ops/axpy/axpy.h new file mode 100644 index 000000000..9d25ef29c --- /dev/null +++ b/src/infiniop/ops/axpy/axpy.h @@ -0,0 +1,101 @@ +#ifndef __AXPY_H__ +#define __AXPY_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/axpy.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::axpy::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + AxpyInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + AxpyInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t alpha_desc, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t y_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + const void *alpha, \ + const void *x, \ + void *y, \ + void *stream) const; \ + }; \ + } + +class AxpyInfo { +private: + size_t _size; + ptrdiff_t _incx; + ptrdiff_t _incy; + infiniDtype_t _dtype; + + AxpyInfo(size_t size, + ptrdiff_t incx, + ptrdiff_t incy, + infiniDtype_t dtype) + : _size(size), _incx(incx), _incy(incy), _dtype(dtype) {} + +public: + inline size_t getSize() const { return _size; } + inline ptrdiff_t getIncx() const { return _incx; } + inline ptrdiff_t getIncy() const { return _incy; } + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static ResultType createAxpyInfo( + infiniopTensorDescriptor_t alpha_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + CHECK_OR_RETURN(alpha_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = x_desc->dtype(); + + CHECK_OR_RETURN(alpha_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(x_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_OR_RETURN(alpha_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto size = x_desc->numel(); + auto incx = x_desc->stride(0); + auto incy = y_desc->stride(0); + + AxpyInfo info(size, incx, incy, dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __AXPY_H__ \ No newline at end of file diff --git a/src/infiniop/ops/axpy/bang/axpy_bang.h b/src/infiniop/ops/axpy/bang/axpy_bang.h new file mode 100644 index 000000000..a448303bb --- /dev/null +++ b/src/infiniop/ops/axpy/bang/axpy_bang.h @@ -0,0 +1,8 @@ +#ifndef __AXPY_BANG_H__ +#define __AXPY_BANG_H__ + +#include "../axpy.h" + +DESCRIPTOR(bang) + +#endif // __AXPY_BANG_H__ \ No newline at end of file diff --git a/src/infiniop/ops/axpy/bang/axpy_bang.mlu b/src/infiniop/ops/axpy/bang/axpy_bang.mlu new file mode 100644 index 000000000..004897e52 --- /dev/null +++ b/src/infiniop/ops/axpy/bang/axpy_bang.mlu @@ -0,0 +1,104 @@ +#include "../../../devices/bang/common_bang.h" +#include "axpy_bang.h" +#include "axpy_bang_kernel.mlu" + +namespace op::axpy::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t alpha_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = AxpyInfo::createAxpyInfo(alpha_desc, x_desc, y_desc); + CHECK_RESULT(info); + + // Create descriptor + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateAxpy( + const AxpyInfo &info, + const Tdata *alpha, + const Tdata *x, + Tdata *y, + cnrtQueue_t queue) { + + const size_t size = info.getSize(); + const ptrdiff_t incx = info.getIncx(); + const ptrdiff_t incy = info.getIncy(); + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeUnion1; + + if (incx == 1 && incy == 1) { + axpyKernelContiguous<<>>( + size, + alpha, + x, + y); + } else { + axpyKernelStrided<<>>( + size, + alpha, + x, + incx, + y, + incy); + } + + cnrtQueueSync(queue); + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_AXPY(TDATA) \ + calculateAxpy(_info, \ + (const TDATA *)alpha, \ + (const TDATA *)x, \ + (TDATA *)y, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *alpha, + const void *x, + void *y, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_AXPY(half); + case INFINI_DTYPE_F32: + return CALCULATE_AXPY(float); + case INFINI_DTYPE_BF16: + return CALCULATE_AXPY(bfloat16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_AXPY + +} // namespace op::axpy::bang \ No newline at end of file diff --git a/src/infiniop/ops/axpy/bang/axpy_bang_kernel.mlu b/src/infiniop/ops/axpy/bang/axpy_bang_kernel.mlu new file mode 100644 index 000000000..8e5ee3db0 --- /dev/null +++ b/src/infiniop/ops/axpy/bang/axpy_bang_kernel.mlu @@ -0,0 +1,84 @@ +#include "../../../devices/bang/common_bang.h" +#include "axpy_bang.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void axpyKernelContiguous( + size_t n, + const Tdata *alpha, + const Tdata *x, + Tdata *y) { + + Tdata *nram_align = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_align - nram_buffer); + size_t max_chunk_elements = nram_usable / (2 * sizeof(Tdata)); + + int align_elements = ALIGN_SIZE / sizeof(Tdata); + if (align_elements == 0) { + align_elements = 1; + } + max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + + Tdata *nram_x = nram_align; + Tdata *nram_y = nram_align + max_chunk_elements; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int core_elements = elements_per_core + (taskId < remain ? 1 : 0); + int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + + if (core_elements <= 0) { + return; + } + + int chunks = core_elements / max_chunk_elements; + int chunk_rem = core_elements % max_chunk_elements; + + for (int c = 0; c < chunks; c++) { + size_t current_offset = core_offset + c * max_chunk_elements; + __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + + __bang_mul_scalar(nram_x, nram_x, alpha[0], max_chunk_elements); + __bang_add(nram_y, nram_y, nram_x, max_chunk_elements); + + __memcpy(y + current_offset, nram_y, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + } + + if (chunk_rem > 0) { + size_t current_offset = core_offset + chunks * max_chunk_elements; + int align_rem = ((chunk_rem + align_elements - 1) / align_elements) * align_elements; + + __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + + __bang_mul_scalar(nram_x, nram_x, alpha[0], align_rem); + __bang_add(nram_y, nram_y, nram_x, align_rem); + + __memcpy(y + current_offset, nram_y, chunk_rem * sizeof(Tdata), NRAM2GDRAM); + } +} + +template +__mlu_global__ void axpyKernelStrided( + size_t n, + const Tdata *alpha, + const Tdata *x, + size_t incx, + Tdata *y, + size_t incy) { + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0); + int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; + + for (int i = start_idx; i < start_idx + actual_tasks; ++i) { + size_t idx_x = i * incx; + size_t idx_y = i * incy; + + y[idx_y] += alpha[0] * x[idx_x]; + } +} \ No newline at end of file diff --git a/src/infiniop/ops/axpy/cpu/axpy_cpu.cc b/src/infiniop/ops/axpy/cpu/axpy_cpu.cc new file mode 100644 index 000000000..ab2b5e673 --- /dev/null +++ b/src/infiniop/ops/axpy/cpu/axpy_cpu.cc @@ -0,0 +1,91 @@ +#include "axpy_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::axpy::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t alpha_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = AxpyInfo::createAxpyInfo(alpha_desc, x_desc, y_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateAxpy( + const AxpyInfo &info, + const Tdata *alpha, + const Tdata *x, + Tdata *y) { + + const ptrdiff_t size = info.getSize(); + const ptrdiff_t incx = info.getIncx(); + const ptrdiff_t incy = info.getIncy(); + + if constexpr (std::is_same::value || std::is_same::value) { + const float alpha_f = utils::cast(alpha[0]); + for (ptrdiff_t i = 0; i < size; ++i) { + const float x_f = utils::cast(x[i * incx]); + const float y_f = utils::cast(y[i * incy]); + y[i * incy] = utils::cast(alpha_f * x_f + y_f); + } + } else { + const Tdata alpha_v = alpha[0]; + for (ptrdiff_t i = 0; i < size; ++i) { + y[i * incy] = alpha_v * x[i * incx] + y[i * incy]; + } + } + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_AXPY(TDATA) \ + calculateAxpy(_info, \ + (const TDATA *)alpha, \ + (const TDATA *)x, \ + (TDATA *)y) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *alpha, + const void *x, + void *y, + void *stream) const { + + (void)workspace; + (void)workspace_size; + (void)stream; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_AXPY(fp16_t); + case INFINI_DTYPE_F32: + return CALCULATE_AXPY(float); + case INFINI_DTYPE_F64: + return CALCULATE_AXPY(double); + case INFINI_DTYPE_BF16: + return CALCULATE_AXPY(bf16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_AXPY + +} // namespace op::axpy::cpu diff --git a/src/infiniop/ops/axpy/cpu/axpy_cpu.h b/src/infiniop/ops/axpy/cpu/axpy_cpu.h new file mode 100644 index 000000000..f4bf63602 --- /dev/null +++ b/src/infiniop/ops/axpy/cpu/axpy_cpu.h @@ -0,0 +1,8 @@ +#ifndef __AXPY_CPU_H__ +#define __AXPY_CPU_H__ + +#include "../axpy.h" + +DESCRIPTOR(cpu) + +#endif // __AXPY_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/axpy/metax/axpy_metax.cc b/src/infiniop/ops/axpy/metax/axpy_metax.cc new file mode 100644 index 000000000..b31f586f7 --- /dev/null +++ b/src/infiniop/ops/axpy/metax/axpy_metax.cc @@ -0,0 +1,74 @@ +#include "axpy_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::axpy::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t alpha_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = AxpyInfo::createAxpyInfo(alpha_desc, x_desc, y_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *alpha, + const void *x, + void *y, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const size_t size = _info.getSize(); + const ptrdiff_t incx = _info.getIncx(); + const ptrdiff_t incy = _info.getIncy(); + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasSaxpy(handle, size, (const float *)alpha, (const float *)x, incx, (float *)y, incy)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasDaxpy(handle, size, (const double *)alpha, (const double *)x, incx, (double *)y, incy)); + break; + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::axpy::metax \ No newline at end of file diff --git a/src/infiniop/ops/axpy/metax/axpy_metax.h b/src/infiniop/ops/axpy/metax/axpy_metax.h new file mode 100644 index 000000000..09144a535 --- /dev/null +++ b/src/infiniop/ops/axpy/metax/axpy_metax.h @@ -0,0 +1,8 @@ +#ifndef __AXPY_METAX_H__ +#define __AXPY_METAX_H__ + +#include "../axpy.h" + +DESCRIPTOR(metax) + +#endif // __AXPY_METAX_H__ \ No newline at end of file diff --git a/src/infiniop/ops/axpy/operator.cc b/src/infiniop/ops/axpy/operator.cc new file mode 100644 index 000000000..1c45f0d35 --- /dev/null +++ b/src/infiniop/ops/axpy/operator.cc @@ -0,0 +1,127 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/axpy.h" + +#ifdef ENABLE_CPU_API +#include "cpu/axpy_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/axpy_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/axpy_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateAxpyDescriptor( + infiniopHandle_t handle, + infiniopAxpyDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t alpha_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::axpy::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + alpha_desc, \ + x_desc, \ + y_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetAxpyWorkspaceSize(infiniopAxpyDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__INFINI_C infiniStatus_t infiniopAxpy( + infiniopAxpyDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *alpha, + const void *x, + void *y, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, alpha, x, y, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t +infiniopDestroyAxpyDescriptor(infiniopAxpyDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} \ No newline at end of file diff --git a/test/infiniop/axpy.py b/test/infiniop/axpy.py new file mode 100644 index 000000000..278e44cd3 --- /dev/null +++ b/test/infiniop/axpy.py @@ -0,0 +1,145 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + +# ============================================================================== +# Configuration +# ============================================================================== + +_TEST_CASES = [ + # n, x_stride, y_stride + (3, None, None), + (8, (2,), (3,)), + (32, None, (2,)), + (257, (3,), None), + (65535, None, None), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, + InfiniDtype.F64: {"atol": 1e-9, "rtol": 1e-9}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def test( + handle, + device, + n, + x_stride=None, + y_stride=None, + dtype=torch.float16, + sync=None, +): + torch.manual_seed(0) + if device != 0: + torch.cuda.manual_seed_all(0) + + alpha = TestTensor(tuple(), None, dtype, device) + x = TestTensor((n,), x_stride, dtype, device) + y = TestTensor((n,), y_stride, dtype, device) + + print( + f"Testing axpy on {InfiniDeviceNames[device]} with n:{n} x_stride:{x_stride} y_stride:{y_stride} " + f"dtype:{InfiniDtypeNames[dtype]}" + ) + + y_ref = alpha.torch_tensor() * x.torch_tensor() + y.torch_tensor() + y.update_torch_tensor(y_ref) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateAxpyDescriptor( + handle, + ctypes.byref(descriptor), + alpha.descriptor, + x.descriptor, + y.descriptor, + ) + ) + + for tensor in [alpha, x, y]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetAxpyWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, y.device) + + def lib_axpy(): + check_error( + LIBINFINIOP.infiniopAxpy( + descriptor, + workspace.data(), + workspace.size(), + alpha.data(), + x.data(), + y.data(), + None, + ) + ) + + lib_axpy() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: alpha.torch_tensor() * x.torch_tensor() + y.torch_tensor(), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_axpy(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + check_error(LIBINFINIOP.infiniopDestroyAxpyDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92m Test passed! \033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 16f85aca1..f620ee870 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2255,3 +2255,37 @@ def asum_(lib): lib.infiniopDestroyAsumDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def axpy_(lib): + lib.infiniopCreateAxpyDescriptor.restype = c_int32 + lib.infiniopCreateAxpyDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetAxpyWorkspaceSize.restype = c_int32 + lib.infiniopGetAxpyWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopAxpy.restype = c_int32 + lib.infiniopAxpy.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyAxpyDescriptor.restype = c_int32 + lib.infiniopDestroyAxpyDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] From d6584f72c3db185aea566b7321279f207aa101c3 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Tue, 28 Apr 2026 07:00:27 +0000 Subject: [PATCH 04/25] Add `blas_copy` operator --- include/infiniop.h | 1 + include/infiniop/ops/blas_copy.h | 24 +++ src/infiniop/devices/metax/metax_ht2mc.h | 2 + .../ops/blas_copy/bang/blas_copy_bang.h | 8 + .../ops/blas_copy/bang/blas_copy_bang.mlu | 92 +++++++++++ .../blas_copy/bang/blas_copy_bang_kernel.mlu | 64 ++++++++ src/infiniop/ops/blas_copy/blas_copy.h | 95 +++++++++++ .../ops/blas_copy/cpu/blas_copy_cpu.cc | 77 +++++++++ .../ops/blas_copy/cpu/blas_copy_cpu.h | 8 + .../ops/blas_copy/metax/blas_copy_metax.cc | 72 +++++++++ .../ops/blas_copy/metax/blas_copy_metax.h | 8 + src/infiniop/ops/blas_copy/operator.cc | 121 ++++++++++++++ test/infiniop/blas_copy.py | 151 ++++++++++++++++++ test/infiniop/libinfiniop/op_register.py | 32 ++++ 14 files changed, 755 insertions(+) create mode 100644 include/infiniop/ops/blas_copy.h create mode 100644 src/infiniop/ops/blas_copy/bang/blas_copy_bang.h create mode 100644 src/infiniop/ops/blas_copy/bang/blas_copy_bang.mlu create mode 100644 src/infiniop/ops/blas_copy/bang/blas_copy_bang_kernel.mlu create mode 100644 src/infiniop/ops/blas_copy/blas_copy.h create mode 100644 src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.cc create mode 100644 src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.h create mode 100644 src/infiniop/ops/blas_copy/metax/blas_copy_metax.cc create mode 100644 src/infiniop/ops/blas_copy/metax/blas_copy_metax.h create mode 100644 src/infiniop/ops/blas_copy/operator.cc create mode 100644 test/infiniop/blas_copy.py diff --git a/include/infiniop.h b/include/infiniop.h index 5b6674adc..6473d3108 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -24,6 +24,7 @@ #include "infiniop/ops/binary_cross_entropy_with_logits.h" #include "infiniop/ops/blas_amax.h" #include "infiniop/ops/blas_amin.h" +#include "infiniop/ops/blas_copy.h" #include "infiniop/ops/block_diag.h" #include "infiniop/ops/broadcast_to.h" #include "infiniop/ops/causal_softmax.h" diff --git a/include/infiniop/ops/blas_copy.h b/include/infiniop/ops/blas_copy.h new file mode 100644 index 000000000..7c6f3611c --- /dev/null +++ b/include/infiniop/ops/blas_copy.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_BLAS_COPY_API_H__ +#define __INFINIOP_BLAS_COPY_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopBlasCopyDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateBlasCopyDescriptor(infiniopHandle_t handle, + infiniopBlasCopyDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y); + +__INFINI_C __export infiniStatus_t infiniopGetBlasCopyWorkspaceSize(infiniopBlasCopyDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopBlasCopy(infiniopBlasCopyDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *x, + void *y, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyBlasCopyDescriptor(infiniopBlasCopyDescriptor_t desc); + +#endif // __INFINIOP_BLAS_COPY_API_H__ diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index c771bf85b..89dbc9173 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -127,6 +127,8 @@ #define hcblasDasum mcblasDasum #define hcblasSaxpy mcblasSaxpy #define hcblasDaxpy mcblasDaxpy +#define hcblasScopy mcblasScopy +#define hcblasDcopy mcblasDcopy #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS #define HCBLAS_OP_T MCBLAS_OP_T #define HCBLAS_OP_N MCBLAS_OP_N diff --git a/src/infiniop/ops/blas_copy/bang/blas_copy_bang.h b/src/infiniop/ops/blas_copy/bang/blas_copy_bang.h new file mode 100644 index 000000000..fb326fb3d --- /dev/null +++ b/src/infiniop/ops/blas_copy/bang/blas_copy_bang.h @@ -0,0 +1,8 @@ +#ifndef __BLAS_COPY_BANG_H__ +#define __BLAS_COPY_BANG_H__ + +#include "../blas_copy.h" + +DESCRIPTOR(bang) + +#endif // __BLAS_COPY_BANG_H__ diff --git a/src/infiniop/ops/blas_copy/bang/blas_copy_bang.mlu b/src/infiniop/ops/blas_copy/bang/blas_copy_bang.mlu new file mode 100644 index 000000000..20a787792 --- /dev/null +++ b/src/infiniop/ops/blas_copy/bang/blas_copy_bang.mlu @@ -0,0 +1,92 @@ +#include "../../../devices/bang/common_bang.h" +#include "blas_copy_bang.h" +#include "blas_copy_bang_kernel.mlu" + +namespace op::blas_copy::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = BlasCopyInfo::createBlasCopyInfo(x_desc, y_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateBlasCopy( + const BlasCopyInfo &info, + const Tdata *x, + Tdata *y, + cnrtQueue_t queue) { + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeUnion1; + + if (info.getIncx() == 1 && info.getIncy() == 1) { + blasCopyKernelContiguous<<>>( + info.getSize(), + x, + y); + } else { + blasCopyKernelStrided<<>>( + info.getSize(), + x, + info.getIncx(), + y, + info.getIncy()); + } + + cnrtQueueSync(queue); + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_BLAS_COPY(TDATA) \ + calculateBlasCopy(_info, \ + (const TDATA *)x, \ + (TDATA *)y, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *y, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_BLAS_COPY(half); + case INFINI_DTYPE_F32: + return CALCULATE_BLAS_COPY(float); + case INFINI_DTYPE_BF16: + return CALCULATE_BLAS_COPY(bfloat16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_BLAS_COPY + +} // namespace op::blas_copy::bang diff --git a/src/infiniop/ops/blas_copy/bang/blas_copy_bang_kernel.mlu b/src/infiniop/ops/blas_copy/bang/blas_copy_bang_kernel.mlu new file mode 100644 index 000000000..2ab596ffa --- /dev/null +++ b/src/infiniop/ops/blas_copy/bang/blas_copy_bang_kernel.mlu @@ -0,0 +1,64 @@ +#include "../../../devices/bang/common_bang.h" +#include "blas_copy_bang.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void blasCopyKernelContiguous( + size_t n, + const Tdata *x, + Tdata *y) { + + Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); + size_t max_chunk_elements = nram_usable / sizeof(Tdata); + + int align_elements = ALIGN_SIZE / sizeof(Tdata); + if (align_elements == 0) { + align_elements = 1; + } + max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int core_elements = elements_per_core + (taskId < remain ? 1 : 0); + int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + + if (core_elements <= 0) { + return; + } + + int chunks = core_elements / max_chunk_elements; + int chunk_rem = core_elements % max_chunk_elements; + + for (int c = 0; c < chunks; c++) { + size_t current_offset = core_offset + c * max_chunk_elements; + __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + __memcpy(y + current_offset, nram_x, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + } + + if (chunk_rem > 0) { + size_t current_offset = core_offset + chunks * max_chunk_elements; + __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + __memcpy(y + current_offset, nram_x, chunk_rem * sizeof(Tdata), NRAM2GDRAM); + } +} + +template +__mlu_global__ void blasCopyKernelStrided( + size_t n, + const Tdata *x, + size_t incx, + Tdata *y, + size_t incy) { + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0); + int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; + + for (int i = start_idx; i < start_idx + actual_tasks; ++i) { + y[i * incy] = x[i * incx]; + } +} diff --git a/src/infiniop/ops/blas_copy/blas_copy.h b/src/infiniop/ops/blas_copy/blas_copy.h new file mode 100644 index 000000000..b2fa23970 --- /dev/null +++ b/src/infiniop/ops/blas_copy/blas_copy.h @@ -0,0 +1,95 @@ +#ifndef __BLAS_COPY_H__ +#define __BLAS_COPY_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/blas_copy.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::blas_copy::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + BlasCopyInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + BlasCopyInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t y_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + const void *x, \ + void *y, \ + void *stream) const; \ + }; \ + } + +class BlasCopyInfo { +private: + size_t _size; + ptrdiff_t _incx; + ptrdiff_t _incy; + infiniDtype_t _dtype; + + BlasCopyInfo(size_t size, + ptrdiff_t incx, + ptrdiff_t incy, + infiniDtype_t dtype) + : _size(size), _incx(incx), _incy(incy), _dtype(dtype) {} + +public: + inline size_t getSize() const { return _size; } + inline ptrdiff_t getIncx() const { return _incx; } + inline ptrdiff_t getIncy() const { return _incy; } + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static utils::Result createBlasCopyInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = x_desc->dtype(); + + CHECK_OR_RETURN(y_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto size = x_desc->numel(); + auto incx = x_desc->stride(0); + auto incy = y_desc->stride(0); + + BlasCopyInfo info(size, incx, incy, dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __BLAS_COPY_H__ diff --git a/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.cc b/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.cc new file mode 100644 index 000000000..979aa9df0 --- /dev/null +++ b/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.cc @@ -0,0 +1,77 @@ +#include "blas_copy_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::blas_copy::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = BlasCopyInfo::createBlasCopyInfo(x_desc, y_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateBlasCopy( + const BlasCopyInfo &info, + const Tdata *x, + Tdata *y) { + + const ptrdiff_t size = info.getSize(); + + for (ptrdiff_t i = 0; i < size; ++i) { + size_t x_idx = i * info.getIncx(); + size_t y_idx = i * info.getIncy(); + y[y_idx] = x[x_idx]; + } + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_BLAS_COPY(TDATA) \ + calculateBlasCopy(_info, \ + (const TDATA *)x, \ + (TDATA *)y) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *y, + void *stream) const { + + (void)workspace; + (void)workspace_size; + (void)stream; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_BLAS_COPY(fp16_t); + case INFINI_DTYPE_F32: + return CALCULATE_BLAS_COPY(float); + case INFINI_DTYPE_F64: + return CALCULATE_BLAS_COPY(double); + case INFINI_DTYPE_BF16: + return CALCULATE_BLAS_COPY(bf16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_BLAS_COPY + +} // namespace op::blas_copy::cpu diff --git a/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.h b/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.h new file mode 100644 index 000000000..7cfe48752 --- /dev/null +++ b/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.h @@ -0,0 +1,8 @@ +#ifndef __BLAS_COPY_CPU_H__ +#define __BLAS_COPY_CPU_H__ + +#include "../blas_copy.h" + +DESCRIPTOR(cpu) + +#endif // __BLAS_COPY_CPU_H__ diff --git a/src/infiniop/ops/blas_copy/metax/blas_copy_metax.cc b/src/infiniop/ops/blas_copy/metax/blas_copy_metax.cc new file mode 100644 index 000000000..8e8a8db10 --- /dev/null +++ b/src/infiniop/ops/blas_copy/metax/blas_copy_metax.cc @@ -0,0 +1,72 @@ +#include "blas_copy_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::blas_copy::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = BlasCopyInfo::createBlasCopyInfo(x_desc, y_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *y, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const size_t size = _info.getSize(); + const ptrdiff_t incx = _info.getIncx(); + const ptrdiff_t incy = _info.getIncy(); + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasScopy(handle, size, (const float *)x, incx, (float *)y, incy)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasDcopy(handle, size, (const double *)x, incx, (double *)y, incy)); + break; + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::blas_copy::metax diff --git a/src/infiniop/ops/blas_copy/metax/blas_copy_metax.h b/src/infiniop/ops/blas_copy/metax/blas_copy_metax.h new file mode 100644 index 000000000..88f118dbf --- /dev/null +++ b/src/infiniop/ops/blas_copy/metax/blas_copy_metax.h @@ -0,0 +1,8 @@ +#ifndef __BLAS_COPY_METAX_H__ +#define __BLAS_COPY_METAX_H__ + +#include "../blas_copy.h" + +DESCRIPTOR(metax) + +#endif // __BLAS_COPY_METAX_H__ diff --git a/src/infiniop/ops/blas_copy/operator.cc b/src/infiniop/ops/blas_copy/operator.cc new file mode 100644 index 000000000..394adc665 --- /dev/null +++ b/src/infiniop/ops/blas_copy/operator.cc @@ -0,0 +1,121 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/blas_copy.h" + +#ifdef ENABLE_CPU_API +#include "cpu/blas_copy_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/blas_copy_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/blas_copy_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateBlasCopyDescriptor( + infiniopHandle_t handle, + infiniopBlasCopyDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::blas_copy::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + x_desc, y_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetBlasCopyWorkspaceSize(infiniopBlasCopyDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__INFINI_C infiniStatus_t infiniopBlasCopy( + infiniopBlasCopyDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *x, + void *y, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, x, y, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyBlasCopyDescriptor(infiniopBlasCopyDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/test/infiniop/blas_copy.py b/test/infiniop/blas_copy.py new file mode 100644 index 000000000..c1d083d2d --- /dev/null +++ b/test/infiniop/blas_copy.py @@ -0,0 +1,151 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + +# ============================================================================== +# Configuration +# ============================================================================== + +_TEST_CASES = [ + # n, x_stride, y_stride + (3, None, None), + (8, (2,), (3,)), + (32, None, (2,)), + (257, (3,), None), + (65535, None, None), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, + InfiniDtype.F64: {"atol": 1e-9, "rtol": 1e-9}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def torch_copy(x, y): + y.copy_(x) + + +def test( + handle, + device, + n, + x_stride=None, + y_stride=None, + dtype=torch.float16, + sync=None, +): + x = TestTensor((n,), x_stride, dtype, device) + y = TestTensor((n,), y_stride, dtype, device) + + if x.is_broadcast(): + return + + print( + f"Testing BlasCopy on {InfiniDeviceNames[device]} with n:{n} x_stride:{x_stride} " + f"y_stride:{y_stride} dtype:{InfiniDtypeNames[dtype]}" + ) + + torch_copy(x.torch_tensor(), y.torch_tensor()) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateBlasCopyDescriptor( + handle, + ctypes.byref(descriptor), + x.descriptor, + y.descriptor, + ) + ) + + x.destroy_desc() + y.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetBlasCopyWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, x.device) + + def lib_copy(): + check_error( + LIBINFINIOP.infiniopBlasCopy( + descriptor, + workspace.data(), + workspace.size(), + x.data(), + y.data(), + None, + ) + ) + + lib_copy() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: torch_copy(x.torch_tensor(), y.torch_tensor()), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lambda: lib_copy(), device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(LIBINFINIOP.infiniopDestroyBlasCopyDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index f620ee870..a045c1da5 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2289,3 +2289,35 @@ def axpy_(lib): lib.infiniopDestroyAxpyDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def blas_copy_(lib): + lib.infiniopCreateBlasCopyDescriptor.restype = c_int32 + lib.infiniopCreateBlasCopyDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetBlasCopyWorkspaceSize.restype = c_int32 + lib.infiniopGetBlasCopyWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopBlasCopy.restype = c_int32 + lib.infiniopBlasCopy.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyBlasCopyDescriptor.restype = c_int32 + lib.infiniopDestroyBlasCopyDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] From ac01531b207ba1c5d155c06c6b43ab75442229f0 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Tue, 28 Apr 2026 07:40:27 +0000 Subject: [PATCH 05/25] Add `blas_dot` operator --- include/infiniop.h | 1 + include/infiniop/ops/blas_dot.h | 26 +++ src/infiniop/devices/metax/metax_ht2mc.h | 2 + .../ops/blas_dot/bang/blas_dot_bang.h | 8 + .../ops/blas_dot/bang/blas_dot_bang.mlu | 99 +++++++++++ .../blas_dot/bang/blas_dot_bang_kernel.mlu | 107 ++++++++++++ src/infiniop/ops/blas_dot/blas_dot.h | 101 ++++++++++++ src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc | 102 ++++++++++++ src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.h | 8 + .../ops/blas_dot/metax/blas_dot_metax.cc | 74 +++++++++ .../ops/blas_dot/metax/blas_dot_metax.h | 8 + src/infiniop/ops/blas_dot/operator.cc | 125 ++++++++++++++ test/infiniop/blas_dot.py | 154 ++++++++++++++++++ test/infiniop/libinfiniop/op_register.py | 34 ++++ 14 files changed, 849 insertions(+) create mode 100644 include/infiniop/ops/blas_dot.h create mode 100644 src/infiniop/ops/blas_dot/bang/blas_dot_bang.h create mode 100644 src/infiniop/ops/blas_dot/bang/blas_dot_bang.mlu create mode 100644 src/infiniop/ops/blas_dot/bang/blas_dot_bang_kernel.mlu create mode 100644 src/infiniop/ops/blas_dot/blas_dot.h create mode 100644 src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc create mode 100644 src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.h create mode 100644 src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc create mode 100644 src/infiniop/ops/blas_dot/metax/blas_dot_metax.h create mode 100644 src/infiniop/ops/blas_dot/operator.cc create mode 100644 test/infiniop/blas_dot.py diff --git a/include/infiniop.h b/include/infiniop.h index 6473d3108..0bd58f66a 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -25,6 +25,7 @@ #include "infiniop/ops/blas_amax.h" #include "infiniop/ops/blas_amin.h" #include "infiniop/ops/blas_copy.h" +#include "infiniop/ops/blas_dot.h" #include "infiniop/ops/block_diag.h" #include "infiniop/ops/broadcast_to.h" #include "infiniop/ops/causal_softmax.h" diff --git a/include/infiniop/ops/blas_dot.h b/include/infiniop/ops/blas_dot.h new file mode 100644 index 000000000..9d03af2c0 --- /dev/null +++ b/include/infiniop/ops/blas_dot.h @@ -0,0 +1,26 @@ +#ifndef __INFINIOP_BLAS_DOT_API_H__ +#define __INFINIOP_BLAS_DOT_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopBlasDotDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateBlasDotDescriptor(infiniopHandle_t handle, + infiniopBlasDotDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t result); + +__INFINI_C __export infiniStatus_t infiniopGetBlasDotWorkspaceSize(infiniopBlasDotDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopBlasDot(infiniopBlasDotDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *x, + const void *y, + void *result, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyBlasDotDescriptor(infiniopBlasDotDescriptor_t desc); + +#endif // __INFINIOP_BLAS_DOT_API_H__ diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index 89dbc9173..29a1f5f93 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -129,6 +129,8 @@ #define hcblasDaxpy mcblasDaxpy #define hcblasScopy mcblasScopy #define hcblasDcopy mcblasDcopy +#define hcblasSdot mcblasSdot +#define hcblasDdot mcblasDdot #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS #define HCBLAS_OP_T MCBLAS_OP_T #define HCBLAS_OP_N MCBLAS_OP_N diff --git a/src/infiniop/ops/blas_dot/bang/blas_dot_bang.h b/src/infiniop/ops/blas_dot/bang/blas_dot_bang.h new file mode 100644 index 000000000..2020935ad --- /dev/null +++ b/src/infiniop/ops/blas_dot/bang/blas_dot_bang.h @@ -0,0 +1,8 @@ +#ifndef __BLAS_DOT_BANG_H__ +#define __BLAS_DOT_BANG_H__ + +#include "../blas_dot.h" + +DESCRIPTOR(bang) + +#endif // __BLAS_DOT_BANG_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_dot/bang/blas_dot_bang.mlu b/src/infiniop/ops/blas_dot/bang/blas_dot_bang.mlu new file mode 100644 index 000000000..481905be1 --- /dev/null +++ b/src/infiniop/ops/blas_dot/bang/blas_dot_bang.mlu @@ -0,0 +1,99 @@ +#include "../../../devices/bang/common_bang.h" +#include "blas_dot_bang.h" +#include "blas_dot_bang_kernel.mlu" + +namespace op::blas_dot::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = BlasDotInfo::createBlasDotInfo(x_desc, y_desc, result_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateBlasDot( + const BlasDotInfo &info, + const Tdata *x, + const Tdata *y, + Tdata *result, + cnrtQueue_t queue) { + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeUnion1; + + if (info.getIncx() == 1 && info.getIncy() == 1) { + blasDotKernelContiguous<<>>( + info.getSize(), + x, + y, + result); + } else { + blasDotKernelStrided<<>>( + info.getSize(), + x, + info.getIncx(), + y, + info.getIncy(), + result); + } + + cnrtQueueSync(queue); + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_BLAS_DOT(TDATA) \ + calculateBlasDot(_info, \ + (const TDATA *)x, \ + (const TDATA *)y, \ + (TDATA *)result, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + const void *y, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_BLAS_DOT(half); + case INFINI_DTYPE_F32: + return CALCULATE_BLAS_DOT(float); + case INFINI_DTYPE_BF16: + return CALCULATE_BLAS_DOT(bfloat16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_BLAS_DOT + +} // namespace op::blas_dot::bang diff --git a/src/infiniop/ops/blas_dot/bang/blas_dot_bang_kernel.mlu b/src/infiniop/ops/blas_dot/bang/blas_dot_bang_kernel.mlu new file mode 100644 index 000000000..a1e3a84d9 --- /dev/null +++ b/src/infiniop/ops/blas_dot/bang/blas_dot_bang_kernel.mlu @@ -0,0 +1,107 @@ +#include "../../../devices/bang/common_bang.h" +#include "blas_dot_bang.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void blasDotKernelContiguous( + size_t n, + const Tdata *x, + const Tdata *y, + Tdata *result) { + __mlu_shared__ Tdata shared_partial_sum[4]; + + Tdata *nram_align = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_align - nram_buffer); + size_t max_chunk_elements = nram_usable / (2 * sizeof(Tdata)); + + int align_elements = ALIGN_SIZE / sizeof(Tdata); + if (align_elements == 0) { + align_elements = 1; + } + max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + + Tdata *nram_x = nram_align; + Tdata *nram_y = nram_align + max_chunk_elements; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int core_elements = elements_per_core + (taskId < remain ? 1 : 0); + int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + + int chunks = core_elements / max_chunk_elements; + int chunk_rem = core_elements % max_chunk_elements; + + Tdata partial_sum = 0; + + for (int c = 0; c < chunks; c++) { + size_t current_offset = core_offset + c * max_chunk_elements; + __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + + __bang_mul(nram_x, nram_x, nram_y, max_chunk_elements); + partial_sum += __bang_sum(nram_x, max_chunk_elements); + } + + if (chunk_rem > 0) { + size_t current_offset = core_offset + chunks * max_chunk_elements; + __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + + __bang_mul(nram_x, nram_x, nram_y, chunk_rem); + partial_sum += __bang_sum(nram_x, chunk_rem); + } + + shared_partial_sum[coreId] = partial_sum; + + __sync_cluster(); + + if (coreId == 0) { + Tdata cluster_sum = 0; + + for (int i = 0; i < coreDim; i++) { + cluster_sum += shared_partial_sum[i]; + } + + result[0] = cluster_sum; + } +} + +template +__mlu_global__ void blasDotKernelStrided( + size_t n, + const Tdata *x, + ptrdiff_t incx, + const Tdata *y, + ptrdiff_t incy, + Tdata *result) { + __mlu_shared__ Tdata shared_partial_sum[4]; + + size_t elements_per_core = n / taskDim; + size_t remain = n % taskDim; + size_t core_elements = elements_per_core + (taskId < remain ? 1 : 0); + size_t start_idx = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + + Tdata partial_sum = 0; + ptrdiff_t x_offset = static_cast(start_idx) * incx; + ptrdiff_t y_offset = static_cast(start_idx) * incy; + + for (size_t i = 0; i < core_elements; ++i) { + partial_sum += x[x_offset] * y[y_offset]; + x_offset += incx; + y_offset += incy; + } + + shared_partial_sum[coreId] = partial_sum; + + __sync_cluster(); + + if (coreId == 0) { + Tdata cluster_sum = 0; + for (int i = 0; i < coreDim; ++i) { + cluster_sum += shared_partial_sum[i]; + } + result[0] = cluster_sum; + } +} \ No newline at end of file diff --git a/src/infiniop/ops/blas_dot/blas_dot.h b/src/infiniop/ops/blas_dot/blas_dot.h new file mode 100644 index 000000000..00675e55c --- /dev/null +++ b/src/infiniop/ops/blas_dot/blas_dot.h @@ -0,0 +1,101 @@ +#ifndef __BLAS_DOT_H__ +#define __BLAS_DOT_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/blas_dot.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::blas_dot::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + BlasDotInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + BlasDotInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t result_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + const void *x, \ + const void *y, \ + void *result, \ + void *stream) const; \ + }; \ + } + +class BlasDotInfo { +private: + size_t _size; + ptrdiff_t _incx; + ptrdiff_t _incy; + infiniDtype_t _dtype; + + BlasDotInfo(size_t size, + ptrdiff_t incx, + ptrdiff_t incy, + infiniDtype_t dtype) + : _size(size), _incx(incx), _incy(incy), _dtype(dtype) {} + +public: + inline size_t getSize() const { return _size; } + inline ptrdiff_t getIncx() const { return _incx; } + inline ptrdiff_t getIncy() const { return _incy; } + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static ResultType createBlasDotInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t result_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = x_desc->dtype(); + + CHECK_OR_RETURN(y_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(result_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto size = x_desc->numel(); + auto incx = x_desc->stride(0); + auto incy = y_desc->stride(0); + + BlasDotInfo info(size, incx, incy, dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __BLAS_DOT_H__ diff --git a/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc b/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc new file mode 100644 index 000000000..c138baa78 --- /dev/null +++ b/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc @@ -0,0 +1,102 @@ +#include "blas_dot_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::blas_dot::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = BlasDotInfo::createBlasDotInfo(x_desc, y_desc, result_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateBlasDot( + const BlasDotInfo &info, + const Tdata *x, + const Tdata *y, + Tdata *result) { + + const ptrdiff_t n = info.getSize(); + const ptrdiff_t incx = info.getIncx(); + const ptrdiff_t incy = info.getIncy(); + + ptrdiff_t ix = (incx < 0) ? (1 - n) * incx : 0; + ptrdiff_t iy = (incy < 0) ? (1 - n) * incy : 0; + + if constexpr (std::is_same::value || std::is_same::value) { + float total = 0.0f; + + for (ptrdiff_t i = 0; i < n; ++i) { + total += utils::cast(x[ix]) * utils::cast(y[iy]); + ix += incx; + iy += incy; + } + + result[0] = utils::cast(total); + } else { + Tdata total = utils::cast(0); + + for (ptrdiff_t i = 0; i < n; ++i) { + total += x[ix] * y[iy]; + ix += incx; + iy += incy; + } + + result[0] = total; + } + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_BLAS_DOT(TDATA) \ + calculateBlasDot(_info, \ + (const TDATA *)x, \ + (const TDATA *)y, \ + (TDATA *)result) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + const void *y, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + (void)stream; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_BLAS_DOT(fp16_t); + case INFINI_DTYPE_F32: + return CALCULATE_BLAS_DOT(float); + case INFINI_DTYPE_F64: + return CALCULATE_BLAS_DOT(double); + case INFINI_DTYPE_BF16: + return CALCULATE_BLAS_DOT(bf16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_BLAS_DOT + +} // namespace op::blas_dot::cpu diff --git a/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.h b/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.h new file mode 100644 index 000000000..e589316a9 --- /dev/null +++ b/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.h @@ -0,0 +1,8 @@ +#ifndef __BLAS_DOT_CPU_H__ +#define __BLAS_DOT_CPU_H__ + +#include "../blas_dot.h" + +DESCRIPTOR(cpu) + +#endif // __BLAS_DOT_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc b/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc new file mode 100644 index 000000000..16e0bd623 --- /dev/null +++ b/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc @@ -0,0 +1,74 @@ +#include "blas_dot_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::blas_dot::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = BlasDotInfo::createBlasDotInfo(x_desc, y_desc, result_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + const void *y, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const size_t size = _info.getSize(); + const ptrdiff_t incx = _info.getIncx(); + const ptrdiff_t incy = _info.getIncy(); + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasSdot(handle, size, (const float *)x, incx, (const float *)y, incy, (float *)result)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasDdot(handle, size, (const double *)x, incx, (const double *)y, incy, (double *)result)); + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::blas_dot::metax diff --git a/src/infiniop/ops/blas_dot/metax/blas_dot_metax.h b/src/infiniop/ops/blas_dot/metax/blas_dot_metax.h new file mode 100644 index 000000000..19a2385b0 --- /dev/null +++ b/src/infiniop/ops/blas_dot/metax/blas_dot_metax.h @@ -0,0 +1,8 @@ +#ifndef __BLAS_DOT_METAX_H__ +#define __BLAS_DOT_METAX_H__ + +#include "../blas_dot.h" + +DESCRIPTOR(metax) + +#endif // __BLAS_DOT_METAX_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_dot/operator.cc b/src/infiniop/ops/blas_dot/operator.cc new file mode 100644 index 000000000..e28943b8b --- /dev/null +++ b/src/infiniop/ops/blas_dot/operator.cc @@ -0,0 +1,125 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/blas_dot.h" + +#ifdef ENABLE_CPU_API +#include "cpu/blas_dot_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/blas_dot_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/blas_dot_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateBlasDotDescriptor( + infiniopHandle_t handle, + infiniopBlasDotDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t result_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::blas_dot::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + x_desc, \ + y_desc, \ + result_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetBlasDotWorkspaceSize(infiniopBlasDotDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__INFINI_C infiniStatus_t infiniopBlasDot( + infiniopBlasDotDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *x, + const void *y, + void *result, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, x, y, result, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyBlasDotDescriptor(infiniopBlasDotDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/test/infiniop/blas_dot.py b/test/infiniop/blas_dot.py new file mode 100644 index 000000000..1f29a2c65 --- /dev/null +++ b/test/infiniop/blas_dot.py @@ -0,0 +1,154 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + +# ============================================================================== +# Configuration +# ============================================================================== + +_TEST_CASES = [ + # n, x_stride, y_stride + (3, None, None), + (8, (2,), (3,)), + (32, None, (2,)), + (257, (3,), None), + (65535, None, None), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, + InfiniDtype.F64: {"atol": 1e-9, "rtol": 1e-9}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def test( + handle, + device, + n, + x_stride=None, + y_stride=None, + dtype=torch.float16, + sync=None, +): + torch.manual_seed(0) + if device != 0: + torch.cuda.manual_seed_all(0) + + x = TestTensor((n,), x_stride, dtype, device) + y = TestTensor((n,), y_stride, dtype, device) + result = TestTensor(tuple(), None, dtype, device, mode="zeros") + + print( + f"Testing blas_dot on {InfiniDeviceNames[device]} with n:{n} x_stride:{x_stride} y_stride:{y_stride} " + f"dtype:{InfiniDtypeNames[dtype]}" + ) + + def torch_blas_dot(): + if dtype in (InfiniDtype.F16, InfiniDtype.BF16): + return torch.dot(x.torch_tensor().float(), y.torch_tensor().float()).to( + x.torch_tensor().dtype + ) + return torch.dot(x.torch_tensor(), y.torch_tensor()) + + result_ref = torch_blas_dot() + result.update_torch_tensor(result_ref) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateBlasDotDescriptor( + handle, + ctypes.byref(descriptor), + x.descriptor, + y.descriptor, + result.descriptor, + ) + ) + + for tensor in [x, result]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetBlasDotWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, result.device) + + def lib_blas_dot(): + check_error( + LIBINFINIOP.infiniopBlasDot( + descriptor, + workspace.data(), + workspace.size(), + x.data(), + y.data(), + result.data(), + None, + ) + ) + + lib_blas_dot() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(result.actual_tensor(), result.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose( + result.actual_tensor(), result.torch_tensor(), atol=atol, rtol=rtol + ) + + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: torch_blas_dot(), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_blas_dot(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + check_error(LIBINFINIOP.infiniopDestroyBlasDotDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92m Test passed! \033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index a045c1da5..974ec4165 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2321,3 +2321,37 @@ def blas_copy_(lib): lib.infiniopDestroyBlasCopyDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def blas_dot_(lib): + lib.infiniopCreateBlasDotDescriptor.restype = c_int32 + lib.infiniopCreateBlasDotDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetBlasDotWorkspaceSize.restype = c_int32 + lib.infiniopGetBlasDotWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopBlasDot.restype = c_int32 + lib.infiniopBlasDot.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyBlasDotDescriptor.restype = c_int32 + lib.infiniopDestroyBlasDotDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] From b876b7956571310958e5cdbb46640371e8b06ac3 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Tue, 28 Apr 2026 09:12:09 +0000 Subject: [PATCH 06/25] Add `nrm2` operator --- include/infiniop.h | 1 + include/infiniop/ops/nrm2.h | 24 +++ src/infiniop/devices/metax/metax_ht2mc.h | 2 + src/infiniop/ops/nrm2/bang/nrm2_bang.h | 8 + src/infiniop/ops/nrm2/bang/nrm2_bang.mlu | 93 ++++++++++ .../ops/nrm2/bang/nrm2_bang_kernel.mlu | 166 ++++++++++++++++++ src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc | 164 +++++++++++++++++ src/infiniop/ops/nrm2/cpu/nrm2_cpu.h | 8 + src/infiniop/ops/nrm2/metax/nrm2_metax.cc | 71 ++++++++ src/infiniop/ops/nrm2/metax/nrm2_metax.h | 8 + src/infiniop/ops/nrm2/nrm2.h | 90 ++++++++++ src/infiniop/ops/nrm2/operator.cc | 121 +++++++++++++ test/infiniop/libinfiniop/op_register.py | 32 ++++ test/infiniop/nrm2.py | 152 ++++++++++++++++ 14 files changed, 940 insertions(+) create mode 100644 include/infiniop/ops/nrm2.h create mode 100644 src/infiniop/ops/nrm2/bang/nrm2_bang.h create mode 100644 src/infiniop/ops/nrm2/bang/nrm2_bang.mlu create mode 100644 src/infiniop/ops/nrm2/bang/nrm2_bang_kernel.mlu create mode 100644 src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc create mode 100644 src/infiniop/ops/nrm2/cpu/nrm2_cpu.h create mode 100644 src/infiniop/ops/nrm2/metax/nrm2_metax.cc create mode 100644 src/infiniop/ops/nrm2/metax/nrm2_metax.h create mode 100644 src/infiniop/ops/nrm2/nrm2.h create mode 100644 src/infiniop/ops/nrm2/operator.cc create mode 100644 test/infiniop/nrm2.py diff --git a/include/infiniop.h b/include/infiniop.h index 0bd58f66a..6ddde667d 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -84,6 +84,7 @@ #include "infiniop/ops/matrix_power.h" #include "infiniop/ops/mul.h" #include "infiniop/ops/multi_margin_loss.h" +#include "infiniop/ops/nrm2.h" #include "infiniop/ops/ones.h" #include "infiniop/ops/pad.h" #include "infiniop/ops/paged_attention.h" diff --git a/include/infiniop/ops/nrm2.h b/include/infiniop/ops/nrm2.h new file mode 100644 index 000000000..2f1eed348 --- /dev/null +++ b/include/infiniop/ops/nrm2.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_NRM2_API_H__ +#define __INFINIOP_NRM2_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopNrm2Descriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateNrm2Descriptor(infiniopHandle_t handle, + infiniopNrm2Descriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t result); + +__INFINI_C __export infiniStatus_t infiniopGetNrm2WorkspaceSize(infiniopNrm2Descriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopNrm2(infiniopNrm2Descriptor_t desc, + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyNrm2Descriptor(infiniopNrm2Descriptor_t desc); + +#endif // __INFINIOP_NRM2_API_H__ diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index 29a1f5f93..b2dfae91f 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -131,6 +131,8 @@ #define hcblasDcopy mcblasDcopy #define hcblasSdot mcblasSdot #define hcblasDdot mcblasDdot +#define hcblasSnrm2 mcblasSnrm2 +#define hcblasDnrm2 mcblasDnrm2 #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS #define HCBLAS_OP_T MCBLAS_OP_T #define HCBLAS_OP_N MCBLAS_OP_N diff --git a/src/infiniop/ops/nrm2/bang/nrm2_bang.h b/src/infiniop/ops/nrm2/bang/nrm2_bang.h new file mode 100644 index 000000000..af7961392 --- /dev/null +++ b/src/infiniop/ops/nrm2/bang/nrm2_bang.h @@ -0,0 +1,8 @@ +#ifndef __NRM2_BANG_H__ +#define __NRM2_BANG_H__ + +#include "../nrm2.h" + +DESCRIPTOR(bang) + +#endif // __NRM2_BANG_H__ \ No newline at end of file diff --git a/src/infiniop/ops/nrm2/bang/nrm2_bang.mlu b/src/infiniop/ops/nrm2/bang/nrm2_bang.mlu new file mode 100644 index 000000000..d935c688e --- /dev/null +++ b/src/infiniop/ops/nrm2/bang/nrm2_bang.mlu @@ -0,0 +1,93 @@ +#include "../../../devices/bang/common_bang.h" +#include "nrm2_bang.h" +#include "nrm2_bang_kernel.mlu" + +namespace op::nrm2::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = Nrm2Info::createNrm2Info(x_desc, result_desc); + CHECK_RESULT(info); + + // Create descriptor + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateNrm2( + const Nrm2Info &info, + const Tdata *x, + Tdata *result, + cnrtQueue_t queue) { + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeUnion1; + + if (info.getIncx() == 1) { + Nrm2KernelContiguous<<>>( + info.getSize(), + x, + result); + } else { + Nrm2KernelStrided<<>>( + info.getSize(), + x, + info.getIncx(), + result); + } + + cnrtQueueSync(queue); + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_NRM2(TDATA) \ + calculateNrm2(_info, \ + (const TDATA *)x, \ + (TDATA *)result, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_NRM2(half); + case INFINI_DTYPE_F32: + return CALCULATE_NRM2(float); + case INFINI_DTYPE_BF16: + return CALCULATE_NRM2(bfloat16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_NRM2 + +} // namespace op::nrm2::bang diff --git a/src/infiniop/ops/nrm2/bang/nrm2_bang_kernel.mlu b/src/infiniop/ops/nrm2/bang/nrm2_bang_kernel.mlu new file mode 100644 index 000000000..f48a4ad52 --- /dev/null +++ b/src/infiniop/ops/nrm2/bang/nrm2_bang_kernel.mlu @@ -0,0 +1,166 @@ +#include "../../../devices/bang/common_bang.h" +#include "nrm2_bang.h" + +#include + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_device__ void nrm2ToCompute(float *dst, const Tdata *src, size_t size) { + if constexpr (std::is_same_v) { + __bang_half2float(dst, src, size); + } else if constexpr (std::is_same_v) { + __bang_bfloat162float(dst, src, size); + } else { + __memcpy(dst, src, size * sizeof(float), NRAM2NRAM); + } +} + +template +__mlu_device__ void nrm2StoreResult(Tdata *result, Tdata *nram_result, float *nram_compute, float value) { + nram_compute[0] = value; + if constexpr (std::is_same_v) { + __bang_float2half(nram_result, nram_compute, 1); + result[0] = nram_result[0]; + } else if constexpr (std::is_same_v) { + __bang_float2bfloat16(nram_result, nram_compute, 1); + result[0] = nram_result[0]; + } else { + result[0] = nram_compute[0]; + } +} + +template +__mlu_global__ void Nrm2KernelContiguous( + size_t n, + const Tdata *x, + Tdata *result) { + __mlu_shared__ float shared_partial_sum[4]; + + Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); + size_t max_chunk_elements = (nram_usable - ALIGN_SIZE) / (sizeof(Tdata) + sizeof(float)); + + int align_elements = ALIGN_SIZE / sizeof(Tdata); + if (align_elements == 0) { + align_elements = 1; + } + max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + float *nram_compute = (float *)(((size_t)(nram_x + max_chunk_elements) + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int core_elements = elements_per_core + (taskId < remain ? 1 : 0); + int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + + int chunks = core_elements / max_chunk_elements; + int chunk_rem = core_elements % max_chunk_elements; + + float partial_sum = 0.0f; + + for (int c = 0; c < chunks; c++) { + size_t current_offset = core_offset + c * max_chunk_elements; + __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + + nrm2ToCompute(nram_compute, nram_x, max_chunk_elements); + __bang_square(nram_compute, nram_compute, max_chunk_elements); + + partial_sum += __bang_sum(nram_compute, max_chunk_elements); + } + + if (chunk_rem > 0) { + size_t current_offset = core_offset + chunks * max_chunk_elements; + + __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + + nrm2ToCompute(nram_compute, nram_x, chunk_rem); + __bang_square(nram_compute, nram_compute, chunk_rem); + + partial_sum += __bang_sum(nram_compute, chunk_rem); + } + + shared_partial_sum[coreId] = partial_sum; + + __sync_cluster(); + + if (coreId == 0) { + float cluster_sum = 0.0f; + + for (int i = 0; i < coreDim; i++) { + cluster_sum += shared_partial_sum[i]; + } + + nrm2StoreResult(result, nram_x, nram_compute, std::sqrt(cluster_sum)); + } +} + +template +__mlu_global__ void Nrm2KernelStrided( + size_t n, + const Tdata *x, + size_t incx, + Tdata *result) { + __mlu_shared__ float shared_partial_sum[4]; + + Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); + size_t max_chunk_elements = (nram_usable - ALIGN_SIZE) / (sizeof(Tdata) + sizeof(float)); + + int align_elements = ALIGN_SIZE / sizeof(Tdata); + if (align_elements == 0) { + align_elements = 1; + } + max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + float *nram_compute = (float *)(((size_t)(nram_x + max_chunk_elements) + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0); + int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; + + float partial_sum = 0.0f; + + int chunks = actual_tasks / max_chunk_elements; + int chunk_rem = actual_tasks % max_chunk_elements; + + for (int c = 0; c < chunks; c++) { + int current_elements = max_chunk_elements; + int current_start = start_idx + c * max_chunk_elements; + for (int i = 0; i < current_elements; ++i) { + nram_x[i] = x[(current_start + i) * incx]; + } + + nrm2ToCompute(nram_compute, nram_x, current_elements); + __bang_square(nram_compute, nram_compute, current_elements); + + partial_sum += __bang_sum(nram_compute, current_elements); + } + + if (chunk_rem > 0) { + int current_start = start_idx + chunks * max_chunk_elements; + for (int i = 0; i < chunk_rem; ++i) { + nram_x[i] = x[(current_start + i) * incx]; + } + + nrm2ToCompute(nram_compute, nram_x, chunk_rem); + __bang_square(nram_compute, nram_compute, chunk_rem); + + partial_sum += __bang_sum(nram_compute, chunk_rem); + } + + shared_partial_sum[coreId] = partial_sum; + + __sync_cluster(); + + if (coreId == 0) { + float cluster_sum = 0.0f; + + for (int i = 0; i < coreDim; i++) { + cluster_sum += shared_partial_sum[i]; + } + + nrm2StoreResult(result, nram_x, nram_compute, std::sqrt(cluster_sum)); + } +} diff --git a/src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc b/src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc new file mode 100644 index 000000000..f36eb33ce --- /dev/null +++ b/src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc @@ -0,0 +1,164 @@ +#include "nrm2_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include +#include +#include + +namespace op::nrm2::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = Nrm2Info::createNrm2Info(x_desc, result_desc); + CHECK_RESULT(info); + + // Create descriptor + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateNrm2( + const Nrm2Info &info, + const Tdata *x, + Tdata *result) { + + using Tcompute = std::conditional_t, double, float>; + + const ptrdiff_t n = info.getSize(); + const ptrdiff_t incx = info.getIncx(); + + // Blue's scaling constants (float vs double) + constexpr Tcompute tsml = [] { + if constexpr (std::is_same_v) { + return Tcompute(0x1p-63f); // 2^-63 + } else { + return Tcompute(0x1p-511); // 2^-511 + } + }(); + constexpr Tcompute tbig = [] { + if constexpr (std::is_same_v) { + return Tcompute(0x1p52f); // 2^52 + } else { + return Tcompute(0x1p486); // 2^486 + } + }(); + constexpr Tcompute ssml = [] { + if constexpr (std::is_same_v) { + return Tcompute(0x1p75f); // 2^75 + } else { + return Tcompute(0x1p600); // 2^600 + } + }(); + constexpr Tcompute sbig = [] { + if constexpr (std::is_same_v) { + return Tcompute(0x1p-76f); // 2^-76 + } else { + return Tcompute(0x1p-601); // 2^-601 + } + }(); + + Tcompute scl = Tcompute(1); + Tcompute sumsq = Tcompute(0); + + bool notbig = true; + Tcompute asml = Tcompute(0); + Tcompute amed = Tcompute(0); + Tcompute abig = Tcompute(0); + + // 0-based index; handle negative stride + ptrdiff_t ix = (incx < 0) ? (ptrdiff_t(1) - n) * incx : 0; + + for (ptrdiff_t i = 0; i < n; ++i) { + Tcompute ax = std::abs(utils::cast(x[ix])); + + if (ax > tbig) { + const Tcompute y = ax * sbig; + abig += y * y; + notbig = false; + } else if (ax < tsml) { + if (notbig) { + const Tcompute y = ax * ssml; + asml += y * y; + } + } else { + amed += ax * ax; + } + + ix += incx; + } + + if (abig > Tcompute(0)) { + if (amed > Tcompute(0) || std::isinf(amed) || std::isnan(amed)) { + abig += (amed * sbig) * sbig; + } + scl = Tcompute(1) / sbig; + sumsq = abig; + } else if (asml > Tcompute(0)) { + if (amed > Tcompute(0) || std::isinf(amed) || std::isnan(amed)) { + amed = std::sqrt(amed); + asml = std::sqrt(asml) / ssml; + + const Tcompute ymin = std::min(amed, asml); + const Tcompute ymax = std::max(amed, asml); + + scl = Tcompute(1); + sumsq = (ymax * ymax) * (Tcompute(1) + (ymin / ymax) * (ymin / ymax)); + } else { + scl = Tcompute(1) / ssml; + sumsq = asml; + } + } else { + scl = Tcompute(1); + sumsq = amed; + } + + result[0] = utils::cast(scl * std::sqrt(sumsq)); + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_NRM2(TDATA) \ + calculateNrm2(_info, \ + (const TDATA *)x, \ + (TDATA *)result) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + (void)stream; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_NRM2(fp16_t); + case INFINI_DTYPE_F32: + return CALCULATE_NRM2(float); + case INFINI_DTYPE_F64: + return CALCULATE_NRM2(double); + case INFINI_DTYPE_BF16: + return CALCULATE_NRM2(bf16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_NRM2 + +} // namespace op::nrm2::cpu diff --git a/src/infiniop/ops/nrm2/cpu/nrm2_cpu.h b/src/infiniop/ops/nrm2/cpu/nrm2_cpu.h new file mode 100644 index 000000000..320e18f28 --- /dev/null +++ b/src/infiniop/ops/nrm2/cpu/nrm2_cpu.h @@ -0,0 +1,8 @@ +#ifndef __NRM2_CPU_H__ +#define __NRM2_CPU_H__ + +#include "../nrm2.h" + +DESCRIPTOR(cpu) + +#endif // __NRM2_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/nrm2/metax/nrm2_metax.cc b/src/infiniop/ops/nrm2/metax/nrm2_metax.cc new file mode 100644 index 000000000..4b3b528d2 --- /dev/null +++ b/src/infiniop/ops/nrm2/metax/nrm2_metax.cc @@ -0,0 +1,71 @@ +#include "nrm2_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::nrm2::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = Nrm2Info::createNrm2Info(x_desc, result_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const size_t size = _info.getSize(); + const ptrdiff_t incx = _info.getIncx(); + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasSnrm2(handle, size, (const float *)x, incx, (float *)result)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasDnrm2(handle, size, (const double *)x, incx, (double *)result)); + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::nrm2::metax diff --git a/src/infiniop/ops/nrm2/metax/nrm2_metax.h b/src/infiniop/ops/nrm2/metax/nrm2_metax.h new file mode 100644 index 000000000..b398b5755 --- /dev/null +++ b/src/infiniop/ops/nrm2/metax/nrm2_metax.h @@ -0,0 +1,8 @@ +#ifndef __NRM2_METAX_H__ +#define __NRM2_METAX_H__ + +#include "../nrm2.h" + +DESCRIPTOR(metax) + +#endif // __NRM2_METAX_H__ \ No newline at end of file diff --git a/src/infiniop/ops/nrm2/nrm2.h b/src/infiniop/ops/nrm2/nrm2.h new file mode 100644 index 000000000..82fb58e0b --- /dev/null +++ b/src/infiniop/ops/nrm2/nrm2.h @@ -0,0 +1,90 @@ +#ifndef __NRM2_H__ +#define __NRM2_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/nrm2.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::nrm2::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + Nrm2Info _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Nrm2Info info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t result_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + const void *x, \ + void *result, \ + void *stream) const; \ + }; \ + } + +class Nrm2Info { +private: + size_t _size; + ptrdiff_t _incx; + infiniDtype_t _dtype; + + Nrm2Info(size_t size, + ptrdiff_t incx, + infiniDtype_t dtype) + : _size(size), _incx(incx), _dtype(dtype) {} + +public: + inline size_t getSize() const { return _size; } + inline ptrdiff_t getIncx() const { return _incx; } + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static ResultType createNrm2Info( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = x_desc->dtype(); + + CHECK_OR_RETURN(result_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto size = x_desc->numel(); + auto incx = x_desc->stride(0); + + Nrm2Info info(size, incx, dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __NRM2_H__ diff --git a/src/infiniop/ops/nrm2/operator.cc b/src/infiniop/ops/nrm2/operator.cc new file mode 100644 index 000000000..8f9e00c0b --- /dev/null +++ b/src/infiniop/ops/nrm2/operator.cc @@ -0,0 +1,121 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/nrm2.h" + +#ifdef ENABLE_CPU_API +#include "cpu/nrm2_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/nrm2_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/nrm2_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateNrm2Descriptor( + infiniopHandle_t handle, + infiniopNrm2Descriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::nrm2::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + x_desc, result_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetNrm2WorkspaceSize(infiniopNrm2Descriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__INFINI_C infiniStatus_t infiniopNrm2( + infiniopNrm2Descriptor_t desc, + void *workspace, + size_t workspace_size, + const void *x, + void *result, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, x, result, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyNrm2Descriptor(infiniopNrm2Descriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} \ No newline at end of file diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 974ec4165..caa9efa73 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2355,3 +2355,35 @@ def blas_dot_(lib): lib.infiniopDestroyBlasDotDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def nrm2_(lib): + lib.infiniopCreateNrm2Descriptor.restype = c_int32 + lib.infiniopCreateNrm2Descriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetNrm2WorkspaceSize.restype = c_int32 + lib.infiniopGetNrm2WorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopNrm2.restype = c_int32 + lib.infiniopNrm2.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyNrm2Descriptor.restype = c_int32 + lib.infiniopDestroyNrm2Descriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/nrm2.py b/test/infiniop/nrm2.py new file mode 100644 index 000000000..1523ce082 --- /dev/null +++ b/test/infiniop/nrm2.py @@ -0,0 +1,152 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + +# ============================================================================== +# Configuration +# ============================================================================== +# Format: (shape, x_stride) +_TEST_CASES = [ + ((13,), None), + ((13,), (10,)), + ((5632,), None), + ((5632,), (5,)), + ((16,), (4,)), + ((5632,), (32,)), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, + InfiniDtype.F64: {"atol": 1e-7, "rtol": 1e-7}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def torch_nrm2(x): + return torch.norm(x, p=2) + + +def test( + handle, + device, + shape, + x_stride=None, + dtype=InfiniDtype.F32, + sync=None, +): + x = TestTensor(shape, x_stride, dtype, device) + + result = TestTensor(tuple(), None, dtype, device, mode="zeros") + + print( + f"Testing Nrm2 on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} " + f"dtype:{InfiniDtypeNames[dtype]}" + ) + + result_ref = torch_nrm2(x.torch_tensor()) + result.update_torch_tensor(result_ref) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateNrm2Descriptor( + handle, + ctypes.byref(descriptor), + x.descriptor, + result.descriptor, + ) + ) + + for tensor in [x, result]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetNrm2WorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, x.device) + + def lib_nrm2(): + check_error( + LIBINFINIOP.infiniopNrm2( + descriptor, + workspace.data(), + workspace.size(), + x.data(), + result.data(), + None, + ) + ) + + lib_nrm2() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + + if DEBUG: + debug(result.actual_tensor(), result.torch_tensor(), atol=atol, rtol=rtol) + + assert torch.allclose( + result.actual_tensor(), result.torch_tensor(), atol=atol, rtol=rtol + ) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: torch_nrm2(x.torch_tensor()), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lambda: lib_nrm2(), device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(LIBINFINIOP.infiniopDestroyNrm2Descriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") From 33fbb95a4cfadc8a5099d6573a5f7385a36e8306 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Tue, 28 Apr 2026 09:13:15 +0000 Subject: [PATCH 07/25] Add `rot` operator --- include/infiniop.h | 1 + include/infiniop/ops/rot.h | 28 +++ src/infiniop/devices/metax/metax_ht2mc.h | 2 + src/infiniop/ops/rot/bang/rot_bang.h | 8 + src/infiniop/ops/rot/bang/rot_bang.mlu | 103 +++++++++++ src/infiniop/ops/rot/bang/rot_bang_kernel.mlu | 102 +++++++++++ src/infiniop/ops/rot/cpu/rot_cpu.cc | 104 ++++++++++++ src/infiniop/ops/rot/cpu/rot_cpu.h | 8 + src/infiniop/ops/rot/metax/rot_metax.cc | 76 +++++++++ src/infiniop/ops/rot/metax/rot_metax.h | 8 + src/infiniop/ops/rot/operator.cc | 128 ++++++++++++++ src/infiniop/ops/rot/rot.h | 107 ++++++++++++ test/infiniop/libinfiniop/op_register.py | 36 ++++ test/infiniop/rot.py | 160 ++++++++++++++++++ 14 files changed, 871 insertions(+) create mode 100644 include/infiniop/ops/rot.h create mode 100644 src/infiniop/ops/rot/bang/rot_bang.h create mode 100644 src/infiniop/ops/rot/bang/rot_bang.mlu create mode 100644 src/infiniop/ops/rot/bang/rot_bang_kernel.mlu create mode 100644 src/infiniop/ops/rot/cpu/rot_cpu.cc create mode 100644 src/infiniop/ops/rot/cpu/rot_cpu.h create mode 100644 src/infiniop/ops/rot/metax/rot_metax.cc create mode 100644 src/infiniop/ops/rot/metax/rot_metax.h create mode 100644 src/infiniop/ops/rot/operator.cc create mode 100644 src/infiniop/ops/rot/rot.h create mode 100644 test/infiniop/rot.py diff --git a/include/infiniop.h b/include/infiniop.h index 6ddde667d..bfdfe1f88 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -100,6 +100,7 @@ #include "infiniop/ops/relu.h" #include "infiniop/ops/rms_norm.h" #include "infiniop/ops/rope.h" +#include "infiniop/ops/rot.h" #include "infiniop/ops/scatter.h" #include "infiniop/ops/selu.h" #include "infiniop/ops/sigmoid.h" diff --git a/include/infiniop/ops/rot.h b/include/infiniop/ops/rot.h new file mode 100644 index 000000000..0ecbae52d --- /dev/null +++ b/include/infiniop/ops/rot.h @@ -0,0 +1,28 @@ +#ifndef __INFINIOP_ROT_API_H__ +#define __INFINIOP_ROT_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopRotDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateRotDescriptor(infiniopHandle_t handle, + infiniopRotDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t c, + infiniopTensorDescriptor_t s); + +__INFINI_C __export infiniStatus_t infiniopGetRotWorkspaceSize(infiniopRotDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopRot(infiniopRotDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *x, + void *y, + const void *c, + const void *s, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyRotDescriptor(infiniopRotDescriptor_t desc); + +#endif // __INFINIOP_ROT_API_H__ diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index b2dfae91f..fd80f4f49 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -133,6 +133,8 @@ #define hcblasDdot mcblasDdot #define hcblasSnrm2 mcblasSnrm2 #define hcblasDnrm2 mcblasDnrm2 +#define hcblasSrot mcblasSrot +#define hcblasDrot mcblasDrot #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS #define HCBLAS_OP_T MCBLAS_OP_T #define HCBLAS_OP_N MCBLAS_OP_N diff --git a/src/infiniop/ops/rot/bang/rot_bang.h b/src/infiniop/ops/rot/bang/rot_bang.h new file mode 100644 index 000000000..86ce84f58 --- /dev/null +++ b/src/infiniop/ops/rot/bang/rot_bang.h @@ -0,0 +1,8 @@ +#ifndef __ROT_BANG_H__ +#define __ROT_BANG_H__ + +#include "../rot.h" + +DESCRIPTOR(bang) + +#endif // __ROT_BANG_H__ \ No newline at end of file diff --git a/src/infiniop/ops/rot/bang/rot_bang.mlu b/src/infiniop/ops/rot/bang/rot_bang.mlu new file mode 100644 index 000000000..755e8b638 --- /dev/null +++ b/src/infiniop/ops/rot/bang/rot_bang.mlu @@ -0,0 +1,103 @@ +#include "../../../devices/bang/common_bang.h" +#include "rot_bang.h" +#include "rot_bang_kernel.mlu" + +namespace op::rot::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t s_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = RotInfo::createRotInfo(x_desc, y_desc, c_desc, s_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateRot( + const RotInfo &info, + Tdata *x, + Tdata *y, + const Tdata *c, + const Tdata *s, + cnrtQueue_t queue) { + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeUnion1; + + if (info.getIncx() == 1 && info.getIncy() == 1) { + rotKernelContiguous<<>>( + info.getSize(), + x, + y, + c, + s); + } else { + rotKernelStrided<<>>( + info.getSize(), + x, + info.getIncx(), + y, + info.getIncy(), + c, + s); + } + + cnrtQueueSync(queue); + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_ROT(TDATA) \ + calculateRot(_info, \ + (TDATA *)x, \ + (TDATA *)y, \ + (const TDATA *)c, \ + (const TDATA *)s, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *x, + void *y, + const void *c, + const void *s, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_ROT(half); + case INFINI_DTYPE_F32: + return CALCULATE_ROT(float); + case INFINI_DTYPE_BF16: + return CALCULATE_ROT(bfloat16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_ROT + +} // namespace op::rot::bang diff --git a/src/infiniop/ops/rot/bang/rot_bang_kernel.mlu b/src/infiniop/ops/rot/bang/rot_bang_kernel.mlu new file mode 100644 index 000000000..27f279e14 --- /dev/null +++ b/src/infiniop/ops/rot/bang/rot_bang_kernel.mlu @@ -0,0 +1,102 @@ +#include "../../../devices/bang/common_bang.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void rotKernelContiguous( + size_t n, + Tdata *x, + Tdata *y, + const Tdata *c, + const Tdata *s) { + + Tdata *nram_align = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_align - nram_buffer); + size_t max_chunk_elements = nram_usable / (4 * sizeof(Tdata)); + + int align_elements = ALIGN_SIZE / sizeof(Tdata); + if (align_elements == 0) { + align_elements = 1; + } + max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + + Tdata *nram_x = nram_align; + Tdata *nram_y = nram_align + max_chunk_elements; + Tdata *nram_x_out = nram_align + 2 * max_chunk_elements; + Tdata *nram_y_out = nram_align + 3 * max_chunk_elements; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int core_elements = elements_per_core + (taskId < remain ? 1 : 0); + int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + + if (core_elements <= 0) { + return; + } + + int chunks = core_elements / max_chunk_elements; + int chunk_rem = core_elements % max_chunk_elements; + + for (int ck = 0; ck < chunks; ck++) { + size_t current_offset = core_offset + ck * max_chunk_elements; + __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + + __bang_mul_scalar(nram_x_out, nram_x, c[0], max_chunk_elements); + __bang_mul_scalar(nram_y_out, nram_y, s[0], max_chunk_elements); + __bang_add(nram_x_out, nram_x_out, nram_y_out, max_chunk_elements); + + __memcpy(x + current_offset, nram_x_out, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + + __bang_mul_scalar(nram_y_out, nram_y, c[0], max_chunk_elements); + __bang_mul_scalar(nram_x_out, nram_x, s[0], max_chunk_elements); + __bang_sub(nram_y_out, nram_y_out, nram_x_out, max_chunk_elements); + + __memcpy(y + current_offset, nram_y_out, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + } + + if (chunk_rem > 0) { + size_t current_offset = core_offset + chunks * max_chunk_elements; + __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + + __bang_mul_scalar(nram_x_out, nram_x, c[0], chunk_rem); + __bang_mul_scalar(nram_y_out, nram_y, s[0], chunk_rem); + __bang_add(nram_x_out, nram_x_out, nram_y_out, chunk_rem); + + __memcpy(x + current_offset, nram_x_out, chunk_rem * sizeof(Tdata), NRAM2GDRAM); + + __bang_mul_scalar(nram_y_out, nram_y, c[0], chunk_rem); + __bang_mul_scalar(nram_x_out, nram_x, s[0], chunk_rem); + __bang_sub(nram_y_out, nram_y_out, nram_x_out, chunk_rem); + + __memcpy(y + current_offset, nram_y_out, chunk_rem * sizeof(Tdata), NRAM2GDRAM); + } +} + +template +__mlu_global__ void rotKernelStrided( + size_t n, + Tdata *x, + size_t incx, + Tdata *y, + size_t incy, + const Tdata *c, + const Tdata *s) { + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0); + int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; + + for (int i = start_idx; i < start_idx + actual_tasks; ++i) { + size_t x_idx = i * incx; + size_t y_idx = i * incy; + Tdata x_val = x[x_idx]; + Tdata y_val = y[y_idx]; + + x[x_idx] = c[0] * x_val + s[0] * y_val; + y[y_idx] = c[0] * y_val - s[0] * x_val; + } +} diff --git a/src/infiniop/ops/rot/cpu/rot_cpu.cc b/src/infiniop/ops/rot/cpu/rot_cpu.cc new file mode 100644 index 000000000..b0eb9dfcb --- /dev/null +++ b/src/infiniop/ops/rot/cpu/rot_cpu.cc @@ -0,0 +1,104 @@ +#include "rot_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::rot::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t s_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = RotInfo::createRotInfo(x_desc, y_desc, c_desc, s_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateRot( + const RotInfo &info, + Tdata *x, + Tdata *y, + const Tdata *c, + const Tdata *s) { + + using Tcompute = std::conditional_t, double, float>; + + const Tcompute c_val = utils::cast(c[0]); + const Tcompute s_val = utils::cast(s[0]); + + const ptrdiff_t size = static_cast(info.getSize()); + const ptrdiff_t incx = info.getIncx(); + const ptrdiff_t incy = info.getIncy(); + + if (size <= 0) { + return INFINI_STATUS_SUCCESS; + } + + const ptrdiff_t ix = incx >= 0 ? 0 : (size - 1) * (-incx); + const ptrdiff_t iy = incy >= 0 ? 0 : (size - 1) * (-incy); + + for (ptrdiff_t i = 0; i < size; ++i) { + const ptrdiff_t x_idx = ix + i * incx; + const ptrdiff_t y_idx = iy + i * incy; + + const Tcompute x_val = utils::cast(x[x_idx]); + const Tcompute y_val = utils::cast(y[y_idx]); + const Tcompute temp = c_val * x_val + s_val * y_val; + y[y_idx] = utils::cast(c_val * y_val - s_val * x_val); + x[x_idx] = utils::cast(temp); + } + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_ROT(TDATA) \ + calculateRot(_info, \ + (TDATA *)x, \ + (TDATA *)y, \ + (const TDATA *)c, \ + (const TDATA *)s) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *x, + void *y, + const void *c, + const void *s, + void *stream) const { + + (void)workspace; + (void)workspace_size; + (void)stream; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_ROT(fp16_t); + case INFINI_DTYPE_F32: + return CALCULATE_ROT(float); + case INFINI_DTYPE_F64: + return CALCULATE_ROT(double); + case INFINI_DTYPE_BF16: + return CALCULATE_ROT(bf16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_ROT + +} // namespace op::rot::cpu diff --git a/src/infiniop/ops/rot/cpu/rot_cpu.h b/src/infiniop/ops/rot/cpu/rot_cpu.h new file mode 100644 index 000000000..57d0a054e --- /dev/null +++ b/src/infiniop/ops/rot/cpu/rot_cpu.h @@ -0,0 +1,8 @@ +#ifndef __ROT_CPU_H__ +#define __ROT_CPU_H__ + +#include "../rot.h" + +DESCRIPTOR(cpu) + +#endif // __ROT_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/rot/metax/rot_metax.cc b/src/infiniop/ops/rot/metax/rot_metax.cc new file mode 100644 index 000000000..0b0c3e567 --- /dev/null +++ b/src/infiniop/ops/rot/metax/rot_metax.cc @@ -0,0 +1,76 @@ +#include "rot_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::rot::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t s_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = RotInfo::createRotInfo(x_desc, y_desc, c_desc, s_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *x, + void *y, + const void *c, + const void *s, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const size_t size = _info.getSize(); + const ptrdiff_t incx = _info.getIncx(); + const ptrdiff_t incy = _info.getIncy(); + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasSrot(handle, size, (float *)x, incx, (float *)y, incy, (const float *)c, (const float *)s)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasDrot(handle, size, (double *)x, incx, (double *)y, incy, (const double *)c, (const double *)s)); + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::rot::metax diff --git a/src/infiniop/ops/rot/metax/rot_metax.h b/src/infiniop/ops/rot/metax/rot_metax.h new file mode 100644 index 000000000..2c64a3a0b --- /dev/null +++ b/src/infiniop/ops/rot/metax/rot_metax.h @@ -0,0 +1,8 @@ +#ifndef __ROT_METAX_H__ +#define __ROT_METAX_H__ + +#include "../rot.h" + +DESCRIPTOR(metax) + +#endif // __ROT_METAX_H__ \ No newline at end of file diff --git a/src/infiniop/ops/rot/operator.cc b/src/infiniop/ops/rot/operator.cc new file mode 100644 index 000000000..6c1345d7a --- /dev/null +++ b/src/infiniop/ops/rot/operator.cc @@ -0,0 +1,128 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/rot.h" + +#ifdef ENABLE_CPU_API +#include "cpu/rot_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/rot_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/rot_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateRotDescriptor( + infiniopHandle_t handle, + infiniopRotDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t s_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::rot::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + x_desc, \ + y_desc, \ + c_desc, \ + s_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetRotWorkspaceSize(infiniopRotDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__INFINI_C infiniStatus_t infiniopRot( + infiniopRotDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *x, + void *y, + const void *c, + const void *s, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, x, y, c, s, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyRotDescriptor(infiniopRotDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/src/infiniop/ops/rot/rot.h b/src/infiniop/ops/rot/rot.h new file mode 100644 index 000000000..fc73a38b9 --- /dev/null +++ b/src/infiniop/ops/rot/rot.h @@ -0,0 +1,107 @@ +#ifndef __ROT_H__ +#define __ROT_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/rot.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::rot::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + RotInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + RotInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t c_desc, \ + infiniopTensorDescriptor_t s_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *x, \ + void *y, \ + const void *c, \ + const void *s, \ + void *stream) const; \ + }; \ + } + +class RotInfo { +private: + size_t _size; + ptrdiff_t _incx; + ptrdiff_t _incy; + infiniDtype_t _dtype; + + RotInfo(size_t size, + ptrdiff_t incx, + ptrdiff_t incy, + infiniDtype_t dtype) + : _size(size), _incx(incx), _incy(incy), _dtype(dtype) {} + +public: + inline size_t getSize() const { return _size; } + inline ptrdiff_t getIncx() const { return _incx; } + inline ptrdiff_t getIncy() const { return _incy; } + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static ResultType createRotInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t s_desc) { + + CHECK_OR_RETURN(c_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(s_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = x_desc->dtype(); + + CHECK_OR_RETURN(c_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(s_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(y_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_OR_RETURN(c_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(s_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto size = x_desc->numel(); + auto incx = x_desc->stride(0); + auto incy = y_desc->stride(0); + + RotInfo info(size, incx, incy, dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __ROT_H__ diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index caa9efa73..5998825de 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2387,3 +2387,39 @@ def nrm2_(lib): lib.infiniopDestroyNrm2Descriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def rot_(lib): + lib.infiniopCreateRotDescriptor.restype = c_int32 + lib.infiniopCreateRotDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetRotWorkspaceSize.restype = c_int32 + lib.infiniopGetRotWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopRot.restype = c_int32 + lib.infiniopRot.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyRotDescriptor.restype = c_int32 + lib.infiniopDestroyRotDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/rot.py b/test/infiniop/rot.py new file mode 100644 index 000000000..9eff59601 --- /dev/null +++ b/test/infiniop/rot.py @@ -0,0 +1,160 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + +_TEST_CASES = [ + ((13,), None, None), + ((13,), (10,), (10,)), + ((5632,), None, None), + ((5632,), (5,), (5,)), + ((16,), (4,), (4,)), + ((5632,), (32,), (32,)), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, + InfiniDtype.F64: {"atol": 1e-9, "rtol": 1e-9}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def torch_rot(x, y, c, s): + x0 = x.clone() + y0 = y.clone() + x.copy_(c * x0 + s * y0) + y.copy_(c * y0 - s * x0) + + +def test( + handle, + device, + shape, + x_stride=None, + y_stride=None, + dtype=torch.float32, + sync=None, +): + x = TestTensor(shape, x_stride, dtype, device) + y = TestTensor(shape, y_stride, dtype, device) + c = TestTensor(tuple(), None, dtype, device) + s = TestTensor(tuple(), None, dtype, device) + + if x.is_broadcast() or y.is_broadcast(): + return + + print( + f"Testing Rot on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} " + f"y_stride:{y_stride} dtype:{InfiniDtypeNames[dtype]}" + ) + + torch_rot(x.torch_tensor(), y.torch_tensor(), c.torch_tensor(), s.torch_tensor()) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateRotDescriptor( + handle, + ctypes.byref(descriptor), + x.descriptor, + y.descriptor, + c.descriptor, + s.descriptor, + ) + ) + + for tensor in [c, s, x, y]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetRotWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, x.device) + + def lib_rot(): + check_error( + LIBINFINIOP.infiniopRot( + descriptor, + workspace.data(), + workspace.size(), + x.data(), + y.data(), + c.data(), + s.data(), + None, + ) + ) + + lib_rot() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(x.actual_tensor(), x.torch_tensor(), atol=atol, rtol=rtol) + debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + assert torch.allclose(x.actual_tensor(), x.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: torch_rot( + x.torch_tensor(), y.torch_tensor(), c.torch_tensor(), s.torch_tensor() + ), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lambda: lib_rot(), device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(LIBINFINIOP.infiniopDestroyRotDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") From 07f47227d8bda5dcbaa613ff09ce70e557d38fa9 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Tue, 28 Apr 2026 09:13:47 +0000 Subject: [PATCH 08/25] Add `rotg` operator --- include/infiniop.h | 1 + include/infiniop/ops/rotg.h | 28 ++++ src/infiniop/devices/metax/metax_ht2mc.h | 2 + src/infiniop/ops/rotg/bang/rotg_bang.h | 8 + src/infiniop/ops/rotg/bang/rotg_bang.mlu | 89 +++++++++++ .../ops/rotg/bang/rotg_bang_kernel.mlu | 76 +++++++++ src/infiniop/ops/rotg/cpu/rotg_cpu.cc | 121 ++++++++++++++ src/infiniop/ops/rotg/cpu/rotg_cpu.h | 8 + src/infiniop/ops/rotg/metax/rotg_metax.cc | 73 +++++++++ src/infiniop/ops/rotg/metax/rotg_metax.h | 8 + src/infiniop/ops/rotg/operator.cc | 125 +++++++++++++++ src/infiniop/ops/rotg/rotg.h | 91 +++++++++++ test/infiniop/libinfiniop/op_register.py | 36 +++++ test/infiniop/rotg.py | 147 ++++++++++++++++++ 14 files changed, 813 insertions(+) create mode 100644 include/infiniop/ops/rotg.h create mode 100644 src/infiniop/ops/rotg/bang/rotg_bang.h create mode 100644 src/infiniop/ops/rotg/bang/rotg_bang.mlu create mode 100644 src/infiniop/ops/rotg/bang/rotg_bang_kernel.mlu create mode 100644 src/infiniop/ops/rotg/cpu/rotg_cpu.cc create mode 100644 src/infiniop/ops/rotg/cpu/rotg_cpu.h create mode 100644 src/infiniop/ops/rotg/metax/rotg_metax.cc create mode 100644 src/infiniop/ops/rotg/metax/rotg_metax.h create mode 100644 src/infiniop/ops/rotg/operator.cc create mode 100644 src/infiniop/ops/rotg/rotg.h create mode 100644 test/infiniop/rotg.py diff --git a/include/infiniop.h b/include/infiniop.h index bfdfe1f88..0008a04d8 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -101,6 +101,7 @@ #include "infiniop/ops/rms_norm.h" #include "infiniop/ops/rope.h" #include "infiniop/ops/rot.h" +#include "infiniop/ops/rotg.h" #include "infiniop/ops/scatter.h" #include "infiniop/ops/selu.h" #include "infiniop/ops/sigmoid.h" diff --git a/include/infiniop/ops/rotg.h b/include/infiniop/ops/rotg.h new file mode 100644 index 000000000..a9121b873 --- /dev/null +++ b/include/infiniop/ops/rotg.h @@ -0,0 +1,28 @@ +#ifndef __INFINIOP_ROTG_API_H__ +#define __INFINIOP_ROTG_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopRotgDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateRotgDescriptor(infiniopHandle_t handle, + infiniopRotgDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t c, + infiniopTensorDescriptor_t s); + +__INFINI_C __export infiniStatus_t infiniopGetRotgWorkspaceSize(infiniopRotgDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopRotg(infiniopRotgDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *x, + void *y, + void *c, + void *s, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyRotgDescriptor(infiniopRotgDescriptor_t desc); + +#endif // __INFINIOP_ROTG_API_H__ \ No newline at end of file diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index fd80f4f49..0f0e8d3d3 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -135,6 +135,8 @@ #define hcblasDnrm2 mcblasDnrm2 #define hcblasSrot mcblasSrot #define hcblasDrot mcblasDrot +#define hcblasSrotg mcblasSrotg +#define hcblasDrotg mcblasDrotg #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS #define HCBLAS_OP_T MCBLAS_OP_T #define HCBLAS_OP_N MCBLAS_OP_N diff --git a/src/infiniop/ops/rotg/bang/rotg_bang.h b/src/infiniop/ops/rotg/bang/rotg_bang.h new file mode 100644 index 000000000..e74696b9e --- /dev/null +++ b/src/infiniop/ops/rotg/bang/rotg_bang.h @@ -0,0 +1,8 @@ +#ifndef __ROTG_BANG_H__ +#define __ROTG_BANG_H__ + +#include "../rotg.h" + +DESCRIPTOR(bang) + +#endif // __ROTG_BANG_H__ diff --git a/src/infiniop/ops/rotg/bang/rotg_bang.mlu b/src/infiniop/ops/rotg/bang/rotg_bang.mlu new file mode 100644 index 000000000..9a7f56fdb --- /dev/null +++ b/src/infiniop/ops/rotg/bang/rotg_bang.mlu @@ -0,0 +1,89 @@ +#include "../../../devices/bang/common_bang.h" +#include "rotg_bang.h" +#include "rotg_bang_kernel.mlu" + +namespace op::rotg::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t s_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = RotgInfo::createRotgInfo(x_desc, y_desc, c_desc, s_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateRotg( + Tdata *x, + Tdata *y, + Tdata *c, + Tdata *s, + cnrtQueue_t queue) { + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + k_dim.x = 1; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeBlock; + + rotgKernel<<>>( + x, + y, + c, + s); + + cnrtQueueSync(queue); + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_ROTG(TDATA) \ + calculateRotg((TDATA *)x, \ + (TDATA *)y, \ + (TDATA *)c, \ + (TDATA *)s, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *x, + void *y, + void *c, + void *s, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_ROTG(half); + case INFINI_DTYPE_BF16: + return CALCULATE_ROTG(bfloat16_t); + case INFINI_DTYPE_F32: + return CALCULATE_ROTG(float); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_ROTG + +} // namespace op::rotg::bang diff --git a/src/infiniop/ops/rotg/bang/rotg_bang_kernel.mlu b/src/infiniop/ops/rotg/bang/rotg_bang_kernel.mlu new file mode 100644 index 000000000..25fb6a8f4 --- /dev/null +++ b/src/infiniop/ops/rotg/bang/rotg_bang_kernel.mlu @@ -0,0 +1,76 @@ +#include "../../../devices/bang/common_bang.h" + +#include +#include + +template +__mlu_func__ float rotgToCompute(Tdata value) { + if constexpr (std::is_same_v) { + return __half2float(value); + } else if constexpr (std::is_same_v) { + return __bfloat162float(value); + } else { + return static_cast(value); + } +} + +template +__mlu_func__ Tdata rotgFromCompute(float value) { + if constexpr (std::is_same_v) { + return __float2half(value); + } else if constexpr (std::is_same_v) { + return __float2bfloat16(value); + } else { + return static_cast(value); + } +} + +template +__mlu_global__ void rotgKernel( + Tdata *x, + Tdata *y, + Tdata *c, + Tdata *s) { + + const float zero = 0.0f; + const float one = 1.0f; + const float safmin = std::numeric_limits::min(); + const float safmax = std::numeric_limits::max(); + + const float x_val = rotgToCompute(*x); + const float y_val = rotgToCompute(*y); + + const float xnorm = std::fabs(x_val); + const float ynorm = std::fabs(y_val); + + if (ynorm == zero) { + *c = rotgFromCompute(one); + *s = rotgFromCompute(zero); + *y = rotgFromCompute(zero); + } else if (xnorm == zero) { + *c = rotgFromCompute(zero); + *s = rotgFromCompute(one); + *x = rotgFromCompute(y_val); + *y = rotgFromCompute(one); + } else { + const float scl = std::min(safmax, std::max(safmin, std::max(xnorm, ynorm))); + const float sigma = xnorm > ynorm ? std::copysign(one, x_val) : std::copysign(one, y_val); + const float r = sigma * (scl * std::sqrt((x_val / scl) * (x_val / scl) + (y_val / scl) * (y_val / scl))); + const float c_val = x_val / r; + const float s_val = y_val / r; + + float z; + if (xnorm > ynorm) { + z = s_val; + } else if (c_val != zero) { + z = one / c_val; + } else { + z = one; + } + + *x = rotgFromCompute(r); + *y = rotgFromCompute(z); + *c = rotgFromCompute(c_val); + *s = rotgFromCompute(s_val); + } +} diff --git a/src/infiniop/ops/rotg/cpu/rotg_cpu.cc b/src/infiniop/ops/rotg/cpu/rotg_cpu.cc new file mode 100644 index 000000000..66f7f1640 --- /dev/null +++ b/src/infiniop/ops/rotg/cpu/rotg_cpu.cc @@ -0,0 +1,121 @@ +#include "rotg_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +#include + +namespace op::rotg::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t s_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = RotgInfo::createRotgInfo(x_desc, y_desc, c_desc, s_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateRotg( + Tdata *x, + Tdata *y, + Tdata *c, + Tdata *s) { + + using Tcompute = std::conditional_t, double, float>; + + const Tcompute zero = utils::cast(0.0f); + const Tcompute one = utils::cast(1.0f); + + Tcompute x_val = utils::cast(x[0]); + Tcompute y_val = utils::cast(y[0]); + + const Tcompute anorm = std::abs(x_val); + const Tcompute bnorm = std::abs(y_val); + + if (bnorm == zero) { + c[0] = utils::cast(one); + s[0] = utils::cast(zero); + y[0] = utils::cast(zero); + return INFINI_STATUS_SUCCESS; + } + + if (anorm == zero) { + c[0] = utils::cast(zero); + s[0] = utils::cast(one); + x[0] = utils::cast(y_val); + y[0] = utils::cast(one); + return INFINI_STATUS_SUCCESS; + } + + const Tcompute sigma = anorm > bnorm ? std::copysign(one, x_val) : std::copysign(one, y_val); + const Tcompute r = sigma * std::hypot(x_val, y_val); + const Tcompute c_val = x_val / r; + const Tcompute s_val = y_val / r; + + Tcompute z; + if (anorm > bnorm) { + z = s_val; + } else if (c_val != zero) { + z = one / c_val; + } else { + z = one; + } + + x[0] = utils::cast(r); + y[0] = utils::cast(z); + c[0] = utils::cast(c_val); + s[0] = utils::cast(s_val); + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_ROTG(TDATA) \ + calculateRotg((TDATA *)x, \ + (TDATA *)y, \ + (TDATA *)c, \ + (TDATA *)s) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *x, + void *y, + void *c, + void *s, + void *stream) const { + + (void)workspace; + (void)workspace_size; + (void)stream; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_ROTG(fp16_t); + case INFINI_DTYPE_F32: + return CALCULATE_ROTG(float); + case INFINI_DTYPE_F64: + return CALCULATE_ROTG(double); + case INFINI_DTYPE_BF16: + return CALCULATE_ROTG(bf16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_ROTG + +} // namespace op::rotg::cpu diff --git a/src/infiniop/ops/rotg/cpu/rotg_cpu.h b/src/infiniop/ops/rotg/cpu/rotg_cpu.h new file mode 100644 index 000000000..a83cb2612 --- /dev/null +++ b/src/infiniop/ops/rotg/cpu/rotg_cpu.h @@ -0,0 +1,8 @@ +#ifndef __ROTG_CPU_H__ +#define __ROTG_CPU_H__ + +#include "../rotg.h" + +DESCRIPTOR(cpu) + +#endif // __ROTG_CPU_H__ diff --git a/src/infiniop/ops/rotg/metax/rotg_metax.cc b/src/infiniop/ops/rotg/metax/rotg_metax.cc new file mode 100644 index 000000000..18e64368b --- /dev/null +++ b/src/infiniop/ops/rotg/metax/rotg_metax.cc @@ -0,0 +1,73 @@ +#include "rotg_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::rotg::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t s_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = RotgInfo::createRotgInfo(x_desc, y_desc, c_desc, s_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *x, + void *y, + void *c, + void *s, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasSrotg(handle, (float *)x, (float *)y, (float *)c, (float *)s)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasDrotg(handle, (double *)x, (double *)y, (double *)c, (double *)s)); + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::rotg::metax diff --git a/src/infiniop/ops/rotg/metax/rotg_metax.h b/src/infiniop/ops/rotg/metax/rotg_metax.h new file mode 100644 index 000000000..aaf5f2612 --- /dev/null +++ b/src/infiniop/ops/rotg/metax/rotg_metax.h @@ -0,0 +1,8 @@ +#ifndef __ROTG_METAX_H__ +#define __ROTG_METAX_H__ + +#include "../rotg.h" + +DESCRIPTOR(metax) + +#endif // __ROTG_METAX_H__ diff --git a/src/infiniop/ops/rotg/operator.cc b/src/infiniop/ops/rotg/operator.cc new file mode 100644 index 000000000..0fa83d664 --- /dev/null +++ b/src/infiniop/ops/rotg/operator.cc @@ -0,0 +1,125 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/rotg.h" + +#ifdef ENABLE_CPU_API +#include "cpu/rotg_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/rotg_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/rotg_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateRotgDescriptor( + infiniopHandle_t handle, + infiniopRotgDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t s_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::rotg::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + x_desc, y_desc, c_desc, s_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetRotgWorkspaceSize(infiniopRotgDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__INFINI_C infiniStatus_t infiniopRotg( + infiniopRotgDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *x, + void *y, + void *c, + void *s, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, x, y, c, s, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyRotgDescriptor(infiniopRotgDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/src/infiniop/ops/rotg/rotg.h b/src/infiniop/ops/rotg/rotg.h new file mode 100644 index 000000000..7156cdf1a --- /dev/null +++ b/src/infiniop/ops/rotg/rotg.h @@ -0,0 +1,91 @@ +#ifndef __ROTG_H__ +#define __ROTG_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/rotg.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::rotg::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + RotgInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + RotgInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t c_desc, \ + infiniopTensorDescriptor_t s_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *x, \ + void *y, \ + void *c, \ + void *s, \ + void *stream) const; \ + }; \ + } + +class RotgInfo { +private: + infiniDtype_t _dtype; + + explicit RotgInfo(infiniDtype_t dtype) : _dtype(dtype) {} + +public: + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static ResultType createRotgInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t s_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(c_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(s_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = x_desc->dtype(); + + CHECK_OR_RETURN(y_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(c_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(s_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_OR_RETURN(x_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(c_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(s_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + RotgInfo info(dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __ROTG_H__ diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 5998825de..5822b0527 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2423,3 +2423,39 @@ def rot_(lib): lib.infiniopDestroyRotDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def rotg_(lib): + lib.infiniopCreateRotgDescriptor.restype = c_int32 + lib.infiniopCreateRotgDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetRotgWorkspaceSize.restype = c_int32 + lib.infiniopGetRotgWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopRotg.restype = c_int32 + lib.infiniopRotg.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyRotgDescriptor.restype = c_int32 + lib.infiniopDestroyRotgDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/rotg.py b/test/infiniop/rotg.py new file mode 100644 index 000000000..da71dbe2b --- /dev/null +++ b/test/infiniop/rotg.py @@ -0,0 +1,147 @@ +import ctypes +import math +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + test_operator, +) + +_TEST_CASES = [ + (0.0, 0.0), + (3.0, 4.0), + (-2.5, 5.0), + (7.0, -1.5), + (-3.2, -8.4), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, + InfiniDtype.F64: {"atol": 1e-7, "rtol": 1e-7}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + + +def torch_rotg(a, b): + anorm = abs(a) + bnorm = abs(b) + if bnorm == 0.0: + return a, 0.0, 1.0, 0.0 + if anorm == 0.0: + return b, 1.0, 0.0, 1.0 + + sigma = math.copysign(1.0, a if anorm > bnorm else b) + r = sigma * math.hypot(a, b) + c = a / r + s = b / r + if anorm > bnorm: + z = s + elif c != 0.0: + z = 1.0 / c + else: + z = 1.0 + return r, z, c, s + + +def test(handle, device, a0, b0, dtype=torch.float32, sync=None): + a_torch = torch.tensor([a0]) + b_torch = torch.tensor([b0]) + a = TestTensor( + a_torch.shape, + a_torch.stride(), + dtype, + device, + mode="manual", + set_tensor=a_torch, + ) + b = TestTensor( + b_torch.shape, + b_torch.stride(), + dtype, + device, + mode="manual", + set_tensor=b_torch, + ) + c = TestTensor(tuple(), None, dtype, device, mode="zeros") + s = TestTensor(tuple(), None, dtype, device, mode="zeros") + + exp_a, exp_b, exp_c, exp_s = torch_rotg(a0, b0) + + print( + f"Testing Rotg on {InfiniDeviceNames[device]} with a:{a0} b:{b0} dtype:{InfiniDtypeNames[dtype]}" + ) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateRotgDescriptor( + handle, + ctypes.byref(descriptor), + a.descriptor, + b.descriptor, + c.descriptor, + s.descriptor, + ) + ) + + a.destroy_desc() + b.destroy_desc() + c.destroy_desc() + s.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetRotgWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + check_error( + LIBINFINIOP.infiniopRotg( + descriptor, + workspace.data(), + workspace.size(), + a.data(), + b.data(), + c.data(), + s.data(), + None, + ) + ) + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + assert math.isclose(a.actual_tensor().item(), exp_a, rel_tol=rtol, abs_tol=atol) + assert math.isclose(b.actual_tensor().item(), exp_b, rel_tol=rtol, abs_tol=atol) + assert math.isclose(c.actual_tensor().item(), exp_c, rel_tol=rtol, abs_tol=atol) + assert math.isclose(s.actual_tensor().item(), exp_s, rel_tol=rtol, abs_tol=atol) + + check_error(LIBINFINIOP.infiniopDestroyRotgDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + print("\033[92mTest passed!\033[0m") From 9259321db66f0f4e8c5a266122e613f14dc27d0f Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Tue, 28 Apr 2026 09:14:24 +0000 Subject: [PATCH 09/25] Add `rotm` operator --- include/infiniop.h | 1 + include/infiniop/ops/rotm.h | 26 +++ src/infiniop/devices/metax/metax_ht2mc.h | 2 + src/infiniop/ops/rotm/bang/rotm_bang.h | 8 + src/infiniop/ops/rotm/bang/rotm_bang.mlu | 97 ++++++++++ .../ops/rotm/bang/rotm_bang_kernel.mlu | 183 ++++++++++++++++++ src/infiniop/ops/rotm/cpu/rotm_cpu.cc | 169 ++++++++++++++++ src/infiniop/ops/rotm/cpu/rotm_cpu.h | 8 + src/infiniop/ops/rotm/metax/rotm_metax.cc | 74 +++++++ src/infiniop/ops/rotm/metax/rotm_metax.h | 8 + src/infiniop/ops/rotm/operator.cc | 129 ++++++++++++ src/infiniop/ops/rotm/rotm.h | 102 ++++++++++ test/infiniop/libinfiniop/op_register.py | 34 ++++ test/infiniop/rotm.py | 180 +++++++++++++++++ 14 files changed, 1021 insertions(+) create mode 100644 include/infiniop/ops/rotm.h create mode 100644 src/infiniop/ops/rotm/bang/rotm_bang.h create mode 100644 src/infiniop/ops/rotm/bang/rotm_bang.mlu create mode 100644 src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu create mode 100644 src/infiniop/ops/rotm/cpu/rotm_cpu.cc create mode 100644 src/infiniop/ops/rotm/cpu/rotm_cpu.h create mode 100644 src/infiniop/ops/rotm/metax/rotm_metax.cc create mode 100644 src/infiniop/ops/rotm/metax/rotm_metax.h create mode 100644 src/infiniop/ops/rotm/operator.cc create mode 100644 src/infiniop/ops/rotm/rotm.h create mode 100644 test/infiniop/rotm.py diff --git a/include/infiniop.h b/include/infiniop.h index 0008a04d8..8115deb61 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -102,6 +102,7 @@ #include "infiniop/ops/rope.h" #include "infiniop/ops/rot.h" #include "infiniop/ops/rotg.h" +#include "infiniop/ops/rotm.h" #include "infiniop/ops/scatter.h" #include "infiniop/ops/selu.h" #include "infiniop/ops/sigmoid.h" diff --git a/include/infiniop/ops/rotm.h b/include/infiniop/ops/rotm.h new file mode 100644 index 000000000..279d18129 --- /dev/null +++ b/include/infiniop/ops/rotm.h @@ -0,0 +1,26 @@ +#ifndef __INFINIOP_ROTM_API_H__ +#define __INFINIOP_ROTM_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopRotmDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateRotmDescriptor(infiniopHandle_t handle, + infiniopRotmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t param); + +__INFINI_C __export infiniStatus_t infiniopGetRotmWorkspaceSize(infiniopRotmDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopRotm(infiniopRotmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *x, + void *y, + const void *param, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyRotmDescriptor(infiniopRotmDescriptor_t desc); + +#endif // __INFINIOP_ROTM_API_H__ \ No newline at end of file diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index 0f0e8d3d3..af7d995d4 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -137,6 +137,8 @@ #define hcblasDrot mcblasDrot #define hcblasSrotg mcblasSrotg #define hcblasDrotg mcblasDrotg +#define hcblasSrotm mcblasSrotm +#define hcblasDrotm mcblasDrotm #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS #define HCBLAS_OP_T MCBLAS_OP_T #define HCBLAS_OP_N MCBLAS_OP_N diff --git a/src/infiniop/ops/rotm/bang/rotm_bang.h b/src/infiniop/ops/rotm/bang/rotm_bang.h new file mode 100644 index 000000000..4d2473b20 --- /dev/null +++ b/src/infiniop/ops/rotm/bang/rotm_bang.h @@ -0,0 +1,8 @@ +#ifndef __ROTM_BANG_H__ +#define __ROTM_BANG_H__ + +#include "../rotm.h" + +DESCRIPTOR(bang) + +#endif // __ROTM_BANG_H__ \ No newline at end of file diff --git a/src/infiniop/ops/rotm/bang/rotm_bang.mlu b/src/infiniop/ops/rotm/bang/rotm_bang.mlu new file mode 100644 index 000000000..8064953b3 --- /dev/null +++ b/src/infiniop/ops/rotm/bang/rotm_bang.mlu @@ -0,0 +1,97 @@ +#include "../../../devices/bang/common_bang.h" +#include "rotm_bang.h" +#include "rotm_bang_kernel.mlu" + +namespace op::rotm::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t param_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = RotmInfo::createRotmInfo(x_desc, y_desc, param_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateRotm( + const RotmInfo &info, + Tdata *x, + Tdata *y, + const Tdata *param, + cnrtQueue_t queue) { + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeUnion1; + + if (info.getIncx() == 1 && info.getIncy() == 1) { + rotmKernelContiguous<<>>( + info.getSize(), + x, + y, + param); + } else { + rotmKernelStrided<<>>( + info.getSize(), + x, + info.getIncx(), + y, + info.getIncy(), + param); + } + + cnrtQueueSync(queue); + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_ROTM(TDATA) \ + calculateRotm(_info, \ + (TDATA *)x, \ + (TDATA *)y, \ + (const TDATA *)param, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *x, + void *y, + const void *param, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_ROTM(half); + case INFINI_DTYPE_BF16: + return CALCULATE_ROTM(bfloat16_t); + case INFINI_DTYPE_F32: + return CALCULATE_ROTM(float); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_ROTM + +} // namespace op::rotm::bang diff --git a/src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu b/src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu new file mode 100644 index 000000000..297c7bea7 --- /dev/null +++ b/src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu @@ -0,0 +1,183 @@ +#include "../../../devices/bang/common_bang.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void rotmKernelContiguous( + size_t n, + Tdata *x, + Tdata *y, + const Tdata *param) { + + const Tdata flag = param[0]; + if (n == 0 || (flag + static_cast(2) == static_cast(0))) { + return; + } + + Tdata h11 = static_cast(0); + Tdata h12 = static_cast(0); + Tdata h21 = static_cast(0); + Tdata h22 = static_cast(0); + + if (flag < static_cast(0)) { + h11 = param[1]; + h12 = param[3]; + h21 = param[2]; + h22 = param[4]; + } else if (flag == static_cast(0)) { + h12 = param[3]; + h21 = param[2]; + } else { + h11 = param[1]; + h22 = param[4]; + } + + Tdata *nram_align = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_align - nram_buffer); + size_t max_chunk_elements = nram_usable / (4 * sizeof(Tdata)); + + int align_elements = ALIGN_SIZE / sizeof(Tdata); + if (align_elements == 0) { + align_elements = 1; + } + max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + + Tdata *nram_x = nram_align; + Tdata *nram_y = nram_align + max_chunk_elements; + Tdata *nram_x_out = nram_align + 2 * max_chunk_elements; + Tdata *nram_y_out = nram_align + 3 * max_chunk_elements; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int core_elements = elements_per_core + (taskId < remain ? 1 : 0); + int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + + if (core_elements <= 0) { + return; + } + + int chunks = core_elements / max_chunk_elements; + int chunk_rem = core_elements % max_chunk_elements; + + for (int c = 0; c < chunks; c++) { + size_t current_offset = core_offset + c * max_chunk_elements; + __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + + if (flag < static_cast(0)) { + __bang_mul_scalar(nram_x_out, nram_x, h11, max_chunk_elements); + __bang_mul_scalar(nram_y_out, nram_y, h12, max_chunk_elements); + __bang_add(nram_x_out, nram_x_out, nram_y_out, max_chunk_elements); + + __bang_mul_scalar(nram_y_out, nram_x, h21, max_chunk_elements); + __bang_mul_scalar(nram_y, nram_y, h22, max_chunk_elements); + __bang_add(nram_y_out, nram_y_out, nram_y, max_chunk_elements); + } else if (flag == static_cast(0)) { + __bang_mul_scalar(nram_x_out, nram_y, h12, max_chunk_elements); + __bang_add(nram_x_out, nram_x, nram_x_out, max_chunk_elements); + + __bang_mul_scalar(nram_y_out, nram_x, h21, max_chunk_elements); + __bang_add(nram_y_out, nram_y_out, nram_y, max_chunk_elements); + } else { + __bang_mul_scalar(nram_x_out, nram_x, h11, max_chunk_elements); + __bang_add(nram_x_out, nram_x_out, nram_y, max_chunk_elements); + + __bang_mul_scalar(nram_y_out, nram_y, h22, max_chunk_elements); + __bang_sub(nram_y_out, nram_y_out, nram_x, max_chunk_elements); + } + + __memcpy(x + current_offset, nram_x_out, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + __memcpy(y + current_offset, nram_y_out, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + } + + if (chunk_rem > 0) { + size_t current_offset = core_offset + chunks * max_chunk_elements; + __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + + if (flag < static_cast(0)) { + __bang_mul_scalar(nram_x_out, nram_x, h11, chunk_rem); + __bang_mul_scalar(nram_y_out, nram_y, h12, chunk_rem); + __bang_add(nram_x_out, nram_x_out, nram_y_out, chunk_rem); + + __bang_mul_scalar(nram_y_out, nram_x, h21, chunk_rem); + __bang_mul_scalar(nram_y, nram_y, h22, chunk_rem); + __bang_add(nram_y_out, nram_y_out, nram_y, chunk_rem); + } else if (flag == static_cast(0)) { + __bang_mul_scalar(nram_x_out, nram_y, h12, chunk_rem); + __bang_add(nram_x_out, nram_x, nram_x_out, chunk_rem); + + __bang_mul_scalar(nram_y_out, nram_x, h21, chunk_rem); + __bang_add(nram_y_out, nram_y_out, nram_y, chunk_rem); + } else { + __bang_mul_scalar(nram_x_out, nram_x, h11, chunk_rem); + __bang_add(nram_x_out, nram_x_out, nram_y, chunk_rem); + + __bang_mul_scalar(nram_y_out, nram_y, h22, chunk_rem); + __bang_sub(nram_y_out, nram_y_out, nram_x, chunk_rem); + } + + __memcpy(x + current_offset, nram_x_out, chunk_rem * sizeof(Tdata), NRAM2GDRAM); + __memcpy(y + current_offset, nram_y_out, chunk_rem * sizeof(Tdata), NRAM2GDRAM); + } +} + +template +__mlu_global__ void rotmKernelStrided( + size_t n, + Tdata *x, + size_t incx, + Tdata *y, + size_t incy, + const Tdata *param) { + + const Tdata flag = param[0]; + if (n == 0 || (flag + static_cast(2) == static_cast(0))) { + return; + } + + Tdata h11 = static_cast(0); + Tdata h12 = static_cast(0); + Tdata h21 = static_cast(0); + Tdata h22 = static_cast(0); + + if (flag < static_cast(0)) { + h11 = param[1]; + h12 = param[3]; + h21 = param[2]; + h22 = param[4]; + } else if (flag == static_cast(0)) { + h12 = param[3]; + h21 = param[2]; + } else { + h11 = param[1]; + h22 = param[4]; + } + + const size_t task = taskId; + const size_t tasks = taskDim; + const size_t per_task = n / tasks; + const size_t remain = n % tasks; + const size_t begin = task < remain ? task * (per_task + 1) : task * per_task + remain; + const size_t count = per_task + (task < remain ? 1 : 0); + + for (size_t i = 0; i < count; ++i) { + const size_t index = begin + i; + const size_t x_idx = index * incx; + const size_t y_idx = index * incy; + const Tdata w = x[x_idx]; + const Tdata z = y[y_idx]; + + if (flag < static_cast(0)) { + x[x_idx] = w * h11 + z * h12; + y[y_idx] = w * h21 + z * h22; + } else if (flag == static_cast(0)) { + x[x_idx] = w + z * h12; + y[y_idx] = w * h21 + z; + } else { + x[x_idx] = w * h11 + z; + y[y_idx] = -w + h22 * z; + } + } +} \ No newline at end of file diff --git a/src/infiniop/ops/rotm/cpu/rotm_cpu.cc b/src/infiniop/ops/rotm/cpu/rotm_cpu.cc new file mode 100644 index 000000000..51f11ad74 --- /dev/null +++ b/src/infiniop/ops/rotm/cpu/rotm_cpu.cc @@ -0,0 +1,169 @@ +#include "rotm_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::rotm::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t param_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = RotmInfo::createRotmInfo(x_desc, y_desc, param_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateRotm( + const RotmInfo &info, + Tdata *x, + Tdata *y, + const Tdata *param) { + + using Tcompute = std::conditional_t, double, float>; + + const Tcompute zero = utils::cast(0.0f); + const Tcompute two = utils::cast(2.0f); + + Tcompute sflag = utils::cast(param[0]); + + if (info.getSize() == 0 || (sflag + two == zero)) { + return INFINI_STATUS_SUCCESS; + } + + const ptrdiff_t size = static_cast(info.getSize()); + const ptrdiff_t incx = info.getIncx(); + const ptrdiff_t incy = info.getIncy(); + const ptrdiff_t kx = incx >= 0 ? 0 : (size - 1) * (-incx); + const ptrdiff_t ky = incy >= 0 ? 0 : (size - 1) * (-incy); + + Tcompute sh11 = zero; + Tcompute sh12 = zero; + Tcompute sh21 = zero; + Tcompute sh22 = zero; + + if (incx == incy && incx > 0) { + const ptrdiff_t nsteps = size * incx; + if (sflag < zero) { + sh11 = utils::cast(param[1]); + sh12 = utils::cast(param[3]); + sh21 = utils::cast(param[2]); + sh22 = utils::cast(param[4]); + for (ptrdiff_t i = 0; i < nsteps; i += incx) { + const Tcompute w = utils::cast(x[i]); + const Tcompute z = utils::cast(y[i]); + x[i] = utils::cast(w * sh11 + z * sh12); + y[i] = utils::cast(w * sh21 + z * sh22); + } + } else if (sflag == zero) { + sh12 = utils::cast(param[3]); + sh21 = utils::cast(param[2]); + for (ptrdiff_t i = 0; i < nsteps; i += incx) { + const Tcompute w = utils::cast(x[i]); + const Tcompute z = utils::cast(y[i]); + x[i] = utils::cast(w + z * sh12); + y[i] = utils::cast(w * sh21 + z); + } + } else { + sh11 = utils::cast(param[1]); + sh22 = utils::cast(param[4]); + for (ptrdiff_t i = 0; i < nsteps; i += incx) { + const Tcompute w = utils::cast(x[i]); + const Tcompute z = utils::cast(y[i]); + x[i] = utils::cast(w * sh11 + z); + y[i] = utils::cast(-w + sh22 * z); + } + } + } else { + ptrdiff_t ix = kx; + ptrdiff_t iy = ky; + + if (sflag < zero) { + sh11 = utils::cast(param[1]); + sh12 = utils::cast(param[3]); + sh21 = utils::cast(param[2]); + sh22 = utils::cast(param[4]); + for (ptrdiff_t i = 0; i < size; ++i) { + const Tcompute w = utils::cast(x[ix]); + const Tcompute z = utils::cast(y[iy]); + x[ix] = utils::cast(w * sh11 + z * sh12); + y[iy] = utils::cast(w * sh21 + z * sh22); + ix += incx; + iy += incy; + } + } else if (sflag == zero) { + sh12 = utils::cast(param[3]); + sh21 = utils::cast(param[2]); + for (ptrdiff_t i = 0; i < size; ++i) { + const Tcompute w = utils::cast(x[ix]); + const Tcompute z = utils::cast(y[iy]); + x[ix] = utils::cast(w + z * sh12); + y[iy] = utils::cast(w * sh21 + z); + ix += incx; + iy += incy; + } + } else { + sh11 = utils::cast(param[1]); + sh22 = utils::cast(param[4]); + for (ptrdiff_t i = 0; i < size; ++i) { + const Tcompute w = utils::cast(x[ix]); + const Tcompute z = utils::cast(y[iy]); + x[ix] = utils::cast(w * sh11 + z); + y[iy] = utils::cast(-w + sh22 * z); + ix += incx; + iy += incy; + } + } + } + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_ROTM(TDATA) \ + calculateRotm(_info, \ + (TDATA *)x, \ + (TDATA *)y, \ + (const TDATA *)param) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *x, + void *y, + const void *param, + void *stream) const { + + (void)workspace; + (void)workspace_size; + (void)stream; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_ROTM(fp16_t); + case INFINI_DTYPE_BF16: + return CALCULATE_ROTM(bf16_t); + case INFINI_DTYPE_F32: + return CALCULATE_ROTM(float); + case INFINI_DTYPE_F64: + return CALCULATE_ROTM(double); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_ROTM + +} // namespace op::rotm::cpu diff --git a/src/infiniop/ops/rotm/cpu/rotm_cpu.h b/src/infiniop/ops/rotm/cpu/rotm_cpu.h new file mode 100644 index 000000000..740d741ad --- /dev/null +++ b/src/infiniop/ops/rotm/cpu/rotm_cpu.h @@ -0,0 +1,8 @@ +#ifndef __ROTM_CPU_H__ +#define __ROTM_CPU_H__ + +#include "../rotm.h" + +DESCRIPTOR(cpu) + +#endif // __ROTM_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/rotm/metax/rotm_metax.cc b/src/infiniop/ops/rotm/metax/rotm_metax.cc new file mode 100644 index 000000000..13def4197 --- /dev/null +++ b/src/infiniop/ops/rotm/metax/rotm_metax.cc @@ -0,0 +1,74 @@ +#include "rotm_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::rotm::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t param_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = RotmInfo::createRotmInfo(x_desc, y_desc, param_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *x, + void *y, + const void *param, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const size_t size = _info.getSize(); + const ptrdiff_t incx = _info.getIncx(); + const ptrdiff_t incy = _info.getIncy(); + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasSrotm(handle, size, (float *)x, incx, (float *)y, incy, (const float *)param)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasDrotm(handle, size, (double *)x, incx, (double *)y, incy, (const double *)param)); + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::rotm::metax diff --git a/src/infiniop/ops/rotm/metax/rotm_metax.h b/src/infiniop/ops/rotm/metax/rotm_metax.h new file mode 100644 index 000000000..7dfe84cab --- /dev/null +++ b/src/infiniop/ops/rotm/metax/rotm_metax.h @@ -0,0 +1,8 @@ +#ifndef __ROTM_METAX_H__ +#define __ROTM_METAX_H__ + +#include "../rotm.h" + +DESCRIPTOR(metax) + +#endif // __ROTM_METAX_H__ \ No newline at end of file diff --git a/src/infiniop/ops/rotm/operator.cc b/src/infiniop/ops/rotm/operator.cc new file mode 100644 index 000000000..fb8eda223 --- /dev/null +++ b/src/infiniop/ops/rotm/operator.cc @@ -0,0 +1,129 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/rotm.h" + +#ifdef ENABLE_CPU_API +#include "cpu/rotm_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/rotm_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/rotm_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateRotmDescriptor( + infiniopHandle_t handle, + infiniopRotmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t param_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::rotm::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + x_desc, y_desc, param_desc) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetRotmWorkspaceSize(infiniopRotmDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__INFINI_C infiniStatus_t infiniopRotm( + infiniopRotmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *x, + void *y, + const void *param, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, x, y, param, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyRotmDescriptor(infiniopRotmDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} \ No newline at end of file diff --git a/src/infiniop/ops/rotm/rotm.h b/src/infiniop/ops/rotm/rotm.h new file mode 100644 index 000000000..7699b6b3b --- /dev/null +++ b/src/infiniop/ops/rotm/rotm.h @@ -0,0 +1,102 @@ +#ifndef __ROTM_H__ +#define __ROTM_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/rotm.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::rotm::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + RotmInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + RotmInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t param_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *x, \ + void *y, \ + const void *param, \ + void *stream) const; \ + }; \ + } + +class RotmInfo { +private: + size_t _size; + ptrdiff_t _incx; + ptrdiff_t _incy; + infiniDtype_t _dtype; + + RotmInfo(size_t size, + ptrdiff_t incx, + ptrdiff_t incy, + infiniDtype_t dtype) + : _size(size), _incx(incx), _incy(incy), _dtype(dtype) {} + +public: + inline size_t getSize() const { return _size; } + inline ptrdiff_t getIncx() const { return _incx; } + inline ptrdiff_t getIncy() const { return _incy; } + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static ResultType createRotmInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t param_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(param_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = x_desc->dtype(); + + CHECK_OR_RETURN(y_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(param_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(param_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(param_desc->numel() == 5, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(param_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + + auto size = x_desc->numel(); + auto incx = x_desc->stride(0); + auto incy = y_desc->stride(0); + + RotmInfo info(size, incx, incy, dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __ROTM_H__ diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 5822b0527..9afe97100 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2459,3 +2459,37 @@ def rotg_(lib): lib.infiniopDestroyRotgDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def rotm_(lib): + lib.infiniopCreateRotmDescriptor.restype = c_int32 + lib.infiniopCreateRotmDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetRotmWorkspaceSize.restype = c_int32 + lib.infiniopGetRotmWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopRotm.restype = c_int32 + lib.infiniopRotm.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyRotmDescriptor.restype = c_int32 + lib.infiniopDestroyRotmDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/rotm.py b/test/infiniop/rotm.py new file mode 100644 index 000000000..79637f671 --- /dev/null +++ b/test/infiniop/rotm.py @@ -0,0 +1,180 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + +_TEST_CASES = [ + ((13,), None, None, (-1.0, 1.2, -0.3, 0.4, 0.8)), + ((13,), (10,), (10,), (0.0, 0.0, -0.25, 0.5, 0.0)), + ((5632,), None, None, (1.0, 1.1, 0.0, 0.0, 0.9)), + ((5632,), (5,), (5,), (-2.0, 0.0, 0.0, 0.0, 0.0)), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, + InfiniDtype.F64: {"atol": 1e-9, "rtol": 1e-9}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def torch_rotm(x, y, param): + sflag, sh11, sh21, sh12, sh22 = param + if sflag == -2.0: + return + + w = x.clone() + z = y.clone() + + if sflag < 0.0: + x.copy_(w * sh11 + z * sh12) + y.copy_(w * sh21 + z * sh22) + elif sflag == 0.0: + x.copy_(w + z * sh12) + y.copy_(w * sh21 + z) + else: + x.copy_(w * sh11 + z) + y.copy_(-w + sh22 * z) + + +def _torch_dtype(dtype): + if dtype == InfiniDtype.F64: + return torch.float64 + return torch.float32 + + +def test( + handle, + device, + shape, + x_stride=None, + y_stride=None, + param=(-1.0, 1.2, -0.3, 0.4, 0.8), + dtype=torch.float32, + sync=None, +): + x = TestTensor(shape, x_stride, dtype, device) + y = TestTensor(shape, y_stride, dtype, device) + param_tensor = TestTensor( + (5,), + (1,), + dtype, + device, + mode="manual", + set_tensor=torch.tensor(param, dtype=_torch_dtype(dtype)), + ) + + if x.is_broadcast() or y.is_broadcast(): + return + + print( + f"Testing Rotm on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} " + f"y_stride:{y_stride} param:{param} dtype:{InfiniDtypeNames[dtype]}" + ) + + torch_rotm(x.torch_tensor(), y.torch_tensor(), param) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateRotmDescriptor( + handle, + ctypes.byref(descriptor), + x.descriptor, + y.descriptor, + param_tensor.descriptor, + ) + ) + + x.destroy_desc() + y.destroy_desc() + param_tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetRotmWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, x.device) + + def lib_rotm(): + check_error( + LIBINFINIOP.infiniopRotm( + descriptor, + workspace.data(), + workspace.size(), + x.data(), + y.data(), + param_tensor.data(), + None, + ) + ) + + lib_rotm() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(x.actual_tensor(), x.torch_tensor(), atol=atol, rtol=rtol) + debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + assert torch.allclose(x.actual_tensor(), x.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: torch_rotm(x.torch_tensor(), y.torch_tensor(), param), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lambda: lib_rotm(), device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(LIBINFINIOP.infiniopDestroyRotmDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") From 079f5f4aa4aff721d7d3cee9b421909c4897c702 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Tue, 28 Apr 2026 09:31:14 +0000 Subject: [PATCH 10/25] Add `rotmg` operator --- include/infiniop.h | 1 + include/infiniop/ops/rotmg.h | 30 ++ src/infiniop/devices/metax/metax_ht2mc.h | 2 + src/infiniop/ops/rotmg/bang/rotmg_bang.h | 8 + src/infiniop/ops/rotmg/bang/rotmg_bang.mlu | 95 +++++++ .../ops/rotmg/bang/rotmg_bang_kernel.mlu | 156 +++++++++++ src/infiniop/ops/rotmg/cpu/rotmg_cpu.cc | 221 +++++++++++++++ src/infiniop/ops/rotmg/cpu/rotmg_cpu.h | 8 + src/infiniop/ops/rotmg/metax/rotmg_metax.cc | 75 +++++ src/infiniop/ops/rotmg/metax/rotmg_metax.h | 8 + src/infiniop/ops/rotmg/operator.cc | 127 +++++++++ src/infiniop/ops/rotmg/rotmg.h | 100 +++++++ test/infiniop/libinfiniop/op_register.py | 38 +++ test/infiniop/rotmg.py | 262 ++++++++++++++++++ 14 files changed, 1131 insertions(+) create mode 100644 include/infiniop/ops/rotmg.h create mode 100644 src/infiniop/ops/rotmg/bang/rotmg_bang.h create mode 100644 src/infiniop/ops/rotmg/bang/rotmg_bang.mlu create mode 100644 src/infiniop/ops/rotmg/bang/rotmg_bang_kernel.mlu create mode 100644 src/infiniop/ops/rotmg/cpu/rotmg_cpu.cc create mode 100644 src/infiniop/ops/rotmg/cpu/rotmg_cpu.h create mode 100644 src/infiniop/ops/rotmg/metax/rotmg_metax.cc create mode 100644 src/infiniop/ops/rotmg/metax/rotmg_metax.h create mode 100644 src/infiniop/ops/rotmg/operator.cc create mode 100644 src/infiniop/ops/rotmg/rotmg.h create mode 100644 test/infiniop/rotmg.py diff --git a/include/infiniop.h b/include/infiniop.h index 8115deb61..33ed4e315 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -103,6 +103,7 @@ #include "infiniop/ops/rot.h" #include "infiniop/ops/rotg.h" #include "infiniop/ops/rotm.h" +#include "infiniop/ops/rotmg.h" #include "infiniop/ops/scatter.h" #include "infiniop/ops/selu.h" #include "infiniop/ops/sigmoid.h" diff --git a/include/infiniop/ops/rotmg.h b/include/infiniop/ops/rotmg.h new file mode 100644 index 000000000..93f425216 --- /dev/null +++ b/include/infiniop/ops/rotmg.h @@ -0,0 +1,30 @@ +#ifndef __INFINIOP_ROTMG_API_H__ +#define __INFINIOP_ROTMG_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopRotmgDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateRotmgDescriptor(infiniopHandle_t handle, + infiniopRotmgDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t d1, + infiniopTensorDescriptor_t d2, + infiniopTensorDescriptor_t x1, + infiniopTensorDescriptor_t y1, + infiniopTensorDescriptor_t param); + +__INFINI_C __export infiniStatus_t infiniopGetRotmgWorkspaceSize(infiniopRotmgDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopRotmg(infiniopRotmgDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *d1, + void *d2, + void *x1, + const void *y1, + void *param, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyRotmgDescriptor(infiniopRotmgDescriptor_t desc); + +#endif // __INFINIOP_ROTMG_API_H__ \ No newline at end of file diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index af7d995d4..d2be2ab26 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -139,6 +139,8 @@ #define hcblasDrotg mcblasDrotg #define hcblasSrotm mcblasSrotm #define hcblasDrotm mcblasDrotm +#define hcblasSrotmg mcblasSrotmg +#define hcblasDrotmg mcblasDrotmg #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS #define HCBLAS_OP_T MCBLAS_OP_T #define HCBLAS_OP_N MCBLAS_OP_N diff --git a/src/infiniop/ops/rotmg/bang/rotmg_bang.h b/src/infiniop/ops/rotmg/bang/rotmg_bang.h new file mode 100644 index 000000000..8d1c678d6 --- /dev/null +++ b/src/infiniop/ops/rotmg/bang/rotmg_bang.h @@ -0,0 +1,8 @@ +#ifndef __ROTMG_BANG_H__ +#define __ROTMG_BANG_H__ + +#include "../rotmg.h" + +DESCRIPTOR(bang) + +#endif // __ROTMG_BANG_H__ diff --git a/src/infiniop/ops/rotmg/bang/rotmg_bang.mlu b/src/infiniop/ops/rotmg/bang/rotmg_bang.mlu new file mode 100644 index 000000000..367fa9fff --- /dev/null +++ b/src/infiniop/ops/rotmg/bang/rotmg_bang.mlu @@ -0,0 +1,95 @@ +#include "../../../devices/bang/common_bang.h" +#include "rotmg_bang.h" +#include "rotmg_bang_kernel.mlu" + +namespace op::rotmg::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t d1_desc, + infiniopTensorDescriptor_t d2_desc, + infiniopTensorDescriptor_t x1_desc, + infiniopTensorDescriptor_t y1_desc, + infiniopTensorDescriptor_t param_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = RotmgInfo::createRotmgInfo(d1_desc, d2_desc, x1_desc, y1_desc, param_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateRotmg( + Tdata *d1, + Tdata *d2, + Tdata *x1, + const Tdata *y1, + Tdata *param, + cnrtQueue_t queue) { + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + k_dim.x = 1; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeBlock; + + rotmgKernel<<>>( + d1, + d2, + x1, + y1, + param); + + cnrtQueueSync(queue); + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_ROTMG(TDATA) \ + calculateRotmg((TDATA *)d1, \ + (TDATA *)d2, \ + (TDATA *)x1, \ + (const TDATA *)y1, \ + (TDATA *)param, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *d1, + void *d2, + void *x1, + const void *y1, + void *param, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_ROTMG(half); + case INFINI_DTYPE_F32: + return CALCULATE_ROTMG(float); + case INFINI_DTYPE_BF16: + return CALCULATE_ROTMG(bfloat16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_ROTMG + +} // namespace op::rotmg::bang diff --git a/src/infiniop/ops/rotmg/bang/rotmg_bang_kernel.mlu b/src/infiniop/ops/rotmg/bang/rotmg_bang_kernel.mlu new file mode 100644 index 000000000..f45d131da --- /dev/null +++ b/src/infiniop/ops/rotmg/bang/rotmg_bang_kernel.mlu @@ -0,0 +1,156 @@ +#include "../../../devices/bang/common_bang.h" + +#include +#include + +template +__mlu_global__ void rotmgKernel( + Tdata *d1, + Tdata *d2, + Tdata *x1, + const Tdata *y1, + Tdata *param) { + + using Tcompute = std::conditional_t, double, float>; + + const Tcompute zero = static_cast(0.0f); + const Tcompute one = static_cast(1.0f); + const Tcompute two = static_cast(2.0f); + const Tcompute gam = static_cast(4096.0f); + const Tcompute gamsq = static_cast(1.67772e7f); + const Tcompute rgamsq = static_cast(5.96046e-8f); + + Tcompute d1_val = static_cast(*d1); + Tcompute d2_val = static_cast(*d2); + Tcompute x1_val = static_cast(*x1); + const Tcompute y1_val = static_cast(*y1); + + Tcompute sflag = zero; + Tcompute sh11 = zero; + Tcompute sh12 = zero; + Tcompute sh21 = zero; + Tcompute sh22 = zero; + + if (d1_val < zero) { + sflag = -one; + d1_val = zero; + d2_val = zero; + x1_val = zero; + } else { + const Tcompute sp2 = d2_val * y1_val; + if (sp2 == zero) { + param[0] = static_cast(-two); + return; + } + + const Tcompute sp1 = d1_val * x1_val; + const Tcompute sq2 = sp2 * y1_val; + const Tcompute sq1 = sp1 * x1_val; + + if (std::fabs(sq1) > std::fabs(sq2)) { + sh21 = -y1_val / x1_val; + sh12 = sp2 / sp1; + const Tcompute su = one - sh12 * sh21; + + if (su > zero) { + sflag = zero; + d1_val = d1_val / su; + d2_val = d2_val / su; + x1_val = x1_val * su; + } else { + sflag = -one; + sh11 = zero; + sh12 = zero; + sh21 = zero; + sh22 = zero; + d1_val = zero; + d2_val = zero; + x1_val = zero; + } + } else { + if (sq2 < zero) { + sflag = -one; + d1_val = zero; + d2_val = zero; + x1_val = zero; + } else { + sflag = one; + sh11 = sp1 / sp2; + sh22 = x1_val / y1_val; + const Tcompute su = one + sh11 * sh22; + const Tcompute stemp = d2_val / su; + d2_val = d1_val / su; + d1_val = stemp; + x1_val = y1_val * su; + } + } + + if (d1_val != zero) { + while (d1_val <= rgamsq || d1_val >= gamsq) { + if (sflag == zero) { + sh11 = one; + sh22 = one; + sflag = -one; + } else { + sh21 = -one; + sh12 = one; + sflag = -one; + } + + if (d1_val <= rgamsq) { + d1_val = d1_val * gam * gam; + x1_val = x1_val / gam; + sh11 = sh11 / gam; + sh12 = sh12 / gam; + } else { + d1_val = d1_val / (gam * gam); + x1_val = x1_val * gam; + sh11 = sh11 * gam; + sh12 = sh12 * gam; + } + } + } + + if (d2_val != zero) { + while (std::fabs(d2_val) <= rgamsq || std::fabs(d2_val) >= gamsq) { + if (sflag == zero) { + sh11 = one; + sh22 = one; + sflag = -one; + } else { + sh21 = -one; + sh12 = one; + sflag = -one; + } + + if (std::fabs(d2_val) <= rgamsq) { + d2_val = d2_val * gam * gam; + sh21 = sh21 / gam; + sh22 = sh22 / gam; + } else { + d2_val = d2_val / (gam * gam); + sh21 = sh21 * gam; + sh22 = sh22 * gam; + } + } + } + } + + if (sflag < zero) { + param[1] = static_cast(sh11); + param[2] = static_cast(sh21); + param[3] = static_cast(sh12); + param[4] = static_cast(sh22); + } else if (sflag == zero) { + param[2] = static_cast(sh21); + param[3] = static_cast(sh12); + } else { + param[1] = static_cast(sh11); + param[4] = static_cast(sh22); + } + + param[0] = static_cast(sflag); + *d1 = static_cast(d1_val); + *d2 = static_cast(d2_val); + *x1 = static_cast(x1_val); +} \ No newline at end of file diff --git a/src/infiniop/ops/rotmg/cpu/rotmg_cpu.cc b/src/infiniop/ops/rotmg/cpu/rotmg_cpu.cc new file mode 100644 index 000000000..0d291b74c --- /dev/null +++ b/src/infiniop/ops/rotmg/cpu/rotmg_cpu.cc @@ -0,0 +1,221 @@ +#include "rotmg_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +#include + +namespace op::rotmg::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t d1_desc, + infiniopTensorDescriptor_t d2_desc, + infiniopTensorDescriptor_t x1_desc, + infiniopTensorDescriptor_t y1_desc, + infiniopTensorDescriptor_t param_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = RotmgInfo::createRotmgInfo(d1_desc, d2_desc, x1_desc, y1_desc, param_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateRotmg( + Tdata *d1, + Tdata *d2, + Tdata *x1, + const Tdata *y1, + Tdata *param) { + + using Tcompute = std::conditional_t, double, float>; + + const Tcompute zero = utils::cast(0.0f); + const Tcompute one = utils::cast(1.0f); + const Tcompute two = utils::cast(2.0f); + const Tcompute gam = utils::cast(4096.0f); + const Tcompute gamsq = utils::cast(1.67772e7f); + const Tcompute rgamsq = utils::cast(5.96046e-8f); + + Tcompute d1_val = utils::cast(d1[0]); + Tcompute d2_val = utils::cast(d2[0]); + Tcompute x1_val = utils::cast(x1[0]); + const Tcompute y1_val = utils::cast(y1[0]); + + Tcompute sflag; + Tcompute sh11 = zero; + Tcompute sh12 = zero; + Tcompute sh21 = zero; + Tcompute sh22 = zero; + + if (d1_val < zero) { + sflag = -one; + d1_val = zero; + d2_val = zero; + x1_val = zero; + } else { + const Tcompute sp2 = d2_val * y1_val; + if (sp2 == zero) { + param[0] = utils::cast(-two); + return INFINI_STATUS_SUCCESS; + } + + const Tcompute sp1 = d1_val * x1_val; + const Tcompute sq2 = sp2 * y1_val; + const Tcompute sq1 = sp1 * x1_val; + + if (std::abs(sq1) > std::abs(sq2)) { + sh21 = -y1_val / x1_val; + sh12 = sp2 / sp1; + const Tcompute su = one - sh12 * sh21; + + if (su > zero) { + sflag = zero; + d1_val = d1_val / su; + d2_val = d2_val / su; + x1_val = x1_val * su; + } else { + sflag = -one; + sh11 = zero; + sh12 = zero; + sh21 = zero; + sh22 = zero; + d1_val = zero; + d2_val = zero; + x1_val = zero; + } + } else { + if (sq2 < zero) { + sflag = -one; + d1_val = zero; + d2_val = zero; + x1_val = zero; + } else { + sflag = one; + sh11 = sp1 / sp2; + sh22 = x1_val / y1_val; + const Tcompute su = one + sh11 * sh22; + const Tcompute stemp = d2_val / su; + d2_val = d1_val / su; + d1_val = stemp; + x1_val = y1_val * su; + } + } + + if (d1_val != zero) { + while (d1_val <= rgamsq || d1_val >= gamsq) { + if (sflag == zero) { + sh11 = one; + sh22 = one; + sflag = -one; + } else { + sh21 = -one; + sh12 = one; + sflag = -one; + } + if (d1_val <= rgamsq) { + d1_val = d1_val * gam * gam; + x1_val = x1_val / gam; + sh11 = sh11 / gam; + sh12 = sh12 / gam; + } else { + d1_val = d1_val / (gam * gam); + x1_val = x1_val * gam; + sh11 = sh11 * gam; + sh12 = sh12 * gam; + } + } + } + + if (d2_val != zero) { + while (std::abs(d2_val) <= rgamsq || std::abs(d2_val) >= gamsq) { + if (sflag == zero) { + sh11 = one; + sh22 = one; + sflag = -one; + } else { + sh21 = -one; + sh12 = one; + sflag = -one; + } + if (std::abs(d2_val) <= rgamsq) { + d2_val = d2_val * gam * gam; + sh21 = sh21 / gam; + sh22 = sh22 / gam; + } else { + d2_val = d2_val / (gam * gam); + sh21 = sh21 * gam; + sh22 = sh22 * gam; + } + } + } + } + + if (sflag < zero) { + param[1] = utils::cast(sh11); + param[2] = utils::cast(sh21); + param[3] = utils::cast(sh12); + param[4] = utils::cast(sh22); + } else if (sflag == zero) { + param[2] = utils::cast(sh21); + param[3] = utils::cast(sh12); + } else { + param[1] = utils::cast(sh11); + param[4] = utils::cast(sh22); + } + + param[0] = utils::cast(sflag); + d1[0] = utils::cast(d1_val); + d2[0] = utils::cast(d2_val); + x1[0] = utils::cast(x1_val); + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_ROTMG(TDATA) \ + calculateRotmg((TDATA *)d1, \ + (TDATA *)d2, \ + (TDATA *)x1, \ + (const TDATA *)y1, \ + (TDATA *)param) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *d1, + void *d2, + void *x1, + const void *y1, + void *param, + void *stream) const { + + (void)workspace; + (void)workspace_size; + (void)stream; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_ROTMG(fp16_t); + case INFINI_DTYPE_F32: + return CALCULATE_ROTMG(float); + case INFINI_DTYPE_F64: + return CALCULATE_ROTMG(double); + case INFINI_DTYPE_BF16: + return CALCULATE_ROTMG(bf16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_ROTMG + +} // namespace op::rotmg::cpu diff --git a/src/infiniop/ops/rotmg/cpu/rotmg_cpu.h b/src/infiniop/ops/rotmg/cpu/rotmg_cpu.h new file mode 100644 index 000000000..49acbc062 --- /dev/null +++ b/src/infiniop/ops/rotmg/cpu/rotmg_cpu.h @@ -0,0 +1,8 @@ +#ifndef __ROTMG_CPU_H__ +#define __ROTMG_CPU_H__ + +#include "../rotmg.h" + +DESCRIPTOR(cpu) + +#endif // __ROTMG_CPU_H__ diff --git a/src/infiniop/ops/rotmg/metax/rotmg_metax.cc b/src/infiniop/ops/rotmg/metax/rotmg_metax.cc new file mode 100644 index 000000000..5136711bd --- /dev/null +++ b/src/infiniop/ops/rotmg/metax/rotmg_metax.cc @@ -0,0 +1,75 @@ +#include "rotmg_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::rotmg::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t d1_desc, + infiniopTensorDescriptor_t d2_desc, + infiniopTensorDescriptor_t x1_desc, + infiniopTensorDescriptor_t y1_desc, + infiniopTensorDescriptor_t param_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = RotmgInfo::createRotmgInfo(d1_desc, d2_desc, x1_desc, y1_desc, param_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *d1, + void *d2, + void *x1, + const void *y1, + void *param, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasSrotmg(handle, (float *)d1, (float *)d2, (float *)x1, (const float *)y1, (float *)param)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasDrotmg(handle, (double *)d1, (double *)d2, (double *)x1, (const double *)y1, (double *)param)); + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::rotmg::metax diff --git a/src/infiniop/ops/rotmg/metax/rotmg_metax.h b/src/infiniop/ops/rotmg/metax/rotmg_metax.h new file mode 100644 index 000000000..37c26c968 --- /dev/null +++ b/src/infiniop/ops/rotmg/metax/rotmg_metax.h @@ -0,0 +1,8 @@ +#ifndef __ROTMG_METAX_H__ +#define __ROTMG_METAX_H__ + +#include "../rotmg.h" + +DESCRIPTOR(metax) + +#endif // __ROTMG_METAX_H__ diff --git a/src/infiniop/ops/rotmg/operator.cc b/src/infiniop/ops/rotmg/operator.cc new file mode 100644 index 000000000..98c25bd5b --- /dev/null +++ b/src/infiniop/ops/rotmg/operator.cc @@ -0,0 +1,127 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/rotmg.h" + +#ifdef ENABLE_CPU_API +#include "cpu/rotmg_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/rotmg_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/rotmg_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateRotmgDescriptor( + infiniopHandle_t handle, + infiniopRotmgDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t d1_desc, + infiniopTensorDescriptor_t d2_desc, + infiniopTensorDescriptor_t x1_desc, + infiniopTensorDescriptor_t y1_desc, + infiniopTensorDescriptor_t param_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::rotmg::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + d1_desc, d2_desc, x1_desc, y1_desc, param_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetRotmgWorkspaceSize(infiniopRotmgDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__INFINI_C infiniStatus_t infiniopRotmg( + infiniopRotmgDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *d1, + void *d2, + void *x1, + const void *y1, + void *param, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, d1, d2, x1, y1, param, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyRotmgDescriptor(infiniopRotmgDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/src/infiniop/ops/rotmg/rotmg.h b/src/infiniop/ops/rotmg/rotmg.h new file mode 100644 index 000000000..8b3e93a08 --- /dev/null +++ b/src/infiniop/ops/rotmg/rotmg.h @@ -0,0 +1,100 @@ +#ifndef __ROTMG_H__ +#define __ROTMG_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/rotmg.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::rotmg::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + RotmgInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + RotmgInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t d1_desc, \ + infiniopTensorDescriptor_t d2_desc, \ + infiniopTensorDescriptor_t x1_desc, \ + infiniopTensorDescriptor_t y1_desc, \ + infiniopTensorDescriptor_t param_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *d1, \ + void *d2, \ + void *x1, \ + const void *y1, \ + void *param, \ + void *stream) const; \ + }; \ + } + +class RotmgInfo { +private: + infiniDtype_t _dtype; + + explicit RotmgInfo(infiniDtype_t dtype) : _dtype(dtype) {} + +public: + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static ResultType createRotmgInfo( + infiniopTensorDescriptor_t d1_desc, + infiniopTensorDescriptor_t d2_desc, + infiniopTensorDescriptor_t x1_desc, + infiniopTensorDescriptor_t y1_desc, + infiniopTensorDescriptor_t param_desc) { + + CHECK_OR_RETURN(d1_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(d2_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(x1_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y1_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(param_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = d1_desc->dtype(); + + CHECK_OR_RETURN(d2_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(x1_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(y1_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(param_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_OR_RETURN(param_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + CHECK_OR_RETURN(d1_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(d2_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x1_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y1_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(param_desc->numel() == 5, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(param_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + + RotmgInfo info(dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __ROTMG_H__ diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 9afe97100..18ba30e24 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2493,3 +2493,41 @@ def rotm_(lib): lib.infiniopDestroyRotmDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def rotmg_(lib): + lib.infiniopCreateRotmgDescriptor.restype = c_int32 + lib.infiniopCreateRotmgDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetRotmgWorkspaceSize.restype = c_int32 + lib.infiniopGetRotmgWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopRotmg.restype = c_int32 + lib.infiniopRotmg.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyRotmgDescriptor.restype = c_int32 + lib.infiniopDestroyRotmgDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/rotmg.py b/test/infiniop/rotmg.py new file mode 100644 index 000000000..ff0dc8976 --- /dev/null +++ b/test/infiniop/rotmg.py @@ -0,0 +1,262 @@ +import ctypes +import math +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + test_operator, +) + +_TEST_CASES = [ + (1.0, 2.0, 3.0, 4.0), + (2.5, 0.5, -1.2, 0.8), + (3.0, 4.0, 0.0, 2.0), + (1.5, 1.5, 2.0, -3.0), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + InfiniDtype.F64: {"atol": 1e-12, "rtol": 1e-12}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + + +def _torch_dtype(dtype): + if dtype == InfiniDtype.F64: + return torch.float64 + return torch.float32 + + +def torch_rotmg(d1, d2, x1, y1): + zero = 0.0 + one = 1.0 + two = 2.0 + gam = 4096.0 + gamsq = 1.67772e7 + rgamsq = 5.96046e-8 + + sparam = [0.0] * 5 + sh11 = sh12 = sh21 = sh22 = 0.0 + + if d1 < zero: + sflag = -one + d1 = d2 = x1 = zero + else: + sp2 = d2 * y1 + if sp2 == zero: + sparam[0] = -two + return d1, d2, x1, sparam + + sp1 = d1 * x1 + sq2 = sp2 * y1 + sq1 = sp1 * x1 + + if abs(sq1) > abs(sq2): + sh21 = -y1 / x1 + sh12 = sp2 / sp1 + su = one - sh12 * sh21 + if su > zero: + sflag = zero + d1 = d1 / su + d2 = d2 / su + x1 = x1 * su + else: + sflag = -one + sh11 = sh12 = sh21 = sh22 = zero + d1 = d2 = x1 = zero + else: + if sq2 < zero: + sflag = -one + d1 = d2 = x1 = zero + else: + sflag = one + sh11 = sp1 / sp2 + sh22 = x1 / y1 + su = one + sh11 * sh22 + stemp = d2 / su + d2 = d1 / su + d1 = stemp + x1 = y1 * su + + if d1 != zero: + while d1 <= rgamsq or d1 >= gamsq: + if sflag == zero: + sh11 = one + sh22 = one + sflag = -one + else: + sh21 = -one + sh12 = one + sflag = -one + if d1 <= rgamsq: + d1 = d1 * gam * gam + x1 = x1 / gam + sh11 = sh11 / gam + sh12 = sh12 / gam + else: + d1 = d1 / (gam * gam) + x1 = x1 * gam + sh11 = sh11 * gam + sh12 = sh12 * gam + + if d2 != zero: + while abs(d2) <= rgamsq or abs(d2) >= gamsq: + if sflag == zero: + sh11 = one + sh22 = one + sflag = -one + else: + sh21 = -one + sh12 = one + sflag = -one + if abs(d2) <= rgamsq: + d2 = d2 * gam * gam + sh21 = sh21 / gam + sh22 = sh22 / gam + else: + d2 = d2 / (gam * gam) + sh21 = sh21 * gam + sh22 = sh22 * gam + + if sflag < zero: + sparam[1] = sh11 + sparam[2] = sh21 + sparam[3] = sh12 + sparam[4] = sh22 + elif sflag == zero: + sparam[2] = sh21 + sparam[3] = sh12 + else: + sparam[1] = sh11 + sparam[4] = sh22 + + sparam[0] = sflag + return d1, d2, x1, sparam + + +def test(handle, device, d1_0, d2_0, x1_0, y1_0, dtype=torch.float32, sync=None): + exp_d1, exp_d2, exp_x1, exp_sparam = torch_rotmg(d1_0, d2_0, x1_0, y1_0) + + scalar_dtype = _torch_dtype(dtype) + d1_torch = torch.tensor([d1_0], dtype=scalar_dtype) + d2_torch = torch.tensor([d2_0], dtype=scalar_dtype) + x1_torch = torch.tensor([x1_0], dtype=scalar_dtype) + y1_torch = torch.tensor([y1_0], dtype=scalar_dtype) + d1 = TestTensor( + d1_torch.shape, + d1_torch.stride(), + dtype, + device, + mode="manual", + set_tensor=d1_torch, + ) + d2 = TestTensor( + d2_torch.shape, + d2_torch.stride(), + dtype, + device, + mode="manual", + set_tensor=d2_torch, + ) + x1 = TestTensor( + x1_torch.shape, + x1_torch.stride(), + dtype, + device, + mode="manual", + set_tensor=x1_torch, + ) + y1 = TestTensor( + y1_torch.shape, + y1_torch.stride(), + dtype, + device, + mode="manual", + set_tensor=y1_torch, + ) + param = TestTensor((5,), (1,), dtype, device, mode="zeros") + + print( + f"Testing Rotmg on {InfiniDeviceNames[device]} with d1:{d1_0} d2:{d2_0} x1:{x1_0} y1:{y1_0} dtype:{InfiniDtypeNames[dtype]}" + ) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateRotmgDescriptor( + handle, + ctypes.byref(descriptor), + d1.descriptor, + d2.descriptor, + x1.descriptor, + y1.descriptor, + param.descriptor, + ) + ) + + d1.destroy_desc() + d2.destroy_desc() + x1.destroy_desc() + y1.destroy_desc() + param.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetRotmgWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + check_error( + LIBINFINIOP.infiniopRotmg( + descriptor, + workspace.data(), + workspace.size(), + d1.data(), + d2.data(), + x1.data(), + y1.data(), + param.data(), + None, + ) + ) + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + assert math.isclose(d1.actual_tensor().item(), exp_d1, rel_tol=rtol, abs_tol=atol) + assert math.isclose(d2.actual_tensor().item(), exp_d2, rel_tol=rtol, abs_tol=atol) + assert math.isclose(x1.actual_tensor().item(), exp_x1, rel_tol=rtol, abs_tol=atol) + for i in range(5): + assert math.isclose( + param.actual_tensor()[i].item(), exp_sparam[i], rel_tol=rtol, abs_tol=atol + ) + + check_error(LIBINFINIOP.infiniopDestroyRotmgDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + print("\033[92mTest passed!\033[0m") From 0153263c2006e2fcb855861a47ff3373101e0801 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Tue, 28 Apr 2026 10:15:33 +0000 Subject: [PATCH 11/25] Add `scal` operator --- include/infiniop.h | 1 + include/infiniop/ops/scal.h | 24 +++ src/infiniop/devices/metax/metax_ht2mc.h | 2 + src/infiniop/ops/scal/bang/scal_bang.h | 8 + src/infiniop/ops/scal/bang/scal_bang.mlu | 92 ++++++++++ .../ops/scal/bang/scal_bang_kernel.mlu | 73 ++++++++ src/infiniop/ops/scal/cpu/scal_cpu.cc | 81 +++++++++ src/infiniop/ops/scal/cpu/scal_cpu.h | 8 + src/infiniop/ops/scal/metax/scal_metax.cc | 71 ++++++++ src/infiniop/ops/scal/metax/scal_metax.h | 8 + src/infiniop/ops/scal/operator.cc | 121 +++++++++++++ src/infiniop/ops/scal/scal.h | 89 ++++++++++ test/infiniop/libinfiniop/op_register.py | 32 ++++ test/infiniop/scal.py | 164 ++++++++++++++++++ 14 files changed, 774 insertions(+) create mode 100644 include/infiniop/ops/scal.h create mode 100644 src/infiniop/ops/scal/bang/scal_bang.h create mode 100644 src/infiniop/ops/scal/bang/scal_bang.mlu create mode 100644 src/infiniop/ops/scal/bang/scal_bang_kernel.mlu create mode 100644 src/infiniop/ops/scal/cpu/scal_cpu.cc create mode 100644 src/infiniop/ops/scal/cpu/scal_cpu.h create mode 100644 src/infiniop/ops/scal/metax/scal_metax.cc create mode 100644 src/infiniop/ops/scal/metax/scal_metax.h create mode 100644 src/infiniop/ops/scal/operator.cc create mode 100644 src/infiniop/ops/scal/scal.h create mode 100644 test/infiniop/scal.py diff --git a/include/infiniop.h b/include/infiniop.h index 33ed4e315..bbf837bea 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -104,6 +104,7 @@ #include "infiniop/ops/rotg.h" #include "infiniop/ops/rotm.h" #include "infiniop/ops/rotmg.h" +#include "infiniop/ops/scal.h" #include "infiniop/ops/scatter.h" #include "infiniop/ops/selu.h" #include "infiniop/ops/sigmoid.h" diff --git a/include/infiniop/ops/scal.h b/include/infiniop/ops/scal.h new file mode 100644 index 000000000..f7903de56 --- /dev/null +++ b/include/infiniop/ops/scal.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_SCAL_API_H__ +#define __INFINIOP_SCAL_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopScalDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateScalDescriptor(infiniopHandle_t handle, + infiniopScalDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t alpha, + infiniopTensorDescriptor_t x); + +__INFINI_C __export infiniStatus_t infiniopGetScalWorkspaceSize(infiniopScalDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopScal(infiniopScalDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *alpha, + void *x, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyScalDescriptor(infiniopScalDescriptor_t desc); + +#endif // __INFINIOP_SCAL_API_H__ diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index d2be2ab26..6ef455a14 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -141,6 +141,8 @@ #define hcblasDrotm mcblasDrotm #define hcblasSrotmg mcblasSrotmg #define hcblasDrotmg mcblasDrotmg +#define hcblasSscal mcblasSscal +#define hcblasDscal mcblasDscal #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS #define HCBLAS_OP_T MCBLAS_OP_T #define HCBLAS_OP_N MCBLAS_OP_N diff --git a/src/infiniop/ops/scal/bang/scal_bang.h b/src/infiniop/ops/scal/bang/scal_bang.h new file mode 100644 index 000000000..c61532ce7 --- /dev/null +++ b/src/infiniop/ops/scal/bang/scal_bang.h @@ -0,0 +1,8 @@ +#ifndef __SCAL_BANG_H__ +#define __SCAL_BANG_H__ + +#include "../scal.h" + +DESCRIPTOR(bang) + +#endif // __SCAL_BANG_H__ \ No newline at end of file diff --git a/src/infiniop/ops/scal/bang/scal_bang.mlu b/src/infiniop/ops/scal/bang/scal_bang.mlu new file mode 100644 index 000000000..b9407ca34 --- /dev/null +++ b/src/infiniop/ops/scal/bang/scal_bang.mlu @@ -0,0 +1,92 @@ +#include "../../../devices/bang/common_bang.h" +#include "scal_bang.h" +#include "scal_bang_kernel.mlu" + +namespace op::scal::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t alpha_desc, + infiniopTensorDescriptor_t x_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = ScalInfo::createScalInfo(alpha_desc, x_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateScal( + const ScalInfo &info, + const Tdata *alpha, + Tdata *x, + cnrtQueue_t queue) { + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeUnion1; + + if (info.getIncx() == 1) { + scalKernelContiguous<<>>( + info.getSize(), + alpha, + x); + } else { + scalKernelStrided<<>>( + info.getSize(), + alpha, + x, + info.getIncx()); + } + + cnrtQueueSync(queue); + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_SCAL(TDATA) \ + calculateScal(_info, \ + (const TDATA *)alpha, \ + (TDATA *)x, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *alpha, + void *x, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_SCAL(half); + case INFINI_DTYPE_F32: + return CALCULATE_SCAL(float); + case INFINI_DTYPE_BF16: + return CALCULATE_SCAL(bfloat16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_SCAL + +} // namespace op::scal::bang diff --git a/src/infiniop/ops/scal/bang/scal_bang_kernel.mlu b/src/infiniop/ops/scal/bang/scal_bang_kernel.mlu new file mode 100644 index 000000000..1e2f8c9a7 --- /dev/null +++ b/src/infiniop/ops/scal/bang/scal_bang_kernel.mlu @@ -0,0 +1,73 @@ +#include "../../../devices/bang/common_bang.h" +#include "scal_bang.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void scalKernelContiguous( + size_t n, + const Tdata *alpha, + Tdata *x) { + + Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); + size_t max_chunk_elements = nram_usable / sizeof(Tdata); + + int align_elements = ALIGN_SIZE / sizeof(Tdata); + if (align_elements == 0) { + align_elements = 1; + } + max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int core_elements = elements_per_core + (taskId < remain ? 1 : 0); + int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + + if (core_elements <= 0) { + return; + } + + int chunks = core_elements / max_chunk_elements; + int chunk_rem = core_elements % max_chunk_elements; + + for (int c = 0; c < chunks; c++) { + size_t current_offset = core_offset + c * max_chunk_elements; + __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + + __bang_mul_scalar(nram_x, nram_x, alpha[0], max_chunk_elements); + + __memcpy(x + current_offset, nram_x, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + } + + if (chunk_rem > 0) { + size_t current_offset = core_offset + chunks * max_chunk_elements; + int align_rem = ((chunk_rem + align_elements - 1) / align_elements) * align_elements; + + __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + + __bang_mul_scalar(nram_x, nram_x, alpha[0], align_rem); + + __memcpy(x + current_offset, nram_x, chunk_rem * sizeof(Tdata), NRAM2GDRAM); + } +} + +template +__mlu_global__ void scalKernelStrided( + size_t n, + const Tdata *alpha, + Tdata *x, + size_t incx) { + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0); + int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; + + for (int i = start_idx; i < start_idx + actual_tasks; ++i) { + size_t offset = i * incx; + + x[offset] *= alpha[0]; + } +} diff --git a/src/infiniop/ops/scal/cpu/scal_cpu.cc b/src/infiniop/ops/scal/cpu/scal_cpu.cc new file mode 100644 index 000000000..42678afeb --- /dev/null +++ b/src/infiniop/ops/scal/cpu/scal_cpu.cc @@ -0,0 +1,81 @@ +#include "scal_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::scal::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t alpha_desc, + infiniopTensorDescriptor_t x_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = ScalInfo::createScalInfo(alpha_desc, x_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateScal( + const ScalInfo &info, + const Tdata *alpha, + Tdata *x) { + + const ptrdiff_t size = info.getSize(); + const ptrdiff_t incx = info.getIncx(); + + for (ptrdiff_t i = 0; i < size; ++i) { + const ptrdiff_t idx = i * incx; + + if constexpr (std::is_same_v || std::is_same_v) { + x[idx] = utils::cast(utils::cast(x[idx]) * utils::cast(alpha[0])); + } else { + x[idx] = x[idx] * alpha[0]; + } + } + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_SCAL(TDATA) \ + calculateScal(_info, \ + (const TDATA *)alpha, \ + (TDATA *)x) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *alpha, + void *x, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_SCAL(fp16_t); + case INFINI_DTYPE_F32: + return CALCULATE_SCAL(float); + case INFINI_DTYPE_F64: + return CALCULATE_SCAL(double); + case INFINI_DTYPE_BF16: + return CALCULATE_SCAL(bf16_t); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_SCAL + +} // namespace op::scal::cpu diff --git a/src/infiniop/ops/scal/cpu/scal_cpu.h b/src/infiniop/ops/scal/cpu/scal_cpu.h new file mode 100644 index 000000000..efbba93ba --- /dev/null +++ b/src/infiniop/ops/scal/cpu/scal_cpu.h @@ -0,0 +1,8 @@ +#ifndef __SCAL_CPU_H__ +#define __SCAL_CPU_H__ + +#include "../scal.h" + +DESCRIPTOR(cpu) + +#endif // __SCAL_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/scal/metax/scal_metax.cc b/src/infiniop/ops/scal/metax/scal_metax.cc new file mode 100644 index 000000000..b534a61aa --- /dev/null +++ b/src/infiniop/ops/scal/metax/scal_metax.cc @@ -0,0 +1,71 @@ +#include "scal_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::scal::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t alpha_desc, + infiniopTensorDescriptor_t x_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = ScalInfo::createScalInfo(alpha_desc, x_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + const void *alpha, + void *x, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const size_t size = _info.getSize(); + const ptrdiff_t incx = _info.getIncx(); + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasSscal(handle, size, (const float *)alpha, (float *)x, incx)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasDscal(handle, size, (const double *)alpha, (double *)x, incx)); + break; + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::scal::metax diff --git a/src/infiniop/ops/scal/metax/scal_metax.h b/src/infiniop/ops/scal/metax/scal_metax.h new file mode 100644 index 000000000..8e63ac399 --- /dev/null +++ b/src/infiniop/ops/scal/metax/scal_metax.h @@ -0,0 +1,8 @@ +#ifndef __SCAL_METAX_H__ +#define __SCAL_METAX_H__ + +#include "../scal.h" + +DESCRIPTOR(metax) + +#endif // __SCAL_METAX_H__ \ No newline at end of file diff --git a/src/infiniop/ops/scal/operator.cc b/src/infiniop/ops/scal/operator.cc new file mode 100644 index 000000000..33fa8618a --- /dev/null +++ b/src/infiniop/ops/scal/operator.cc @@ -0,0 +1,121 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/scal.h" + +#ifdef ENABLE_CPU_API +#include "cpu/scal_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/scal_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/scal_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateScalDescriptor( + infiniopHandle_t handle, + infiniopScalDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t alpha_desc, + infiniopTensorDescriptor_t x_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::scal::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + alpha_desc, x_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetScalWorkspaceSize(infiniopScalDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__INFINI_C infiniStatus_t infiniopScal( + infiniopScalDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *alpha, + void *x, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, alpha, x, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyScalDescriptor(infiniopScalDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/src/infiniop/ops/scal/scal.h b/src/infiniop/ops/scal/scal.h new file mode 100644 index 000000000..25551d0fa --- /dev/null +++ b/src/infiniop/ops/scal/scal.h @@ -0,0 +1,89 @@ +#ifndef __SCAL_H__ +#define __SCAL_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/scal.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::scal::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + ScalInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + ScalInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t alpha_desc, \ + infiniopTensorDescriptor_t x_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + const void *alpha, \ + void *x, \ + void *stream) const; \ + }; \ + } + +class ScalInfo { +private: + size_t _size; + ptrdiff_t _incx; + infiniDtype_t _dtype; + + ScalInfo(size_t size, + ptrdiff_t incx, + infiniDtype_t dtype) + : _size(size), _incx(incx), _dtype(dtype) {} + +public: + inline size_t getSize() const { return _size; } + inline ptrdiff_t getIncx() const { return _incx; } + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static ResultType createScalInfo( + infiniopTensorDescriptor_t alpha_desc, + infiniopTensorDescriptor_t x_desc) { + + CHECK_OR_RETURN(alpha_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = x_desc->dtype(); + + CHECK_OR_RETURN(alpha_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_OR_RETURN(alpha_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto size = x_desc->numel(); + auto incx = x_desc->stride(0); + + ScalInfo info(size, incx, dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __SCAL_H__ diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 18ba30e24..3e71d669d 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2531,3 +2531,35 @@ def rotmg_(lib): lib.infiniopDestroyRotmgDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def scal_(lib): + lib.infiniopCreateScalDescriptor.restype = c_int32 + lib.infiniopCreateScalDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetScalWorkspaceSize.restype = c_int32 + lib.infiniopGetScalWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopScal.restype = c_int32 + lib.infiniopScal.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyScalDescriptor.restype = c_int32 + lib.infiniopDestroyScalDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/scal.py b/test/infiniop/scal.py new file mode 100644 index 000000000..69279c070 --- /dev/null +++ b/test/infiniop/scal.py @@ -0,0 +1,164 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +# Format: (shape, x_stride, alpha) +_TEST_CASES = [ + ((13,), None, 2.5), + ((13,), (10,), 2.5), + ((5632,), None, 2.5), + ((5632,), (5,), 2.5), + ((16,), (4,), 2.5), + ((5632,), (32,), 2.5), +] + +# Data types used for testing +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + InfiniDtype.F64: {"atol": 1e-15, "rtol": 1e-15}, + InfiniDtype.BF16: {"atol": 5e-3, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def scal(x, alpha): + x.mul_(alpha) + + +def test( + handle, + device, + shape, + x_stride=None, + alpha_value=2.5, + dtype=torch.float16, + sync=None, +): + alpha_torch = torch.tensor([alpha_value]) + alpha = TestTensor( + alpha_torch.shape, + alpha_torch.stride(), + dtype, + device, + mode="manual", + set_tensor=alpha_torch, + ) + x = TestTensor(shape, x_stride, dtype, device) + + if x.is_broadcast(): + return + + print( + f"Testing Scal on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} " + f"dtype:{InfiniDtypeNames[dtype]}" + ) + + # Compute PyTorch reference + scal(x.torch_tensor(), alpha.torch_tensor()) + + if sync is not None: + sync() + + # Create Descriptor + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateScalDescriptor( + handle, + ctypes.byref(descriptor), + alpha.descriptor, + x.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + alpha.destroy_desc() + x.destroy_desc() + + # Allocate Workspace + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetScalWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, x.device) + + # Execute C library op + def lib_scal(): + check_error( + LIBINFINIOP.infiniopScal( + descriptor, + workspace.data(), + workspace.size(), + alpha.data(), + x.data(), + None, + ) + ) + + lib_scal() + + # Compare results + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(x.actual_tensor(), x.torch_tensor(), atol=atol, rtol=rtol) + + assert torch.allclose(x.actual_tensor(), x.torch_tensor(), atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: scal(x.torch_tensor(), alpha.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_scal(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + check_error(LIBINFINIOP.infiniopDestroyScalDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") From e64afb999a94c977781364909c96c39b40c96f2a Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Tue, 28 Apr 2026 10:24:30 +0000 Subject: [PATCH 12/25] Add `swap` operator --- include/infiniop.h | 1 + include/infiniop/ops/swap.h | 24 +++ src/infiniop/devices/metax/metax_ht2mc.h | 2 + src/infiniop/ops/swap/bang/swap_bang.h | 8 + src/infiniop/ops/swap/bang/swap_bang.mlu | 92 +++++++++++ .../ops/swap/bang/swap_bang_kernel.mlu | 75 +++++++++ src/infiniop/ops/swap/cpu/swap_cpu.cc | 82 ++++++++++ src/infiniop/ops/swap/cpu/swap_cpu.h | 8 + src/infiniop/ops/swap/metax/swap_metax.cc | 72 +++++++++ src/infiniop/ops/swap/metax/swap_metax.h | 8 + src/infiniop/ops/swap/operator.cc | 121 ++++++++++++++ src/infiniop/ops/swap/swap.h | 94 +++++++++++ test/infiniop/libinfiniop/op_register.py | 32 ++++ test/infiniop/swap.py | 151 ++++++++++++++++++ 14 files changed, 770 insertions(+) create mode 100644 include/infiniop/ops/swap.h create mode 100644 src/infiniop/ops/swap/bang/swap_bang.h create mode 100644 src/infiniop/ops/swap/bang/swap_bang.mlu create mode 100644 src/infiniop/ops/swap/bang/swap_bang_kernel.mlu create mode 100644 src/infiniop/ops/swap/cpu/swap_cpu.cc create mode 100644 src/infiniop/ops/swap/cpu/swap_cpu.h create mode 100644 src/infiniop/ops/swap/metax/swap_metax.cc create mode 100644 src/infiniop/ops/swap/metax/swap_metax.h create mode 100644 src/infiniop/ops/swap/operator.cc create mode 100644 src/infiniop/ops/swap/swap.h create mode 100644 test/infiniop/swap.py diff --git a/include/infiniop.h b/include/infiniop.h index bbf837bea..de33a7a4b 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -117,6 +117,7 @@ #include "infiniop/ops/softsign.h" #include "infiniop/ops/sub.h" #include "infiniop/ops/sum.h" +#include "infiniop/ops/swap.h" #include "infiniop/ops/swiglu.h" #include "infiniop/ops/take.h" #include "infiniop/ops/tan.h" diff --git a/include/infiniop/ops/swap.h b/include/infiniop/ops/swap.h new file mode 100644 index 000000000..7eb14b57a --- /dev/null +++ b/include/infiniop/ops/swap.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_SWAP_API_H__ +#define __INFINIOP_SWAP_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopSwapDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateSwapDescriptor(infiniopHandle_t handle, + infiniopSwapDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y); + +__INFINI_C __export infiniStatus_t infiniopGetSwapWorkspaceSize(infiniopSwapDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopSwap(infiniopSwapDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *x, + void *y, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroySwapDescriptor(infiniopSwapDescriptor_t desc); + +#endif // __INFINIOP_SWAP_API_H__ diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index 6ef455a14..7135658be 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -143,6 +143,8 @@ #define hcblasDrotmg mcblasDrotmg #define hcblasSscal mcblasSscal #define hcblasDscal mcblasDscal +#define hcblasSswap mcblasSswap +#define hcblasDswap mcblasDswap #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS #define HCBLAS_OP_T MCBLAS_OP_T #define HCBLAS_OP_N MCBLAS_OP_N diff --git a/src/infiniop/ops/swap/bang/swap_bang.h b/src/infiniop/ops/swap/bang/swap_bang.h new file mode 100644 index 000000000..dd908caa2 --- /dev/null +++ b/src/infiniop/ops/swap/bang/swap_bang.h @@ -0,0 +1,8 @@ +#ifndef __SWAP_BANG_H__ +#define __SWAP_BANG_H__ + +#include "../swap.h" + +DESCRIPTOR(bang) + +#endif // __SWAP_BANG_H__ diff --git a/src/infiniop/ops/swap/bang/swap_bang.mlu b/src/infiniop/ops/swap/bang/swap_bang.mlu new file mode 100644 index 000000000..041d09654 --- /dev/null +++ b/src/infiniop/ops/swap/bang/swap_bang.mlu @@ -0,0 +1,92 @@ +#include "../../../devices/bang/common_bang.h" +#include "swap_bang.h" +#include "swap_bang_kernel.mlu" + +namespace op::swap::bang { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = SwapInfo::createSwapInfo(x_desc, y_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateSwap( + const SwapInfo &info, + Tdata *x, + Tdata *y, + cnrtQueue_t queue) { + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = cnrtFuncTypeUnion1; + + if (info.getIncx() == 1 && info.getIncy() == 1) { + swapKernelContiguous<<>>( + info.getSize(), + x, + y); + } else { + swapKernelStrided<<>>( + info.getSize(), + x, + info.getIncx(), + y, + info.getIncy()); + } + + cnrtQueueSync(queue); + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_SWAP(TDATA) \ + calculateSwap(_info, \ + (TDATA *)x, \ + (TDATA *)y, \ + (cnrtQueue_t)stream) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *x, + void *y, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_SWAP(half); + case INFINI_DTYPE_BF16: + return CALCULATE_SWAP(bfloat16_t); + case INFINI_DTYPE_F32: + return CALCULATE_SWAP(float); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_SWAP + +} // namespace op::swap::bang diff --git a/src/infiniop/ops/swap/bang/swap_bang_kernel.mlu b/src/infiniop/ops/swap/bang/swap_bang_kernel.mlu new file mode 100644 index 000000000..d5f44292d --- /dev/null +++ b/src/infiniop/ops/swap/bang/swap_bang_kernel.mlu @@ -0,0 +1,75 @@ +#include "../../../devices/bang/common_bang.h" +#include "swap_bang.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void swapKernelContiguous( + size_t n, + Tdata *x, + Tdata *y) { + + Tdata *nram_align = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_align - nram_buffer); + size_t max_chunk_elements = nram_usable / (2 * sizeof(Tdata)); + + int align_elements = ALIGN_SIZE / sizeof(Tdata); + if (align_elements == 0) { + align_elements = 1; + } + max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + + Tdata *nram_x = nram_align; + Tdata *nram_y = nram_align + max_chunk_elements; + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int core_elements = elements_per_core + (taskId < remain ? 1 : 0); + int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + + if (core_elements <= 0) { + return; + } + + int chunks = core_elements / max_chunk_elements; + int chunk_rem = core_elements % max_chunk_elements; + + for (int c = 0; c < chunks; c++) { + size_t current_offset = core_offset + c * max_chunk_elements; + __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + __memcpy(x + current_offset, nram_y, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + __memcpy(y + current_offset, nram_x, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + } + + if (chunk_rem > 0) { + size_t current_offset = core_offset + chunks * max_chunk_elements; + __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); + __memcpy(x + current_offset, nram_y, chunk_rem * sizeof(Tdata), NRAM2GDRAM); + __memcpy(y + current_offset, nram_x, chunk_rem * sizeof(Tdata), NRAM2GDRAM); + } +} + +template +__mlu_global__ void swapKernelStrided( + size_t n, + Tdata *x, + size_t incx, + Tdata *y, + size_t incy) { + + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0); + int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; + + for (int i = start_idx; i < start_idx + actual_tasks; ++i) { + size_t x_idx = i * incx; + size_t y_idx = i * incy; + Tdata temp = x[x_idx]; + x[x_idx] = y[y_idx]; + y[y_idx] = temp; + } +} diff --git a/src/infiniop/ops/swap/cpu/swap_cpu.cc b/src/infiniop/ops/swap/cpu/swap_cpu.cc new file mode 100644 index 000000000..4a59f96c4 --- /dev/null +++ b/src/infiniop/ops/swap/cpu/swap_cpu.cc @@ -0,0 +1,82 @@ +#include "swap_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::swap::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = SwapInfo::createSwapInfo(x_desc, y_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateSwap( + const SwapInfo &info, + Tdata *x, + Tdata *y) { + + const ptrdiff_t size = info.getSize(); + const ptrdiff_t incx = info.getIncx(); + const ptrdiff_t incy = info.getIncy(); + +#pragma omp parallel for if (size > 1024) + for (ptrdiff_t i = 0; i < size; ++i) { + const ptrdiff_t x_idx = i * incx; + const ptrdiff_t y_idx = i * incy; + Tdata temp = x[x_idx]; + x[x_idx] = y[y_idx]; + y[y_idx] = temp; + } + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_SWAP(TDATA) \ + calculateSwap(_info, \ + (TDATA *)x, \ + (TDATA *)y) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *x, + void *y, + void *stream) const { + + (void)workspace; + (void)workspace_size; + (void)stream; + + switch (_info.getDtype()) { + case INFINI_DTYPE_F16: + return CALCULATE_SWAP(fp16_t); + case INFINI_DTYPE_BF16: + return CALCULATE_SWAP(bf16_t); + case INFINI_DTYPE_F32: + return CALCULATE_SWAP(float); + case INFINI_DTYPE_F64: + return CALCULATE_SWAP(double); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef CALCULATE_SWAP + +} // namespace op::swap::cpu diff --git a/src/infiniop/ops/swap/cpu/swap_cpu.h b/src/infiniop/ops/swap/cpu/swap_cpu.h new file mode 100644 index 000000000..b37295614 --- /dev/null +++ b/src/infiniop/ops/swap/cpu/swap_cpu.h @@ -0,0 +1,8 @@ +#ifndef __SWAP_CPU_H__ +#define __SWAP_CPU_H__ + +#include "../swap.h" + +DESCRIPTOR(cpu) + +#endif // __SWAP_CPU_H__ diff --git a/src/infiniop/ops/swap/metax/swap_metax.cc b/src/infiniop/ops/swap/metax/swap_metax.cc new file mode 100644 index 000000000..4cc7d336a --- /dev/null +++ b/src/infiniop/ops/swap/metax/swap_metax.cc @@ -0,0 +1,72 @@ +#include "swap_metax.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_handle.h" + +namespace op::swap::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + auto handle = reinterpret_cast(handle_); + auto info = SwapInfo::createSwapInfo(x_desc, y_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{handle->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *x, + void *y, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const size_t size = _info.getSize(); + const ptrdiff_t incx = _info.getIncx(); + const ptrdiff_t incy = _info.getIncy(); + const infiniDtype_t data_type = _info.getDtype(); + + CHECK_STATUS(_opaque->internal->useMcblas( + (hcStream_t)stream, + [&](hcblasHandle_t handle) { + CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + + switch (data_type) { + case INFINI_DTYPE_F32: + CHECK_MCBLAS(hcblasSswap(handle, size, (float *)x, incx, (float *)y, incy)); + break; + case INFINI_DTYPE_F64: + CHECK_MCBLAS(hcblasDswap(handle, size, (double *)x, incx, (double *)y, incy)); + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::swap::metax diff --git a/src/infiniop/ops/swap/metax/swap_metax.h b/src/infiniop/ops/swap/metax/swap_metax.h new file mode 100644 index 000000000..db8817c9a --- /dev/null +++ b/src/infiniop/ops/swap/metax/swap_metax.h @@ -0,0 +1,8 @@ +#ifndef __SWAP_METAX_H__ +#define __SWAP_METAX_H__ + +#include "../swap.h" + +DESCRIPTOR(metax) + +#endif // __SWAP_METAX_H__ diff --git a/src/infiniop/ops/swap/operator.cc b/src/infiniop/ops/swap/operator.cc new file mode 100644 index 000000000..22d688021 --- /dev/null +++ b/src/infiniop/ops/swap/operator.cc @@ -0,0 +1,121 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/swap.h" + +#ifdef ENABLE_CPU_API +#include "cpu/swap_cpu.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/swap_metax.h" +#endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/swap_bang.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateSwapDescriptor( + infiniopHandle_t handle, + infiniopSwapDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::swap::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + x_desc, y_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetSwapWorkspaceSize(infiniopSwapDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__INFINI_C infiniStatus_t infiniopSwap( + infiniopSwapDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *x, + void *y, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, x, y, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroySwapDescriptor(infiniopSwapDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/src/infiniop/ops/swap/swap.h b/src/infiniop/ops/swap/swap.h new file mode 100644 index 000000000..7ace7d8ee --- /dev/null +++ b/src/infiniop/ops/swap/swap.h @@ -0,0 +1,94 @@ +#ifndef __SWAP_H__ +#define __SWAP_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/swap.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::swap::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + SwapInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + SwapInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(std::move(info)), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t y_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *x, \ + void *y, \ + void *stream) const; \ + }; \ + } + +class SwapInfo { +private: + size_t _size; + ptrdiff_t _incx; + ptrdiff_t _incy; + infiniDtype_t _dtype; + + SwapInfo(size_t size, + ptrdiff_t incx, + ptrdiff_t incy, + infiniDtype_t dtype) + : _size(size), _incx(incx), _incy(incy), _dtype(dtype) {} + +public: + inline size_t getSize() const { return _size; } + inline ptrdiff_t getIncx() const { return _incx; } + inline ptrdiff_t getIncy() const { return _incy; } + inline infiniDtype_t getDtype() const { return _dtype; } + + using ResultType = utils::Result; + + static ResultType createSwapInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto dtype = x_desc->dtype(); + + CHECK_OR_RETURN(y_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto size = x_desc->numel(); + auto incx = x_desc->stride(0); + auto incy = y_desc->stride(0); + + SwapInfo info(size, incx, incy, dtype); + return ResultType(std::move(info)); + } +}; + +#endif // __SWAP_H__ diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 3e71d669d..0eeb99098 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2563,3 +2563,35 @@ def scal_(lib): lib.infiniopDestroyScalDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def swap_(lib): + lib.infiniopCreateSwapDescriptor.restype = c_int32 + lib.infiniopCreateSwapDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetSwapWorkspaceSize.restype = c_int32 + lib.infiniopGetSwapWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopSwap.restype = c_int32 + lib.infiniopSwap.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroySwapDescriptor.restype = c_int32 + lib.infiniopDestroySwapDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/swap.py b/test/infiniop/swap.py new file mode 100644 index 000000000..4d92668d7 --- /dev/null +++ b/test/infiniop/swap.py @@ -0,0 +1,151 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + +_TEST_CASES = [ + ((13,), None, None), + ((13,), (10,), (10,)), + ((5632,), None, None), + ((5632,), (5,), (5,)), + ((16,), (4,), (4,)), + ((5632,), (32,), (32,)), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F16, + InfiniDtype.F32, + # InfiniDtype.F64, + # InfiniDtype.BF16, +] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + InfiniDtype.F64: {"atol": 1e-15, "rtol": 1e-15}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def torch_swap(x, y): + tmp = x.clone() + x.copy_(y) + y.copy_(tmp) + + +def test( + handle, + device, + shape, + x_stride=None, + y_stride=None, + dtype=torch.float16, + sync=None, +): + x = TestTensor(shape, x_stride, dtype, device) + y = TestTensor(shape, y_stride, dtype, device) + + if x.is_broadcast() or y.is_broadcast(): + return + + print( + f"Testing Swap on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} " + f"y_stride:{y_stride} dtype:{InfiniDtypeNames[dtype]}" + ) + + torch_swap(x.torch_tensor(), y.torch_tensor()) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateSwapDescriptor( + handle, + ctypes.byref(descriptor), + x.descriptor, + y.descriptor, + ) + ) + + x.destroy_desc() + y.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetSwapWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, x.device) + + def lib_swap(): + check_error( + LIBINFINIOP.infiniopSwap( + descriptor, + workspace.data(), + workspace.size(), + x.data(), + y.data(), + None, + ) + ) + + lib_swap() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(x.actual_tensor(), x.torch_tensor(), atol=atol, rtol=rtol) + debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + assert torch.allclose(x.actual_tensor(), x.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: torch_swap(x.torch_tensor(), y.torch_tensor()), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lambda: lib_swap(), device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(LIBINFINIOP.infiniopDestroySwapDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") From f8e9cc27be0ea142254ecc0429dd0b7c1189b8dd Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Wed, 29 Apr 2026 09:00:45 +0000 Subject: [PATCH 13/25] Use mcBLAS Ex APIs for `axpy`, `blas_dot`, `nrm2`, `rot` and `scal` operators --- src/infiniop/devices/metax/metax_ht2mc.h | 16 ++---- src/infiniop/ops/axpy/metax/axpy_metax.cc | 57 +++++++++++++++---- .../ops/blas_dot/metax/blas_dot_metax.cc | 46 +++++++++++---- src/infiniop/ops/nrm2/metax/nrm2_metax.cc | 43 ++++++++++---- src/infiniop/ops/rot/metax/rot_metax.cc | 47 +++++++++++---- src/infiniop/ops/scal/metax/scal_metax.cc | 43 ++++++++++---- test/infiniop/axpy.py | 4 +- test/infiniop/blas_dot.py | 4 +- test/infiniop/nrm2.py | 4 +- test/infiniop/rot.py | 4 +- test/infiniop/scal.py | 4 +- 11 files changed, 202 insertions(+), 70 deletions(-) diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index 7135658be..888f1c72f 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -1,6 +1,7 @@ #ifdef ENABLE_METAX_MC_API #define hpccDataType macaDataType #define HPCC_R_32F MACA_R_32F +#define HPCC_R_64F MACA_R_64F #define HPCC_R_16F MACA_R_16F #define HPCC_R_16BF MACA_R_16BF #define hpcc_bfloat162 maca_bfloat162 @@ -125,24 +126,19 @@ #define hcblasIdamin mcblasIdamin #define hcblasSasum mcblasSasum #define hcblasDasum mcblasDasum -#define hcblasSaxpy mcblasSaxpy -#define hcblasDaxpy mcblasDaxpy +#define hcblasAxpyEx mcblasAxpyEx #define hcblasScopy mcblasScopy #define hcblasDcopy mcblasDcopy -#define hcblasSdot mcblasSdot -#define hcblasDdot mcblasDdot -#define hcblasSnrm2 mcblasSnrm2 -#define hcblasDnrm2 mcblasDnrm2 -#define hcblasSrot mcblasSrot -#define hcblasDrot mcblasDrot +#define hcblasDotEx mcblasDotEx +#define hcblasNrm2Ex mcblasNrm2Ex +#define hcblasRotEx mcblasRotEx #define hcblasSrotg mcblasSrotg #define hcblasDrotg mcblasDrotg #define hcblasSrotm mcblasSrotm #define hcblasDrotm mcblasDrotm #define hcblasSrotmg mcblasSrotmg #define hcblasDrotmg mcblasDrotmg -#define hcblasSscal mcblasSscal -#define hcblasDscal mcblasDscal +#define hcblasScalEx mcblasScalEx #define hcblasSswap mcblasSswap #define hcblasDswap mcblasDswap #define HCBLAS_STATUS_SUCCESS MCBLAS_STATUS_SUCCESS diff --git a/src/infiniop/ops/axpy/metax/axpy_metax.cc b/src/infiniop/ops/axpy/metax/axpy_metax.cc index b31f586f7..a41d3cf62 100644 --- a/src/infiniop/ops/axpy/metax/axpy_metax.cc +++ b/src/infiniop/ops/axpy/metax/axpy_metax.cc @@ -49,21 +49,58 @@ infiniStatus_t Descriptor::calculate( const ptrdiff_t incy = _info.getIncy(); const infiniDtype_t data_type = _info.getDtype(); + hpccDataType alpha_type, x_type, y_type; + hpccDataType execution_type; + + switch (data_type) { + case INFINI_DTYPE_F16: + alpha_type = x_type = y_type = HPCC_R_16F; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_BF16: + alpha_type = x_type = y_type = HPCC_R_16BF; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_F32: + alpha_type = x_type = y_type = HPCC_R_32F; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_F64: + alpha_type = x_type = y_type = HPCC_R_64F; + execution_type = HPCC_R_64F; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, [&](hcblasHandle_t handle) { CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); - switch (data_type) { - case INFINI_DTYPE_F32: - CHECK_MCBLAS(hcblasSaxpy(handle, size, (const float *)alpha, (const float *)x, incx, (float *)y, incy)); - break; - case INFINI_DTYPE_F64: - CHECK_MCBLAS(hcblasDaxpy(handle, size, (const double *)alpha, (const double *)x, incx, (double *)y, incy)); - break; - default: - return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; - } + // switch (data_type) { + // case INFINI_DTYPE_F32: + // CHECK_MCBLAS(hcblasSaxpy(handle, size, (const float *)alpha, (const float *)x, incx, (float *)y, incy)); + // break; + // case INFINI_DTYPE_F64: + // CHECK_MCBLAS(hcblasDaxpy(handle, size, (const double *)alpha, (const double *)x, incx, (double *)y, incy)); + // break; + // default: + // return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + // } + + hcblasAxpyEx( + handle, + size, + alpha, + alpha_type, + x, + x_type, + incx, + y, + y_type, + incy, + execution_type); return INFINI_STATUS_SUCCESS; })); diff --git a/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc b/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc index 16e0bd623..a70ba48c2 100644 --- a/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc +++ b/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc @@ -49,21 +49,47 @@ infiniStatus_t Descriptor::calculate( const ptrdiff_t incy = _info.getIncy(); const infiniDtype_t data_type = _info.getDtype(); + hpccDataType x_type, y_type, result_type; + hpccDataType execution_type; + + switch (data_type) { + case INFINI_DTYPE_F16: + x_type = y_type = result_type = HPCC_R_16F; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_BF16: + x_type = y_type = result_type = HPCC_R_16BF; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_F32: + x_type = y_type = result_type = HPCC_R_32F; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_F64: + x_type = y_type = result_type = HPCC_R_64F; + execution_type = HPCC_R_64F; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, [&](hcblasHandle_t handle) { CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); - switch (data_type) { - case INFINI_DTYPE_F32: - CHECK_MCBLAS(hcblasSdot(handle, size, (const float *)x, incx, (const float *)y, incy, (float *)result)); - break; - case INFINI_DTYPE_F64: - CHECK_MCBLAS(hcblasDdot(handle, size, (const double *)x, incx, (const double *)y, incy, (double *)result)); - break; - default: - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } + CHECK_MCBLAS(hcblasDotEx( + handle, + size, + x, + x_type, + incx, + y, + y_type, + incy, + result, + result_type, + execution_type)); return INFINI_STATUS_SUCCESS; })); diff --git a/src/infiniop/ops/nrm2/metax/nrm2_metax.cc b/src/infiniop/ops/nrm2/metax/nrm2_metax.cc index 4b3b528d2..12319182b 100644 --- a/src/infiniop/ops/nrm2/metax/nrm2_metax.cc +++ b/src/infiniop/ops/nrm2/metax/nrm2_metax.cc @@ -46,21 +46,44 @@ infiniStatus_t Descriptor::calculate( const ptrdiff_t incx = _info.getIncx(); const infiniDtype_t data_type = _info.getDtype(); + hpccDataType x_type, result_type; + hpccDataType execution_type; + + switch (data_type) { + case INFINI_DTYPE_F16: + x_type = result_type = HPCC_R_16F; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_BF16: + x_type = result_type = HPCC_R_16BF; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_F32: + x_type = result_type = HPCC_R_32F; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_F64: + x_type = result_type = HPCC_R_64F; + execution_type = HPCC_R_64F; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, [&](hcblasHandle_t handle) { CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); - switch (data_type) { - case INFINI_DTYPE_F32: - CHECK_MCBLAS(hcblasSnrm2(handle, size, (const float *)x, incx, (float *)result)); - break; - case INFINI_DTYPE_F64: - CHECK_MCBLAS(hcblasDnrm2(handle, size, (const double *)x, incx, (double *)result)); - break; - default: - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } + CHECK_MCBLAS(hcblasNrm2Ex( + handle, + size, + x, + x_type, + incx, + result, + result_type, + execution_type)); return INFINI_STATUS_SUCCESS; })); diff --git a/src/infiniop/ops/rot/metax/rot_metax.cc b/src/infiniop/ops/rot/metax/rot_metax.cc index 0b0c3e567..2d958f20d 100644 --- a/src/infiniop/ops/rot/metax/rot_metax.cc +++ b/src/infiniop/ops/rot/metax/rot_metax.cc @@ -51,21 +51,48 @@ infiniStatus_t Descriptor::calculate( const ptrdiff_t incy = _info.getIncy(); const infiniDtype_t data_type = _info.getDtype(); + hpccDataType x_type, y_type, cs_type; + hpccDataType execution_type; + + switch (data_type) { + case INFINI_DTYPE_F16: + x_type = y_type = cs_type = HPCC_R_16F; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_BF16: + x_type = y_type = cs_type = HPCC_R_16BF; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_F32: + x_type = y_type = cs_type = HPCC_R_32F; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_F64: + x_type = y_type = cs_type = HPCC_R_64F; + execution_type = HPCC_R_64F; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, [&](hcblasHandle_t handle) { CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); - switch (data_type) { - case INFINI_DTYPE_F32: - CHECK_MCBLAS(hcblasSrot(handle, size, (float *)x, incx, (float *)y, incy, (const float *)c, (const float *)s)); - break; - case INFINI_DTYPE_F64: - CHECK_MCBLAS(hcblasDrot(handle, size, (double *)x, incx, (double *)y, incy, (const double *)c, (const double *)s)); - break; - default: - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } + CHECK_MCBLAS(hcblasRotEx( + handle, + size, + x, + x_type, + incx, + y, + y_type, + incy, + c, + s, + cs_type, + execution_type)); return INFINI_STATUS_SUCCESS; })); diff --git a/src/infiniop/ops/scal/metax/scal_metax.cc b/src/infiniop/ops/scal/metax/scal_metax.cc index b534a61aa..c5542e39d 100644 --- a/src/infiniop/ops/scal/metax/scal_metax.cc +++ b/src/infiniop/ops/scal/metax/scal_metax.cc @@ -46,21 +46,44 @@ infiniStatus_t Descriptor::calculate( const ptrdiff_t incx = _info.getIncx(); const infiniDtype_t data_type = _info.getDtype(); + hpccDataType alpha_type, x_type; + hpccDataType execution_type; + + switch (data_type) { + case INFINI_DTYPE_F16: + alpha_type = x_type = HPCC_R_16F; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_BF16: + alpha_type = x_type = HPCC_R_16BF; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_F32: + alpha_type = x_type = HPCC_R_32F; + execution_type = HPCC_R_32F; + break; + case INFINI_DTYPE_F64: + alpha_type = x_type = HPCC_R_64F; + execution_type = HPCC_R_64F; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, [&](hcblasHandle_t handle) { CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); - switch (data_type) { - case INFINI_DTYPE_F32: - CHECK_MCBLAS(hcblasSscal(handle, size, (const float *)alpha, (float *)x, incx)); - break; - case INFINI_DTYPE_F64: - CHECK_MCBLAS(hcblasDscal(handle, size, (const double *)alpha, (double *)x, incx)); - break; - default: - return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; - } + CHECK_MCBLAS(hcblasScalEx( + handle, + size, + alpha, + alpha_type, + x, + x_type, + incx, + execution_type)); return INFINI_STATUS_SUCCESS; })); diff --git a/test/infiniop/axpy.py b/test/infiniop/axpy.py index 278e44cd3..d0258b39b 100644 --- a/test/infiniop/axpy.py +++ b/test/infiniop/axpy.py @@ -33,10 +33,10 @@ ] _TENSOR_DTYPES = [ - # InfiniDtype.F16, + InfiniDtype.F16, InfiniDtype.F32, # InfiniDtype.F64, - # InfiniDtype.BF16, + InfiniDtype.BF16, ] _TOLERANCE_MAP = { diff --git a/test/infiniop/blas_dot.py b/test/infiniop/blas_dot.py index 1f29a2c65..2d51b883d 100644 --- a/test/infiniop/blas_dot.py +++ b/test/infiniop/blas_dot.py @@ -33,10 +33,10 @@ ] _TENSOR_DTYPES = [ - # InfiniDtype.F16, + InfiniDtype.F16, InfiniDtype.F32, # InfiniDtype.F64, - # InfiniDtype.BF16, + InfiniDtype.BF16, ] _TOLERANCE_MAP = { diff --git a/test/infiniop/nrm2.py b/test/infiniop/nrm2.py index 1523ce082..7c45fda1f 100644 --- a/test/infiniop/nrm2.py +++ b/test/infiniop/nrm2.py @@ -33,10 +33,10 @@ ] _TENSOR_DTYPES = [ - # InfiniDtype.F16, + InfiniDtype.F16, InfiniDtype.F32, # InfiniDtype.F64, - # InfiniDtype.BF16, + InfiniDtype.BF16, ] _TOLERANCE_MAP = { diff --git a/test/infiniop/rot.py b/test/infiniop/rot.py index 9eff59601..e9d6f63f3 100644 --- a/test/infiniop/rot.py +++ b/test/infiniop/rot.py @@ -29,10 +29,10 @@ ] _TENSOR_DTYPES = [ - # InfiniDtype.F16, + InfiniDtype.F16, InfiniDtype.F32, # InfiniDtype.F64, - # InfiniDtype.BF16, + InfiniDtype.BF16, ] _TOLERANCE_MAP = { diff --git a/test/infiniop/scal.py b/test/infiniop/scal.py index 69279c070..ff5c6ed32 100644 --- a/test/infiniop/scal.py +++ b/test/infiniop/scal.py @@ -35,10 +35,10 @@ # Data types used for testing _TENSOR_DTYPES = [ - # InfiniDtype.F16, + InfiniDtype.F16, InfiniDtype.F32, # InfiniDtype.F64, - # InfiniDtype.BF16, + InfiniDtype.BF16, ] # Tolerance map for different data types From 9da0b531d1531025a2d5a8c5e2e8f36dad0262df Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Wed, 6 May 2026 06:25:17 +0000 Subject: [PATCH 14/25] Split BLAS op info into local headers --- src/infiniop/ops/asum/asum.h | 45 +------------- src/infiniop/ops/asum/bang/asum_bang.mlu | 16 ++--- src/infiniop/ops/asum/cpu/asum_cpu.cc | 16 ++--- src/infiniop/ops/asum/info.h | 41 +++++++++++++ src/infiniop/ops/asum/metax/asum_metax.cc | 16 ++--- src/infiniop/ops/axpy/axpy.h | 54 +---------------- src/infiniop/ops/axpy/bang/axpy_bang.mlu | 14 ++--- src/infiniop/ops/axpy/cpu/axpy_cpu.cc | 14 ++--- src/infiniop/ops/axpy/info.h | 49 ++++++++++++++++ src/infiniop/ops/axpy/metax/axpy_metax.cc | 35 +++++------ .../ops/blas_amax/bang/blas_amax_bang.mlu | 12 ++-- src/infiniop/ops/blas_amax/blas_amax.h | 45 +------------- .../ops/blas_amax/cpu/blas_amax_cpu.cc | 12 ++-- src/infiniop/ops/blas_amax/info.h | 41 +++++++++++++ .../ops/blas_amax/metax/blas_amax_metax.cc | 12 ++-- .../ops/blas_amin/bang/blas_amin_bang.mlu | 12 ++-- src/infiniop/ops/blas_amin/blas_amin.h | 45 +------------- .../ops/blas_amin/cpu/blas_amin_cpu.cc | 12 ++-- src/infiniop/ops/blas_amin/info.h | 41 +++++++++++++ .../ops/blas_amin/metax/blas_amin_metax.cc | 12 ++-- .../ops/blas_copy/bang/blas_copy_bang.mlu | 18 +++--- src/infiniop/ops/blas_copy/blas_copy.h | 50 +--------------- .../ops/blas_copy/cpu/blas_copy_cpu.cc | 14 ++--- src/infiniop/ops/blas_copy/info.h | 45 ++++++++++++++ .../ops/blas_copy/metax/blas_copy_metax.cc | 14 ++--- .../ops/blas_dot/bang/blas_dot_bang.mlu | 18 +++--- src/infiniop/ops/blas_dot/blas_dot.h | 54 +---------------- src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc | 14 ++--- src/infiniop/ops/blas_dot/info.h | 49 ++++++++++++++++ .../ops/blas_dot/metax/blas_dot_metax.cc | 18 +++--- src/infiniop/ops/nrm2/bang/nrm2_bang.mlu | 16 ++--- src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc | 12 ++-- src/infiniop/ops/nrm2/info.h | 41 +++++++++++++ src/infiniop/ops/nrm2/metax/nrm2_metax.cc | 16 ++--- src/infiniop/ops/nrm2/nrm2.h | 45 +------------- src/infiniop/ops/rot/bang/rot_bang.mlu | 18 +++--- src/infiniop/ops/rot/cpu/rot_cpu.cc | 14 ++--- src/infiniop/ops/rot/info.h | 53 +++++++++++++++++ src/infiniop/ops/rot/metax/rot_metax.cc | 18 +++--- src/infiniop/ops/rot/rot.h | 58 +------------------ src/infiniop/ops/rotg/bang/rotg_bang.mlu | 8 +-- src/infiniop/ops/rotg/cpu/rotg_cpu.cc | 8 +-- src/infiniop/ops/rotg/info.h | 41 +++++++++++++ src/infiniop/ops/rotg/metax/rotg_metax.cc | 8 +-- src/infiniop/ops/rotg/rotg.h | 42 +------------- src/infiniop/ops/rotm/bang/rotm_bang.mlu | 18 +++--- src/infiniop/ops/rotm/cpu/rotm_cpu.cc | 16 ++--- src/infiniop/ops/rotm/info.h | 50 ++++++++++++++++ src/infiniop/ops/rotm/metax/rotm_metax.cc | 14 ++--- src/infiniop/ops/rotm/rotm.h | 55 +----------------- src/infiniop/ops/rotmg/bang/rotmg_bang.mlu | 8 +-- src/infiniop/ops/rotmg/cpu/rotmg_cpu.cc | 8 +-- src/infiniop/ops/rotmg/info.h | 48 +++++++++++++++ src/infiniop/ops/rotmg/metax/rotmg_metax.cc | 8 +-- src/infiniop/ops/rotmg/rotmg.h | 49 +--------------- src/infiniop/ops/scal/bang/scal_bang.mlu | 16 ++--- src/infiniop/ops/scal/cpu/scal_cpu.cc | 12 ++-- src/infiniop/ops/scal/info.h | 40 +++++++++++++ src/infiniop/ops/scal/metax/scal_metax.cc | 16 ++--- src/infiniop/ops/scal/scal.h | 44 +------------- src/infiniop/ops/swap/bang/swap_bang.mlu | 18 +++--- src/infiniop/ops/swap/cpu/swap_cpu.cc | 14 ++--- src/infiniop/ops/swap/info.h | 44 ++++++++++++++ src/infiniop/ops/swap/metax/swap_metax.cc | 14 ++--- src/infiniop/ops/swap/swap.h | 49 +--------------- 65 files changed, 875 insertions(+), 902 deletions(-) create mode 100644 src/infiniop/ops/asum/info.h create mode 100644 src/infiniop/ops/axpy/info.h create mode 100644 src/infiniop/ops/blas_amax/info.h create mode 100644 src/infiniop/ops/blas_amin/info.h create mode 100644 src/infiniop/ops/blas_copy/info.h create mode 100644 src/infiniop/ops/blas_dot/info.h create mode 100644 src/infiniop/ops/nrm2/info.h create mode 100644 src/infiniop/ops/rot/info.h create mode 100644 src/infiniop/ops/rotg/info.h create mode 100644 src/infiniop/ops/rotm/info.h create mode 100644 src/infiniop/ops/rotmg/info.h create mode 100644 src/infiniop/ops/scal/info.h create mode 100644 src/infiniop/ops/swap/info.h diff --git a/src/infiniop/ops/asum/asum.h b/src/infiniop/ops/asum/asum.h index 0174e3591..71cc97cad 100644 --- a/src/infiniop/ops/asum/asum.h +++ b/src/infiniop/ops/asum/asum.h @@ -1,10 +1,8 @@ #ifndef __ASUM_H__ #define __ASUM_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/asum.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -46,45 +44,4 @@ }; \ } -class AsumInfo { -private: - size_t _size; - ptrdiff_t _incx; - infiniDtype_t _dtype; - - AsumInfo(size_t size, - ptrdiff_t incx, - infiniDtype_t dtype) - : _size(size), _incx(incx), _dtype(dtype) {} - -public: - inline size_t getSize() const { return _size; } - inline ptrdiff_t getIncx() const { return _incx; } - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static ResultType createAsumInfo( - infiniopTensorDescriptor_t x_desc, - infiniopTensorDescriptor_t result_desc) { - - CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = x_desc->dtype(); - - CHECK_OR_RETURN(result_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - - CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - - auto size = x_desc->numel(); - auto incx = x_desc->stride(0); - - AsumInfo info(size, incx, dtype); - return ResultType(std::move(info)); - } -}; - #endif // __ASUM_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asum/bang/asum_bang.mlu b/src/infiniop/ops/asum/bang/asum_bang.mlu index f9658b072..4ea5ef5b6 100644 --- a/src/infiniop/ops/asum/bang/asum_bang.mlu +++ b/src/infiniop/ops/asum/bang/asum_bang.mlu @@ -13,11 +13,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = AsumInfo::createAsumInfo(x_desc, result_desc); - CHECK_RESULT(info); + auto result = AsumInfo::createAsumInfo(x_desc, result_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -33,8 +33,8 @@ infiniStatus_t calculateAsum( Tdata *result, cnrtQueue_t queue) { - const size_t size = info.getSize(); - const ptrdiff_t incx = info.getIncx(); + const size_t n = info.n; + const ptrdiff_t incx = info.incx; cnrtDim3_t k_dim; cnrtFunctionType_t k_type; @@ -46,12 +46,12 @@ infiniStatus_t calculateAsum( if (incx == 1) { asumKernelContiguous<<>>( - size, + n, x, result); } else { asumKernelStrided<<>>( - size, + n, x, incx, result); @@ -78,7 +78,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_ASUM(half); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/asum/cpu/asum_cpu.cc b/src/infiniop/ops/asum/cpu/asum_cpu.cc index 52692d984..bae85f6c2 100644 --- a/src/infiniop/ops/asum/cpu/asum_cpu.cc +++ b/src/infiniop/ops/asum/cpu/asum_cpu.cc @@ -12,11 +12,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = AsumInfo::createAsumInfo(x_desc, result_desc); - CHECK_RESULT(info); + auto result = AsumInfo::createAsumInfo(x_desc, result_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -31,13 +31,13 @@ infiniStatus_t calculateAsum( const Tdata *x, Tdata *result) { - const ptrdiff_t size = info.getSize(); - const ptrdiff_t incx = info.getIncx(); + const ptrdiff_t n = info.n; + const ptrdiff_t incx = info.incx; if constexpr (std::is_same::value || std::is_same::value) { float total_sum = 0.0; - for (ptrdiff_t i = 0; i < size; ++i) { + for (ptrdiff_t i = 0; i < n; ++i) { total_sum += std::abs(utils::cast(x[i * incx])); } @@ -45,7 +45,7 @@ infiniStatus_t calculateAsum( } else { Tdata total_sum = 0.0; - for (ptrdiff_t i = 0; i < size; ++i) { + for (ptrdiff_t i = 0; i < n; ++i) { total_sum += std::abs(x[i * incx]); } @@ -70,7 +70,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_ASUM(fp16_t); case INFINI_DTYPE_BF16: diff --git a/src/infiniop/ops/asum/info.h b/src/infiniop/ops/asum/info.h new file mode 100644 index 000000000..17841f48f --- /dev/null +++ b/src/infiniop/ops/asum/info.h @@ -0,0 +1,41 @@ +#ifndef __ASUM_INFO_H__ +#define __ASUM_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class AsumInfo { +private: + AsumInfo() = default; + +public: + size_t n; + ptrdiff_t incx; + infiniDtype_t data_type; + + static utils::Result createAsumInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = x_desc->dtype(); + + CHECK_OR_RETURN(result_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto n = x_desc->numel(); + auto incx = x_desc->stride(0); + + return utils::Result(AsumInfo{ + n, + incx, + data_type}); + } +}; + +#endif // __ASUM_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asum/metax/asum_metax.cc b/src/infiniop/ops/asum/metax/asum_metax.cc index edd3febed..61098eed2 100644 --- a/src/infiniop/ops/asum/metax/asum_metax.cc +++ b/src/infiniop/ops/asum/metax/asum_metax.cc @@ -19,11 +19,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = AsumInfo::createAsumInfo(x_desc, result_desc); - CHECK_RESULT(info); + auto result = AsumInfo::createAsumInfo(x_desc, result_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -42,9 +42,9 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.getSize(); - const ptrdiff_t incx = _info.getIncx(); - const infiniDtype_t data_type = _info.getDtype(); + const size_t n = _info.n; + const ptrdiff_t incx = _info.incx; + const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, @@ -53,10 +53,10 @@ infiniStatus_t Descriptor::calculate( switch (data_type) { case INFINI_DTYPE_F32: - CHECK_MCBLAS(hcblasSasum(handle, size, (const float *)x, incx, (float *)result)); + CHECK_MCBLAS(hcblasSasum(handle, n, (const float *)x, incx, (float *)result)); break; case INFINI_DTYPE_F64: - CHECK_MCBLAS(hcblasDasum(handle, size, (const double *)x, incx, (double *)result)); + CHECK_MCBLAS(hcblasDasum(handle, n, (const double *)x, incx, (double *)result)); break; default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/axpy/axpy.h b/src/infiniop/ops/axpy/axpy.h index 9d25ef29c..7dedb68b4 100644 --- a/src/infiniop/ops/axpy/axpy.h +++ b/src/infiniop/ops/axpy/axpy.h @@ -1,10 +1,8 @@ #ifndef __AXPY_H__ #define __AXPY_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/axpy.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -48,54 +46,4 @@ }; \ } -class AxpyInfo { -private: - size_t _size; - ptrdiff_t _incx; - ptrdiff_t _incy; - infiniDtype_t _dtype; - - AxpyInfo(size_t size, - ptrdiff_t incx, - ptrdiff_t incy, - infiniDtype_t dtype) - : _size(size), _incx(incx), _incy(incy), _dtype(dtype) {} - -public: - inline size_t getSize() const { return _size; } - inline ptrdiff_t getIncx() const { return _incx; } - inline ptrdiff_t getIncy() const { return _incy; } - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static ResultType createAxpyInfo( - infiniopTensorDescriptor_t alpha_desc, - infiniopTensorDescriptor_t x_desc, - infiniopTensorDescriptor_t y_desc) { - - CHECK_OR_RETURN(alpha_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = x_desc->dtype(); - - CHECK_OR_RETURN(alpha_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(x_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - - CHECK_OR_RETURN(alpha_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); - - auto size = x_desc->numel(); - auto incx = x_desc->stride(0); - auto incy = y_desc->stride(0); - - AxpyInfo info(size, incx, incy, dtype); - return ResultType(std::move(info)); - } -}; - #endif // __AXPY_H__ \ No newline at end of file diff --git a/src/infiniop/ops/axpy/bang/axpy_bang.mlu b/src/infiniop/ops/axpy/bang/axpy_bang.mlu index 004897e52..e74d07dbf 100644 --- a/src/infiniop/ops/axpy/bang/axpy_bang.mlu +++ b/src/infiniop/ops/axpy/bang/axpy_bang.mlu @@ -14,12 +14,12 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t y_desc) { auto handle = reinterpret_cast(handle_); - auto info = AxpyInfo::createAxpyInfo(alpha_desc, x_desc, y_desc); - CHECK_RESULT(info); + auto result = AxpyInfo::createAxpyInfo(alpha_desc, x_desc, y_desc); + CHECK_RESULT(result); // Create descriptor *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -36,9 +36,9 @@ infiniStatus_t calculateAxpy( Tdata *y, cnrtQueue_t queue) { - const size_t size = info.getSize(); - const ptrdiff_t incx = info.getIncx(); - const ptrdiff_t incy = info.getIncy(); + const size_t size = info.n; + const ptrdiff_t incx = info.incx; + const ptrdiff_t incy = info.incy; cnrtDim3_t k_dim; cnrtFunctionType_t k_type; @@ -87,7 +87,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_AXPY(half); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/axpy/cpu/axpy_cpu.cc b/src/infiniop/ops/axpy/cpu/axpy_cpu.cc index ab2b5e673..8c42c4855 100644 --- a/src/infiniop/ops/axpy/cpu/axpy_cpu.cc +++ b/src/infiniop/ops/axpy/cpu/axpy_cpu.cc @@ -13,11 +13,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t y_desc) { auto handle = reinterpret_cast(handle_); - auto info = AxpyInfo::createAxpyInfo(alpha_desc, x_desc, y_desc); - CHECK_RESULT(info); + auto result = AxpyInfo::createAxpyInfo(alpha_desc, x_desc, y_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -33,9 +33,9 @@ infiniStatus_t calculateAxpy( const Tdata *x, Tdata *y) { - const ptrdiff_t size = info.getSize(); - const ptrdiff_t incx = info.getIncx(); - const ptrdiff_t incy = info.getIncy(); + const ptrdiff_t size = info.n; + const ptrdiff_t incx = info.incx; + const ptrdiff_t incy = info.incy; if constexpr (std::is_same::value || std::is_same::value) { const float alpha_f = utils::cast(alpha[0]); @@ -72,7 +72,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace_size; (void)stream; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_AXPY(fp16_t); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/axpy/info.h b/src/infiniop/ops/axpy/info.h new file mode 100644 index 000000000..7155c846b --- /dev/null +++ b/src/infiniop/ops/axpy/info.h @@ -0,0 +1,49 @@ +#ifndef __AXPY_INFO_H__ +#define __AXPY_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class AxpyInfo { +private: + AxpyInfo() = default; + +public: + size_t n; + ptrdiff_t incx; + ptrdiff_t incy; + infiniDtype_t data_type; + + static utils::Result createAxpyInfo( + infiniopTensorDescriptor_t alpha_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + CHECK_OR_RETURN(alpha_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = x_desc->dtype(); + + CHECK_OR_RETURN(alpha_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(x_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_OR_RETURN(alpha_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto n = x_desc->numel(); + auto incx = x_desc->stride(0); + auto incy = y_desc->stride(0); + + return utils::Result(AxpyInfo{ + n, + incx, + incy, + data_type}); + } +}; + +#endif // __AXPY_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/axpy/metax/axpy_metax.cc b/src/infiniop/ops/axpy/metax/axpy_metax.cc index a41d3cf62..f1fd491cf 100644 --- a/src/infiniop/ops/axpy/metax/axpy_metax.cc +++ b/src/infiniop/ops/axpy/metax/axpy_metax.cc @@ -20,11 +20,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t y_desc) { auto handle = reinterpret_cast(handle_); - auto info = AxpyInfo::createAxpyInfo(alpha_desc, x_desc, y_desc); - CHECK_RESULT(info); + auto result = AxpyInfo::createAxpyInfo(alpha_desc, x_desc, y_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -44,10 +44,10 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.getSize(); - const ptrdiff_t incx = _info.getIncx(); - const ptrdiff_t incy = _info.getIncy(); - const infiniDtype_t data_type = _info.getDtype(); + const size_t size = _info.n; + const ptrdiff_t incx = _info.incx; + const ptrdiff_t incy = _info.incy; + const infiniDtype_t data_type = _info.data_type; hpccDataType alpha_type, x_type, y_type; hpccDataType execution_type; @@ -76,20 +76,11 @@ infiniStatus_t Descriptor::calculate( CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, [&](hcblasHandle_t handle) { - CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); - - // switch (data_type) { - // case INFINI_DTYPE_F32: - // CHECK_MCBLAS(hcblasSaxpy(handle, size, (const float *)alpha, (const float *)x, incx, (float *)y, incy)); - // break; - // case INFINI_DTYPE_F64: - // CHECK_MCBLAS(hcblasDaxpy(handle, size, (const double *)alpha, (const double *)x, incx, (double *)y, incy)); - // break; - // default: - // return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; - // } - - hcblasAxpyEx( + CHECK_MCBLAS(hcblasSetPointerMode( + handle, + HCBLAS_POINTER_MODE_DEVICE)); + + CHECK_MCBLAS(hcblasAxpyEx( handle, size, alpha, @@ -100,7 +91,7 @@ infiniStatus_t Descriptor::calculate( y, y_type, incy, - execution_type); + execution_type)); return INFINI_STATUS_SUCCESS; })); diff --git a/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu b/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu index 38b450caf..5e0e6ac46 100644 --- a/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu +++ b/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu @@ -13,12 +13,12 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = BlasAmaxInfo::createBlasAmaxInfo(x_desc, result_desc); - CHECK_RESULT(info); + auto result = BlasAmaxInfo::createBlasAmaxInfo(x_desc, result_desc); + CHECK_RESULT(result); // Create descriptor *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -34,8 +34,8 @@ infiniStatus_t calculateBlasAmax( int *result, cnrtQueue_t queue) { - const size_t size = info.getSize(); - const ptrdiff_t incx = info.getIncx(); + const size_t size = info.n; + const ptrdiff_t incx = info.incx; cnrtDim3_t k_dim; cnrtFunctionType_t k_type; @@ -79,7 +79,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_BLAS_AMAX(half); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/blas_amax/blas_amax.h b/src/infiniop/ops/blas_amax/blas_amax.h index 06247fe29..e5b3400c3 100644 --- a/src/infiniop/ops/blas_amax/blas_amax.h +++ b/src/infiniop/ops/blas_amax/blas_amax.h @@ -1,10 +1,8 @@ #ifndef __BLAS_AMAX_H__ #define __BLAS_AMAX_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/blas_amax.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -46,45 +44,4 @@ }; \ } -class BlasAmaxInfo { -private: - size_t _size; - ptrdiff_t _incx; - infiniDtype_t _dtype; - - BlasAmaxInfo(size_t size, - ptrdiff_t incx, - infiniDtype_t dtype) - : _size(size), _incx(incx), _dtype(dtype) {} - -public: - inline size_t getSize() const { return _size; } - inline ptrdiff_t getIncx() const { return _incx; } - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static utils::Result createBlasAmaxInfo( - infiniopTensorDescriptor_t x_desc, - infiniopTensorDescriptor_t result_desc) { - - CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = x_desc->dtype(); - auto itype = result_desc->dtype(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - CHECK_DTYPE(itype, INFINI_DTYPE_I32); - - CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - - auto size = x_desc->numel(); - auto incx = x_desc->stride(0); - - BlasAmaxInfo info(size, incx, dtype); - return ResultType(std::move(info)); - } -}; - #endif // __BLAS_AMAX_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.cc b/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.cc index 237dfde64..8d66d71f2 100644 --- a/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.cc +++ b/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.cc @@ -12,12 +12,12 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = BlasAmaxInfo::createBlasAmaxInfo(x_desc, result_desc); - CHECK_RESULT(info); + auto result = BlasAmaxInfo::createBlasAmaxInfo(x_desc, result_desc); + CHECK_RESULT(result); // Create descriptor *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -32,8 +32,8 @@ infiniStatus_t calculateBlasAmax( const Tdata *x, int *result) { - const ptrdiff_t size = info.getSize(); - const ptrdiff_t incx = info.getIncx(); + const ptrdiff_t size = info.n; + const ptrdiff_t incx = info.incx; if (size < 1 || incx == 0) { result[0] = 0; @@ -84,7 +84,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace_size; (void)stream; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_BLAS_AMAX(fp16_t); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/blas_amax/info.h b/src/infiniop/ops/blas_amax/info.h new file mode 100644 index 000000000..b9b1924c8 --- /dev/null +++ b/src/infiniop/ops/blas_amax/info.h @@ -0,0 +1,41 @@ +#ifndef __BLAS_AMAX_INFO_H__ +#define __BLAS_AMAX_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class BlasAmaxInfo { +private: + BlasAmaxInfo() = default; + +public: + size_t n; + ptrdiff_t incx; + infiniDtype_t data_type; + + static utils::Result createBlasAmaxInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = x_desc->dtype(); + auto itype = result_desc->dtype(); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_DTYPE(itype, INFINI_DTYPE_I32); + + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto n = x_desc->numel(); + auto incx = x_desc->stride(0); + + return utils::Result(BlasAmaxInfo{ + n, + incx, + data_type}); + } +}; + +#endif // __BLAS_AMAX_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc b/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc index 298ad5510..cb73bdd67 100644 --- a/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc +++ b/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc @@ -19,11 +19,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = BlasAmaxInfo::createBlasAmaxInfo(x_desc, result_desc); - CHECK_RESULT(info); + auto result = BlasAmaxInfo::createBlasAmaxInfo(x_desc, result_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -42,9 +42,9 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.getSize(); - const ptrdiff_t incx = _info.getIncx(); - const infiniDtype_t data_type = _info.getDtype(); + const size_t size = _info.n; + const ptrdiff_t incx = _info.incx; + const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, diff --git a/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu b/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu index 4b727093c..7c0f795f0 100644 --- a/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu +++ b/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu @@ -13,12 +13,12 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = BlasAminInfo::createBlasAminInfo(x_desc, result_desc); - CHECK_RESULT(info); + auto result = BlasAminInfo::createBlasAminInfo(x_desc, result_desc); + CHECK_RESULT(result); // Create descriptor *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -34,8 +34,8 @@ infiniStatus_t calculateBlasAmin( int *result, cnrtQueue_t queue) { - const size_t size = info.getSize(); - const ptrdiff_t incx = info.getIncx(); + const size_t size = info.n; + const ptrdiff_t incx = info.incx; cnrtDim3_t k_dim; cnrtFunctionType_t k_type; @@ -79,7 +79,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_BLAS_AMIN(half); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/blas_amin/blas_amin.h b/src/infiniop/ops/blas_amin/blas_amin.h index a05db711e..8cf7ceca6 100644 --- a/src/infiniop/ops/blas_amin/blas_amin.h +++ b/src/infiniop/ops/blas_amin/blas_amin.h @@ -1,10 +1,8 @@ #ifndef __BLAS_AMIN_H__ #define __BLAS_AMIN_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/blas_amin.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -46,45 +44,4 @@ }; \ } -class BlasAminInfo { -private: - size_t _size; - ptrdiff_t _incx; - infiniDtype_t _dtype; - - BlasAminInfo(size_t size, - ptrdiff_t incx, - infiniDtype_t dtype) - : _size(size), _incx(incx), _dtype(dtype) {} - -public: - inline size_t getSize() const { return _size; } - inline ptrdiff_t getIncx() const { return _incx; } - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static utils::Result createBlasAminInfo( - infiniopTensorDescriptor_t x_desc, - infiniopTensorDescriptor_t result_desc) { - - CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = x_desc->dtype(); - auto itype = result_desc->dtype(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - CHECK_DTYPE(itype, INFINI_DTYPE_I32); - - CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - - auto size = x_desc->numel(); - auto incx = x_desc->stride(0); - - BlasAminInfo info(size, incx, dtype); - return ResultType(std::move(info)); - } -}; - #endif // __BLAS_AMIN_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.cc b/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.cc index ce6b567c0..57c060c63 100644 --- a/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.cc +++ b/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.cc @@ -12,12 +12,12 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = BlasAminInfo::createBlasAminInfo(x_desc, result_desc); - CHECK_RESULT(info); + auto result = BlasAminInfo::createBlasAminInfo(x_desc, result_desc); + CHECK_RESULT(result); // Create descriptor *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -32,8 +32,8 @@ infiniStatus_t calculateBlasAmin( const Tdata *x, int *result) { - const ptrdiff_t size = info.getSize(); - const ptrdiff_t incx = info.getIncx(); + const ptrdiff_t size = info.n; + const ptrdiff_t incx = info.incx; if (size < 1 || incx == 0) { result[0] = 0; @@ -84,7 +84,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace_size; (void)stream; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_BLAS_AMIN(fp16_t); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/blas_amin/info.h b/src/infiniop/ops/blas_amin/info.h new file mode 100644 index 000000000..a9b187f92 --- /dev/null +++ b/src/infiniop/ops/blas_amin/info.h @@ -0,0 +1,41 @@ +#ifndef __BLAS_AMIN_INFO_H__ +#define __BLAS_AMIN_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class BlasAminInfo { +private: + BlasAminInfo() = default; + +public: + size_t n; + ptrdiff_t incx; + infiniDtype_t data_type; + + static utils::Result createBlasAminInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = x_desc->dtype(); + auto itype = result_desc->dtype(); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_DTYPE(itype, INFINI_DTYPE_I32); + + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto n = x_desc->numel(); + auto incx = x_desc->stride(0); + + return utils::Result(BlasAminInfo{ + n, + incx, + data_type}); + } +}; + +#endif // __BLAS_AMIN_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc b/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc index d924f7f63..8217eacf6 100644 --- a/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc +++ b/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc @@ -19,11 +19,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = BlasAminInfo::createBlasAminInfo(x_desc, result_desc); - CHECK_RESULT(info); + auto result = BlasAminInfo::createBlasAminInfo(x_desc, result_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -42,9 +42,9 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.getSize(); - const ptrdiff_t incx = _info.getIncx(); - const infiniDtype_t data_type = _info.getDtype(); + const size_t size = _info.n; + const ptrdiff_t incx = _info.incx; + const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, diff --git a/src/infiniop/ops/blas_copy/bang/blas_copy_bang.mlu b/src/infiniop/ops/blas_copy/bang/blas_copy_bang.mlu index 20a787792..65e68b631 100644 --- a/src/infiniop/ops/blas_copy/bang/blas_copy_bang.mlu +++ b/src/infiniop/ops/blas_copy/bang/blas_copy_bang.mlu @@ -13,11 +13,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t y_desc) { auto handle = reinterpret_cast(handle_); - auto info = BlasCopyInfo::createBlasCopyInfo(x_desc, y_desc); - CHECK_RESULT(info); + auto result = BlasCopyInfo::createBlasCopyInfo(x_desc, y_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -41,18 +41,18 @@ infiniStatus_t calculateBlasCopy( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.getIncx() == 1 && info.getIncy() == 1) { + if (info.incx == 1 && info.incy == 1) { blasCopyKernelContiguous<<>>( - info.getSize(), + info.n, x, y); } else { blasCopyKernelStrided<<>>( - info.getSize(), + info.n, x, - info.getIncx(), + info.incx, y, - info.getIncy()); + info.incy); } cnrtQueueSync(queue); @@ -75,7 +75,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_BLAS_COPY(half); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/blas_copy/blas_copy.h b/src/infiniop/ops/blas_copy/blas_copy.h index b2fa23970..3670ba204 100644 --- a/src/infiniop/ops/blas_copy/blas_copy.h +++ b/src/infiniop/ops/blas_copy/blas_copy.h @@ -1,10 +1,8 @@ #ifndef __BLAS_COPY_H__ #define __BLAS_COPY_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/blas_copy.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -46,50 +44,4 @@ }; \ } -class BlasCopyInfo { -private: - size_t _size; - ptrdiff_t _incx; - ptrdiff_t _incy; - infiniDtype_t _dtype; - - BlasCopyInfo(size_t size, - ptrdiff_t incx, - ptrdiff_t incy, - infiniDtype_t dtype) - : _size(size), _incx(incx), _incy(incy), _dtype(dtype) {} - -public: - inline size_t getSize() const { return _size; } - inline ptrdiff_t getIncx() const { return _incx; } - inline ptrdiff_t getIncy() const { return _incy; } - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static utils::Result createBlasCopyInfo( - infiniopTensorDescriptor_t x_desc, - infiniopTensorDescriptor_t y_desc) { - - CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = x_desc->dtype(); - - CHECK_OR_RETURN(y_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - - CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); - - auto size = x_desc->numel(); - auto incx = x_desc->stride(0); - auto incy = y_desc->stride(0); - - BlasCopyInfo info(size, incx, incy, dtype); - return ResultType(std::move(info)); - } -}; - #endif // __BLAS_COPY_H__ diff --git a/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.cc b/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.cc index 979aa9df0..d4671fb72 100644 --- a/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.cc +++ b/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.cc @@ -12,11 +12,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t y_desc) { auto handle = reinterpret_cast(handle_); - auto info = BlasCopyInfo::createBlasCopyInfo(x_desc, y_desc); - CHECK_RESULT(info); + auto result = BlasCopyInfo::createBlasCopyInfo(x_desc, y_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -31,11 +31,11 @@ infiniStatus_t calculateBlasCopy( const Tdata *x, Tdata *y) { - const ptrdiff_t size = info.getSize(); + const ptrdiff_t size = info.n; for (ptrdiff_t i = 0; i < size; ++i) { - size_t x_idx = i * info.getIncx(); - size_t y_idx = i * info.getIncy(); + size_t x_idx = i * info.incx; + size_t y_idx = i * info.incy; y[y_idx] = x[x_idx]; } @@ -58,7 +58,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace_size; (void)stream; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_BLAS_COPY(fp16_t); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/blas_copy/info.h b/src/infiniop/ops/blas_copy/info.h new file mode 100644 index 000000000..ec936b1bd --- /dev/null +++ b/src/infiniop/ops/blas_copy/info.h @@ -0,0 +1,45 @@ +#ifndef __BLAS_COPY_INFO_H__ +#define __BLAS_COPY_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class BlasCopyInfo { +private: + BlasCopyInfo() = default; + +public: + size_t n; + ptrdiff_t incx; + ptrdiff_t incy; + infiniDtype_t data_type; + + static utils::Result createBlasCopyInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = x_desc->dtype(); + + CHECK_OR_RETURN(y_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto n = x_desc->numel(); + auto incx = x_desc->stride(0); + auto incy = y_desc->stride(0); + + return utils::Result(BlasCopyInfo{ + n, + incx, + incy, + data_type}); + } +}; + +#endif // __BLAS_COPY_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_copy/metax/blas_copy_metax.cc b/src/infiniop/ops/blas_copy/metax/blas_copy_metax.cc index 8e8a8db10..398688a87 100644 --- a/src/infiniop/ops/blas_copy/metax/blas_copy_metax.cc +++ b/src/infiniop/ops/blas_copy/metax/blas_copy_metax.cc @@ -19,11 +19,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t y_desc) { auto handle = reinterpret_cast(handle_); - auto info = BlasCopyInfo::createBlasCopyInfo(x_desc, y_desc); - CHECK_RESULT(info); + auto result = BlasCopyInfo::createBlasCopyInfo(x_desc, y_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -42,10 +42,10 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.getSize(); - const ptrdiff_t incx = _info.getIncx(); - const ptrdiff_t incy = _info.getIncy(); - const infiniDtype_t data_type = _info.getDtype(); + const size_t size = _info.n; + const ptrdiff_t incx = _info.incx; + const ptrdiff_t incy = _info.incy; + const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, diff --git a/src/infiniop/ops/blas_dot/bang/blas_dot_bang.mlu b/src/infiniop/ops/blas_dot/bang/blas_dot_bang.mlu index 481905be1..45c93d69c 100644 --- a/src/infiniop/ops/blas_dot/bang/blas_dot_bang.mlu +++ b/src/infiniop/ops/blas_dot/bang/blas_dot_bang.mlu @@ -14,11 +14,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = BlasDotInfo::createBlasDotInfo(x_desc, y_desc, result_desc); - CHECK_RESULT(info); + auto result = BlasDotInfo::createBlasDotInfo(x_desc, y_desc, result_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -43,19 +43,19 @@ infiniStatus_t calculateBlasDot( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.getIncx() == 1 && info.getIncy() == 1) { + if (info.incx == 1 && info.incy == 1) { blasDotKernelContiguous<<>>( - info.getSize(), + info.n, x, y, result); } else { blasDotKernelStrided<<>>( - info.getSize(), + info.n, x, - info.getIncx(), + info.incx, y, - info.getIncy(), + info.incy, result); } @@ -82,7 +82,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_BLAS_DOT(half); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/blas_dot/blas_dot.h b/src/infiniop/ops/blas_dot/blas_dot.h index 00675e55c..09e81a7fc 100644 --- a/src/infiniop/ops/blas_dot/blas_dot.h +++ b/src/infiniop/ops/blas_dot/blas_dot.h @@ -1,10 +1,8 @@ #ifndef __BLAS_DOT_H__ #define __BLAS_DOT_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/blas_dot.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -48,54 +46,4 @@ }; \ } -class BlasDotInfo { -private: - size_t _size; - ptrdiff_t _incx; - ptrdiff_t _incy; - infiniDtype_t _dtype; - - BlasDotInfo(size_t size, - ptrdiff_t incx, - ptrdiff_t incy, - infiniDtype_t dtype) - : _size(size), _incx(incx), _incy(incy), _dtype(dtype) {} - -public: - inline size_t getSize() const { return _size; } - inline ptrdiff_t getIncx() const { return _incx; } - inline ptrdiff_t getIncy() const { return _incy; } - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static ResultType createBlasDotInfo( - infiniopTensorDescriptor_t x_desc, - infiniopTensorDescriptor_t y_desc, - infiniopTensorDescriptor_t result_desc) { - - CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = x_desc->dtype(); - - CHECK_OR_RETURN(y_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(result_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - - CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - - auto size = x_desc->numel(); - auto incx = x_desc->stride(0); - auto incy = y_desc->stride(0); - - BlasDotInfo info(size, incx, incy, dtype); - return ResultType(std::move(info)); - } -}; - #endif // __BLAS_DOT_H__ diff --git a/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc b/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc index c138baa78..e88b2abd9 100644 --- a/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc +++ b/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc @@ -13,11 +13,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = BlasDotInfo::createBlasDotInfo(x_desc, y_desc, result_desc); - CHECK_RESULT(info); + auto result = BlasDotInfo::createBlasDotInfo(x_desc, y_desc, result_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -33,9 +33,9 @@ infiniStatus_t calculateBlasDot( const Tdata *y, Tdata *result) { - const ptrdiff_t n = info.getSize(); - const ptrdiff_t incx = info.getIncx(); - const ptrdiff_t incy = info.getIncy(); + const ptrdiff_t n = info.n; + const ptrdiff_t incx = info.incx; + const ptrdiff_t incy = info.incy; ptrdiff_t ix = (incx < 0) ? (1 - n) * incx : 0; ptrdiff_t iy = (incy < 0) ? (1 - n) * incy : 0; @@ -83,7 +83,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace_size; (void)stream; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_BLAS_DOT(fp16_t); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/blas_dot/info.h b/src/infiniop/ops/blas_dot/info.h new file mode 100644 index 000000000..b06decd6b --- /dev/null +++ b/src/infiniop/ops/blas_dot/info.h @@ -0,0 +1,49 @@ +#ifndef __BLAS_DOT_INFO_H__ +#define __BLAS_DOT_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class BlasDotInfo { +private: + BlasDotInfo() = default; + +public: + size_t n; + ptrdiff_t incx; + ptrdiff_t incy; + infiniDtype_t data_type; + + static utils::Result createBlasDotInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t result_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = x_desc->dtype(); + + CHECK_OR_RETURN(y_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(result_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto n = x_desc->numel(); + auto incx = x_desc->stride(0); + auto incy = y_desc->stride(0); + + return utils::Result(BlasDotInfo{ + n, + incx, + incy, + data_type}); + } +}; + +#endif // __BLAS_DOT_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc b/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc index a70ba48c2..220886132 100644 --- a/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc +++ b/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc @@ -20,11 +20,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = BlasDotInfo::createBlasDotInfo(x_desc, y_desc, result_desc); - CHECK_RESULT(info); + auto result = BlasDotInfo::createBlasDotInfo(x_desc, y_desc, result_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -44,10 +44,10 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.getSize(); - const ptrdiff_t incx = _info.getIncx(); - const ptrdiff_t incy = _info.getIncy(); - const infiniDtype_t data_type = _info.getDtype(); + const size_t size = _info.n; + const ptrdiff_t incx = _info.incx; + const ptrdiff_t incy = _info.incy; + const infiniDtype_t data_type = _info.data_type; hpccDataType x_type, y_type, result_type; hpccDataType execution_type; @@ -76,7 +76,9 @@ infiniStatus_t Descriptor::calculate( CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, [&](hcblasHandle_t handle) { - CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + CHECK_MCBLAS(hcblasSetPointerMode( + handle, + HCBLAS_POINTER_MODE_DEVICE)); CHECK_MCBLAS(hcblasDotEx( handle, diff --git a/src/infiniop/ops/nrm2/bang/nrm2_bang.mlu b/src/infiniop/ops/nrm2/bang/nrm2_bang.mlu index d935c688e..6e25f5b68 100644 --- a/src/infiniop/ops/nrm2/bang/nrm2_bang.mlu +++ b/src/infiniop/ops/nrm2/bang/nrm2_bang.mlu @@ -13,12 +13,12 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = Nrm2Info::createNrm2Info(x_desc, result_desc); - CHECK_RESULT(info); + auto result = Nrm2Info::createNrm2Info(x_desc, result_desc); + CHECK_RESULT(result); // Create descriptor *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -42,16 +42,16 @@ infiniStatus_t calculateNrm2( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.getIncx() == 1) { + if (info.incx == 1) { Nrm2KernelContiguous<<>>( - info.getSize(), + info.n, x, result); } else { Nrm2KernelStrided<<>>( - info.getSize(), + info.n, x, - info.getIncx(), + info.incx, result); } @@ -76,7 +76,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_NRM2(half); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc b/src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc index f36eb33ce..d6a45b2d2 100644 --- a/src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc +++ b/src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc @@ -15,12 +15,12 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = Nrm2Info::createNrm2Info(x_desc, result_desc); - CHECK_RESULT(info); + auto result = Nrm2Info::createNrm2Info(x_desc, result_desc); + CHECK_RESULT(result); // Create descriptor *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -37,8 +37,8 @@ infiniStatus_t calculateNrm2( using Tcompute = std::conditional_t, double, float>; - const ptrdiff_t n = info.getSize(); - const ptrdiff_t incx = info.getIncx(); + const ptrdiff_t n = info.n; + const ptrdiff_t incx = info.incx; // Blue's scaling constants (float vs double) constexpr Tcompute tsml = [] { @@ -145,7 +145,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace_size; (void)stream; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_NRM2(fp16_t); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/nrm2/info.h b/src/infiniop/ops/nrm2/info.h new file mode 100644 index 000000000..acf11ba25 --- /dev/null +++ b/src/infiniop/ops/nrm2/info.h @@ -0,0 +1,41 @@ +#ifndef __NRM2_INFO_H__ +#define __NRM2_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class Nrm2Info { +private: + Nrm2Info() = default; + +public: + size_t n; + ptrdiff_t incx; + infiniDtype_t data_type; + + static utils::Result createNrm2Info( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t result_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = x_desc->dtype(); + + CHECK_OR_RETURN(result_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto n = x_desc->numel(); + auto incx = x_desc->stride(0); + + return utils::Result(Nrm2Info{ + n, + incx, + data_type}); + } +}; + +#endif // __NRM2_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/nrm2/metax/nrm2_metax.cc b/src/infiniop/ops/nrm2/metax/nrm2_metax.cc index 12319182b..6c7c6ff89 100644 --- a/src/infiniop/ops/nrm2/metax/nrm2_metax.cc +++ b/src/infiniop/ops/nrm2/metax/nrm2_metax.cc @@ -19,11 +19,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t result_desc) { auto handle = reinterpret_cast(handle_); - auto info = Nrm2Info::createNrm2Info(x_desc, result_desc); - CHECK_RESULT(info); + auto result = Nrm2Info::createNrm2Info(x_desc, result_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -42,9 +42,9 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.getSize(); - const ptrdiff_t incx = _info.getIncx(); - const infiniDtype_t data_type = _info.getDtype(); + const size_t size = _info.n; + const ptrdiff_t incx = _info.incx; + const infiniDtype_t data_type = _info.data_type; hpccDataType x_type, result_type; hpccDataType execution_type; @@ -73,7 +73,9 @@ infiniStatus_t Descriptor::calculate( CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, [&](hcblasHandle_t handle) { - CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + CHECK_MCBLAS(hcblasSetPointerMode( + handle, + HCBLAS_POINTER_MODE_DEVICE)); CHECK_MCBLAS(hcblasNrm2Ex( handle, diff --git a/src/infiniop/ops/nrm2/nrm2.h b/src/infiniop/ops/nrm2/nrm2.h index 82fb58e0b..a4094cd67 100644 --- a/src/infiniop/ops/nrm2/nrm2.h +++ b/src/infiniop/ops/nrm2/nrm2.h @@ -1,10 +1,8 @@ #ifndef __NRM2_H__ #define __NRM2_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/nrm2.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -46,45 +44,4 @@ }; \ } -class Nrm2Info { -private: - size_t _size; - ptrdiff_t _incx; - infiniDtype_t _dtype; - - Nrm2Info(size_t size, - ptrdiff_t incx, - infiniDtype_t dtype) - : _size(size), _incx(incx), _dtype(dtype) {} - -public: - inline size_t getSize() const { return _size; } - inline ptrdiff_t getIncx() const { return _incx; } - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static ResultType createNrm2Info( - infiniopTensorDescriptor_t x_desc, - infiniopTensorDescriptor_t result_desc) { - - CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(result_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = x_desc->dtype(); - - CHECK_OR_RETURN(result_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - - CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(result_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - - auto size = x_desc->numel(); - auto incx = x_desc->stride(0); - - Nrm2Info info(size, incx, dtype); - return ResultType(std::move(info)); - } -}; - #endif // __NRM2_H__ diff --git a/src/infiniop/ops/rot/bang/rot_bang.mlu b/src/infiniop/ops/rot/bang/rot_bang.mlu index 755e8b638..7da1728e8 100644 --- a/src/infiniop/ops/rot/bang/rot_bang.mlu +++ b/src/infiniop/ops/rot/bang/rot_bang.mlu @@ -15,11 +15,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t s_desc) { auto handle = reinterpret_cast(handle_); - auto info = RotInfo::createRotInfo(x_desc, y_desc, c_desc, s_desc); - CHECK_RESULT(info); + auto result = RotInfo::createRotInfo(x_desc, y_desc, c_desc, s_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -44,20 +44,20 @@ infiniStatus_t calculateRot( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.getIncx() == 1 && info.getIncy() == 1) { + if (info.incx == 1 && info.incy == 1) { rotKernelContiguous<<>>( - info.getSize(), + info.n, x, y, c, s); } else { rotKernelStrided<<>>( - info.getSize(), + info.n, x, - info.getIncx(), + info.incx, y, - info.getIncy(), + info.incy, c, s); } @@ -86,7 +86,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_ROT(half); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/rot/cpu/rot_cpu.cc b/src/infiniop/ops/rot/cpu/rot_cpu.cc index b0eb9dfcb..6087487d1 100644 --- a/src/infiniop/ops/rot/cpu/rot_cpu.cc +++ b/src/infiniop/ops/rot/cpu/rot_cpu.cc @@ -14,11 +14,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t s_desc) { auto handle = reinterpret_cast(handle_); - auto info = RotInfo::createRotInfo(x_desc, y_desc, c_desc, s_desc); - CHECK_RESULT(info); + auto result = RotInfo::createRotInfo(x_desc, y_desc, c_desc, s_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -40,9 +40,9 @@ infiniStatus_t calculateRot( const Tcompute c_val = utils::cast(c[0]); const Tcompute s_val = utils::cast(s[0]); - const ptrdiff_t size = static_cast(info.getSize()); - const ptrdiff_t incx = info.getIncx(); - const ptrdiff_t incy = info.getIncy(); + const ptrdiff_t size = static_cast(info.n); + const ptrdiff_t incx = info.incx; + const ptrdiff_t incy = info.incy; if (size <= 0) { return INFINI_STATUS_SUCCESS; @@ -85,7 +85,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace_size; (void)stream; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_ROT(fp16_t); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/rot/info.h b/src/infiniop/ops/rot/info.h new file mode 100644 index 000000000..93de6413b --- /dev/null +++ b/src/infiniop/ops/rot/info.h @@ -0,0 +1,53 @@ +#ifndef __ROT_INFO_H__ +#define __ROT_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class RotInfo { +private: + RotInfo() = default; + +public: + size_t n; + ptrdiff_t incx; + ptrdiff_t incy; + infiniDtype_t data_type; + + static utils::Result createRotInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t s_desc) { + + CHECK_OR_RETURN(c_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(s_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = x_desc->dtype(); + + CHECK_OR_RETURN(c_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(s_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(y_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_OR_RETURN(c_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(s_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto n = x_desc->numel(); + auto incx = x_desc->stride(0); + auto incy = y_desc->stride(0); + + return utils::Result(RotInfo{ + n, + incx, + incy, + data_type}); + } +}; + +#endif // __ROT_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/rot/metax/rot_metax.cc b/src/infiniop/ops/rot/metax/rot_metax.cc index 2d958f20d..0362643da 100644 --- a/src/infiniop/ops/rot/metax/rot_metax.cc +++ b/src/infiniop/ops/rot/metax/rot_metax.cc @@ -21,11 +21,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t s_desc) { auto handle = reinterpret_cast(handle_); - auto info = RotInfo::createRotInfo(x_desc, y_desc, c_desc, s_desc); - CHECK_RESULT(info); + auto result = RotInfo::createRotInfo(x_desc, y_desc, c_desc, s_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -46,10 +46,10 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.getSize(); - const ptrdiff_t incx = _info.getIncx(); - const ptrdiff_t incy = _info.getIncy(); - const infiniDtype_t data_type = _info.getDtype(); + const size_t size = _info.n; + const ptrdiff_t incx = _info.incx; + const ptrdiff_t incy = _info.incy; + const infiniDtype_t data_type = _info.data_type; hpccDataType x_type, y_type, cs_type; hpccDataType execution_type; @@ -78,7 +78,9 @@ infiniStatus_t Descriptor::calculate( CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, [&](hcblasHandle_t handle) { - CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + CHECK_MCBLAS(hcblasSetPointerMode( + handle, + HCBLAS_POINTER_MODE_DEVICE)); CHECK_MCBLAS(hcblasRotEx( handle, diff --git a/src/infiniop/ops/rot/rot.h b/src/infiniop/ops/rot/rot.h index fc73a38b9..8304442d7 100644 --- a/src/infiniop/ops/rot/rot.h +++ b/src/infiniop/ops/rot/rot.h @@ -1,10 +1,8 @@ #ifndef __ROT_H__ #define __ROT_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/rot.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -50,58 +48,4 @@ }; \ } -class RotInfo { -private: - size_t _size; - ptrdiff_t _incx; - ptrdiff_t _incy; - infiniDtype_t _dtype; - - RotInfo(size_t size, - ptrdiff_t incx, - ptrdiff_t incy, - infiniDtype_t dtype) - : _size(size), _incx(incx), _incy(incy), _dtype(dtype) {} - -public: - inline size_t getSize() const { return _size; } - inline ptrdiff_t getIncx() const { return _incx; } - inline ptrdiff_t getIncy() const { return _incy; } - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static ResultType createRotInfo( - infiniopTensorDescriptor_t x_desc, - infiniopTensorDescriptor_t y_desc, - infiniopTensorDescriptor_t c_desc, - infiniopTensorDescriptor_t s_desc) { - - CHECK_OR_RETURN(c_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(s_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = x_desc->dtype(); - - CHECK_OR_RETURN(c_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(s_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(y_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - - CHECK_OR_RETURN(c_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(s_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); - - auto size = x_desc->numel(); - auto incx = x_desc->stride(0); - auto incy = y_desc->stride(0); - - RotInfo info(size, incx, incy, dtype); - return ResultType(std::move(info)); - } -}; - #endif // __ROT_H__ diff --git a/src/infiniop/ops/rotg/bang/rotg_bang.mlu b/src/infiniop/ops/rotg/bang/rotg_bang.mlu index 9a7f56fdb..b0f271786 100644 --- a/src/infiniop/ops/rotg/bang/rotg_bang.mlu +++ b/src/infiniop/ops/rotg/bang/rotg_bang.mlu @@ -15,11 +15,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t s_desc) { auto handle = reinterpret_cast(handle_); - auto info = RotgInfo::createRotgInfo(x_desc, y_desc, c_desc, s_desc); - CHECK_RESULT(info); + auto result = RotgInfo::createRotgInfo(x_desc, y_desc, c_desc, s_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -72,7 +72,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_ROTG(half); case INFINI_DTYPE_BF16: diff --git a/src/infiniop/ops/rotg/cpu/rotg_cpu.cc b/src/infiniop/ops/rotg/cpu/rotg_cpu.cc index 66f7f1640..bf95a5c3e 100644 --- a/src/infiniop/ops/rotg/cpu/rotg_cpu.cc +++ b/src/infiniop/ops/rotg/cpu/rotg_cpu.cc @@ -16,11 +16,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t s_desc) { auto handle = reinterpret_cast(handle_); - auto info = RotgInfo::createRotgInfo(x_desc, y_desc, c_desc, s_desc); - CHECK_RESULT(info); + auto result = RotgInfo::createRotgInfo(x_desc, y_desc, c_desc, s_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -102,7 +102,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace_size; (void)stream; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_ROTG(fp16_t); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/rotg/info.h b/src/infiniop/ops/rotg/info.h new file mode 100644 index 000000000..e96bea587 --- /dev/null +++ b/src/infiniop/ops/rotg/info.h @@ -0,0 +1,41 @@ +#ifndef __ROTG_INFO_H__ +#define __ROTG_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class RotgInfo { +private: + RotgInfo() = default; + +public: + infiniDtype_t data_type; + + static utils::Result createRotgInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t s_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(c_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(s_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = x_desc->dtype(); + + CHECK_OR_RETURN(y_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(c_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(s_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_OR_RETURN(x_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(c_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(s_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + return utils::Result(RotgInfo{ + data_type}); + } +}; + +#endif // __ROTG_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/rotg/metax/rotg_metax.cc b/src/infiniop/ops/rotg/metax/rotg_metax.cc index 18e64368b..32b12704b 100644 --- a/src/infiniop/ops/rotg/metax/rotg_metax.cc +++ b/src/infiniop/ops/rotg/metax/rotg_metax.cc @@ -21,11 +21,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t s_desc) { auto handle = reinterpret_cast(handle_); - auto info = RotgInfo::createRotgInfo(x_desc, y_desc, c_desc, s_desc); - CHECK_RESULT(info); + auto result = RotgInfo::createRotgInfo(x_desc, y_desc, c_desc, s_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -46,7 +46,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const infiniDtype_t data_type = _info.getDtype(); + const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, diff --git a/src/infiniop/ops/rotg/rotg.h b/src/infiniop/ops/rotg/rotg.h index 7156cdf1a..9aa0a59e0 100644 --- a/src/infiniop/ops/rotg/rotg.h +++ b/src/infiniop/ops/rotg/rotg.h @@ -1,10 +1,8 @@ #ifndef __ROTG_H__ #define __ROTG_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/rotg.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -50,42 +48,4 @@ }; \ } -class RotgInfo { -private: - infiniDtype_t _dtype; - - explicit RotgInfo(infiniDtype_t dtype) : _dtype(dtype) {} - -public: - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static ResultType createRotgInfo( - infiniopTensorDescriptor_t x_desc, - infiniopTensorDescriptor_t y_desc, - infiniopTensorDescriptor_t c_desc, - infiniopTensorDescriptor_t s_desc) { - - CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(c_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(s_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = x_desc->dtype(); - - CHECK_OR_RETURN(y_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(c_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(s_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - CHECK_OR_RETURN(x_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(y_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(c_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(s_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - - RotgInfo info(dtype); - return ResultType(std::move(info)); - } -}; - #endif // __ROTG_H__ diff --git a/src/infiniop/ops/rotm/bang/rotm_bang.mlu b/src/infiniop/ops/rotm/bang/rotm_bang.mlu index 8064953b3..5b1f28be7 100644 --- a/src/infiniop/ops/rotm/bang/rotm_bang.mlu +++ b/src/infiniop/ops/rotm/bang/rotm_bang.mlu @@ -14,11 +14,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t param_desc) { auto handle = reinterpret_cast(handle_); - auto info = RotmInfo::createRotmInfo(x_desc, y_desc, param_desc); - CHECK_RESULT(info); + auto result = RotmInfo::createRotmInfo(x_desc, y_desc, param_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -42,19 +42,19 @@ infiniStatus_t calculateRotm( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.getIncx() == 1 && info.getIncy() == 1) { + if (info.incx == 1 && info.incy == 1) { rotmKernelContiguous<<>>( - info.getSize(), + info.n, x, y, param); } else { rotmKernelStrided<<>>( - info.getSize(), + info.n, x, - info.getIncx(), + info.incx, y, - info.getIncy(), + info.incy, param); } @@ -80,7 +80,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_ROTM(half); case INFINI_DTYPE_BF16: diff --git a/src/infiniop/ops/rotm/cpu/rotm_cpu.cc b/src/infiniop/ops/rotm/cpu/rotm_cpu.cc index 51f11ad74..2f00ba97f 100644 --- a/src/infiniop/ops/rotm/cpu/rotm_cpu.cc +++ b/src/infiniop/ops/rotm/cpu/rotm_cpu.cc @@ -13,11 +13,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t param_desc) { auto handle = reinterpret_cast(handle_); - auto info = RotmInfo::createRotmInfo(x_desc, y_desc, param_desc); - CHECK_RESULT(info); + auto result = RotmInfo::createRotmInfo(x_desc, y_desc, param_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -40,13 +40,13 @@ infiniStatus_t calculateRotm( Tcompute sflag = utils::cast(param[0]); - if (info.getSize() == 0 || (sflag + two == zero)) { + if (info.n == 0 || (sflag + two == zero)) { return INFINI_STATUS_SUCCESS; } - const ptrdiff_t size = static_cast(info.getSize()); - const ptrdiff_t incx = info.getIncx(); - const ptrdiff_t incy = info.getIncy(); + const ptrdiff_t size = static_cast(info.n); + const ptrdiff_t incx = info.incx; + const ptrdiff_t incy = info.incy; const ptrdiff_t kx = incx >= 0 ? 0 : (size - 1) * (-incx); const ptrdiff_t ky = incy >= 0 ? 0 : (size - 1) * (-incy); @@ -150,7 +150,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace_size; (void)stream; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_ROTM(fp16_t); case INFINI_DTYPE_BF16: diff --git a/src/infiniop/ops/rotm/info.h b/src/infiniop/ops/rotm/info.h new file mode 100644 index 000000000..6509e6d75 --- /dev/null +++ b/src/infiniop/ops/rotm/info.h @@ -0,0 +1,50 @@ +#ifndef __ROTM_INFO_H__ +#define __ROTM_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class RotmInfo { +private: + RotmInfo() = default; + +public: + size_t n; + ptrdiff_t incx; + ptrdiff_t incy; + infiniDtype_t data_type; + + static utils::Result createRotmInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t param_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(param_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = x_desc->dtype(); + + CHECK_OR_RETURN(y_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(param_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(param_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(param_desc->numel() == 5, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(param_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + + auto n = x_desc->numel(); + auto incx = x_desc->stride(0); + auto incy = y_desc->stride(0); + + return utils::Result(RotmInfo{ + n, + incx, + incy, + data_type}); + } +}; + +#endif // __ROTM_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/rotm/metax/rotm_metax.cc b/src/infiniop/ops/rotm/metax/rotm_metax.cc index 13def4197..60d750912 100644 --- a/src/infiniop/ops/rotm/metax/rotm_metax.cc +++ b/src/infiniop/ops/rotm/metax/rotm_metax.cc @@ -20,11 +20,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t param_desc) { auto handle = reinterpret_cast(handle_); - auto info = RotmInfo::createRotmInfo(x_desc, y_desc, param_desc); - CHECK_RESULT(info); + auto result = RotmInfo::createRotmInfo(x_desc, y_desc, param_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -44,10 +44,10 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.getSize(); - const ptrdiff_t incx = _info.getIncx(); - const ptrdiff_t incy = _info.getIncy(); - const infiniDtype_t data_type = _info.getDtype(); + const size_t size = _info.n; + const ptrdiff_t incx = _info.incx; + const ptrdiff_t incy = _info.incy; + const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, diff --git a/src/infiniop/ops/rotm/rotm.h b/src/infiniop/ops/rotm/rotm.h index 7699b6b3b..65b02e3a8 100644 --- a/src/infiniop/ops/rotm/rotm.h +++ b/src/infiniop/ops/rotm/rotm.h @@ -1,10 +1,8 @@ #ifndef __ROTM_H__ #define __ROTM_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/rotm.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -48,55 +46,4 @@ }; \ } -class RotmInfo { -private: - size_t _size; - ptrdiff_t _incx; - ptrdiff_t _incy; - infiniDtype_t _dtype; - - RotmInfo(size_t size, - ptrdiff_t incx, - ptrdiff_t incy, - infiniDtype_t dtype) - : _size(size), _incx(incx), _incy(incy), _dtype(dtype) {} - -public: - inline size_t getSize() const { return _size; } - inline ptrdiff_t getIncx() const { return _incx; } - inline ptrdiff_t getIncy() const { return _incy; } - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static ResultType createRotmInfo( - infiniopTensorDescriptor_t x_desc, - infiniopTensorDescriptor_t y_desc, - infiniopTensorDescriptor_t param_desc) { - - CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(param_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = x_desc->dtype(); - - CHECK_OR_RETURN(y_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(param_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(param_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(param_desc->numel() == 5, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(param_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); - - auto size = x_desc->numel(); - auto incx = x_desc->stride(0); - auto incy = y_desc->stride(0); - - RotmInfo info(size, incx, incy, dtype); - return ResultType(std::move(info)); - } -}; - #endif // __ROTM_H__ diff --git a/src/infiniop/ops/rotmg/bang/rotmg_bang.mlu b/src/infiniop/ops/rotmg/bang/rotmg_bang.mlu index 367fa9fff..53f7134b4 100644 --- a/src/infiniop/ops/rotmg/bang/rotmg_bang.mlu +++ b/src/infiniop/ops/rotmg/bang/rotmg_bang.mlu @@ -16,11 +16,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t param_desc) { auto handle = reinterpret_cast(handle_); - auto info = RotmgInfo::createRotmgInfo(d1_desc, d2_desc, x1_desc, y1_desc, param_desc); - CHECK_RESULT(info); + auto result = RotmgInfo::createRotmgInfo(d1_desc, d2_desc, x1_desc, y1_desc, param_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -78,7 +78,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_ROTMG(half); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/rotmg/cpu/rotmg_cpu.cc b/src/infiniop/ops/rotmg/cpu/rotmg_cpu.cc index 0d291b74c..258a936bf 100644 --- a/src/infiniop/ops/rotmg/cpu/rotmg_cpu.cc +++ b/src/infiniop/ops/rotmg/cpu/rotmg_cpu.cc @@ -17,11 +17,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t param_desc) { auto handle = reinterpret_cast(handle_); - auto info = RotmgInfo::createRotmgInfo(d1_desc, d2_desc, x1_desc, y1_desc, param_desc); - CHECK_RESULT(info); + auto result = RotmgInfo::createRotmgInfo(d1_desc, d2_desc, x1_desc, y1_desc, param_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -202,7 +202,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace_size; (void)stream; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_ROTMG(fp16_t); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/rotmg/info.h b/src/infiniop/ops/rotmg/info.h new file mode 100644 index 000000000..cc9f2826b --- /dev/null +++ b/src/infiniop/ops/rotmg/info.h @@ -0,0 +1,48 @@ +#ifndef __ROTMG_INFO_H__ +#define __ROTMG_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class RotmgInfo { +private: + RotmgInfo() = default; + +public: + infiniDtype_t data_type; + + static utils::Result createRotmgInfo( + infiniopTensorDescriptor_t d1_desc, + infiniopTensorDescriptor_t d2_desc, + infiniopTensorDescriptor_t x1_desc, + infiniopTensorDescriptor_t y1_desc, + infiniopTensorDescriptor_t param_desc) { + + CHECK_OR_RETURN(d1_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(d2_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(x1_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y1_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(param_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = d1_desc->dtype(); + + CHECK_OR_RETURN(d2_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(x1_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(y1_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(param_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_OR_RETURN(param_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + CHECK_OR_RETURN(d1_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(d2_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x1_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y1_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(param_desc->numel() == 5, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(param_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + + return utils::Result(RotmgInfo{ + data_type}); + } +}; + +#endif // __ROTMG_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/rotmg/metax/rotmg_metax.cc b/src/infiniop/ops/rotmg/metax/rotmg_metax.cc index 5136711bd..8fe2a881c 100644 --- a/src/infiniop/ops/rotmg/metax/rotmg_metax.cc +++ b/src/infiniop/ops/rotmg/metax/rotmg_metax.cc @@ -22,11 +22,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t param_desc) { auto handle = reinterpret_cast(handle_); - auto info = RotmgInfo::createRotmgInfo(d1_desc, d2_desc, x1_desc, y1_desc, param_desc); - CHECK_RESULT(info); + auto result = RotmgInfo::createRotmgInfo(d1_desc, d2_desc, x1_desc, y1_desc, param_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -48,7 +48,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const infiniDtype_t data_type = _info.getDtype(); + const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, diff --git a/src/infiniop/ops/rotmg/rotmg.h b/src/infiniop/ops/rotmg/rotmg.h index 8b3e93a08..4cfc97ef0 100644 --- a/src/infiniop/ops/rotmg/rotmg.h +++ b/src/infiniop/ops/rotmg/rotmg.h @@ -1,10 +1,8 @@ #ifndef __ROTMG_H__ #define __ROTMG_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/rotmg.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -52,49 +50,4 @@ }; \ } -class RotmgInfo { -private: - infiniDtype_t _dtype; - - explicit RotmgInfo(infiniDtype_t dtype) : _dtype(dtype) {} - -public: - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static ResultType createRotmgInfo( - infiniopTensorDescriptor_t d1_desc, - infiniopTensorDescriptor_t d2_desc, - infiniopTensorDescriptor_t x1_desc, - infiniopTensorDescriptor_t y1_desc, - infiniopTensorDescriptor_t param_desc) { - - CHECK_OR_RETURN(d1_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(d2_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(x1_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(y1_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(param_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = d1_desc->dtype(); - - CHECK_OR_RETURN(d2_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(x1_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(y1_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(param_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - CHECK_OR_RETURN(param_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - - CHECK_OR_RETURN(d1_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(d2_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(x1_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(y1_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(param_desc->numel() == 5, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(param_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); - - RotmgInfo info(dtype); - return ResultType(std::move(info)); - } -}; - #endif // __ROTMG_H__ diff --git a/src/infiniop/ops/scal/bang/scal_bang.mlu b/src/infiniop/ops/scal/bang/scal_bang.mlu index b9407ca34..31d8a480f 100644 --- a/src/infiniop/ops/scal/bang/scal_bang.mlu +++ b/src/infiniop/ops/scal/bang/scal_bang.mlu @@ -13,11 +13,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t x_desc) { auto handle = reinterpret_cast(handle_); - auto info = ScalInfo::createScalInfo(alpha_desc, x_desc); - CHECK_RESULT(info); + auto result = ScalInfo::createScalInfo(alpha_desc, x_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -41,17 +41,17 @@ infiniStatus_t calculateScal( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.getIncx() == 1) { + if (info.incx == 1) { scalKernelContiguous<<>>( - info.getSize(), + info.n, alpha, x); } else { scalKernelStrided<<>>( - info.getSize(), + info.n, alpha, x, - info.getIncx()); + info.incx); } cnrtQueueSync(queue); @@ -75,7 +75,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_SCAL(half); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/scal/cpu/scal_cpu.cc b/src/infiniop/ops/scal/cpu/scal_cpu.cc index 42678afeb..d54aa8543 100644 --- a/src/infiniop/ops/scal/cpu/scal_cpu.cc +++ b/src/infiniop/ops/scal/cpu/scal_cpu.cc @@ -12,11 +12,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t x_desc) { auto handle = reinterpret_cast(handle_); - auto info = ScalInfo::createScalInfo(alpha_desc, x_desc); - CHECK_RESULT(info); + auto result = ScalInfo::createScalInfo(alpha_desc, x_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -31,8 +31,8 @@ infiniStatus_t calculateScal( const Tdata *alpha, Tdata *x) { - const ptrdiff_t size = info.getSize(); - const ptrdiff_t incx = info.getIncx(); + const ptrdiff_t size = info.n; + const ptrdiff_t incx = info.incx; for (ptrdiff_t i = 0; i < size; ++i) { const ptrdiff_t idx = i * incx; @@ -62,7 +62,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_SCAL(fp16_t); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/scal/info.h b/src/infiniop/ops/scal/info.h new file mode 100644 index 000000000..fd104e5d5 --- /dev/null +++ b/src/infiniop/ops/scal/info.h @@ -0,0 +1,40 @@ +#ifndef __SCAL_INFO_H__ +#define __SCAL_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class ScalInfo { +private: + ScalInfo() = default; + +public: + size_t n; + ptrdiff_t incx; + infiniDtype_t data_type; + + static utils::Result createScalInfo( + infiniopTensorDescriptor_t alpha_desc, + infiniopTensorDescriptor_t x_desc) { + + CHECK_OR_RETURN(alpha_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = x_desc->dtype(); + + CHECK_OR_RETURN(alpha_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_OR_RETURN(alpha_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto n = x_desc->numel(); + auto incx = x_desc->stride(0); + + return utils::Result(ScalInfo{ + n, + incx, + data_type}); + } +}; + +#endif // __SCAL_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/scal/metax/scal_metax.cc b/src/infiniop/ops/scal/metax/scal_metax.cc index c5542e39d..a4320cb23 100644 --- a/src/infiniop/ops/scal/metax/scal_metax.cc +++ b/src/infiniop/ops/scal/metax/scal_metax.cc @@ -19,11 +19,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t x_desc) { auto handle = reinterpret_cast(handle_); - auto info = ScalInfo::createScalInfo(alpha_desc, x_desc); - CHECK_RESULT(info); + auto result = ScalInfo::createScalInfo(alpha_desc, x_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -42,9 +42,9 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.getSize(); - const ptrdiff_t incx = _info.getIncx(); - const infiniDtype_t data_type = _info.getDtype(); + const size_t size = _info.n; + const ptrdiff_t incx = _info.incx; + const infiniDtype_t data_type = _info.data_type; hpccDataType alpha_type, x_type; hpccDataType execution_type; @@ -73,7 +73,9 @@ infiniStatus_t Descriptor::calculate( CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, [&](hcblasHandle_t handle) { - CHECK_MCBLAS(hcblasSetPointerMode(handle, HCBLAS_POINTER_MODE_DEVICE)); + CHECK_MCBLAS(hcblasSetPointerMode( + handle, + HCBLAS_POINTER_MODE_DEVICE)); CHECK_MCBLAS(hcblasScalEx( handle, diff --git a/src/infiniop/ops/scal/scal.h b/src/infiniop/ops/scal/scal.h index 25551d0fa..e94bbe2b2 100644 --- a/src/infiniop/ops/scal/scal.h +++ b/src/infiniop/ops/scal/scal.h @@ -1,10 +1,8 @@ #ifndef __SCAL_H__ #define __SCAL_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/scal.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -46,44 +44,4 @@ }; \ } -class ScalInfo { -private: - size_t _size; - ptrdiff_t _incx; - infiniDtype_t _dtype; - - ScalInfo(size_t size, - ptrdiff_t incx, - infiniDtype_t dtype) - : _size(size), _incx(incx), _dtype(dtype) {} - -public: - inline size_t getSize() const { return _size; } - inline ptrdiff_t getIncx() const { return _incx; } - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static ResultType createScalInfo( - infiniopTensorDescriptor_t alpha_desc, - infiniopTensorDescriptor_t x_desc) { - - CHECK_OR_RETURN(alpha_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = x_desc->dtype(); - - CHECK_OR_RETURN(alpha_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - CHECK_OR_RETURN(alpha_desc->numel() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - - auto size = x_desc->numel(); - auto incx = x_desc->stride(0); - - ScalInfo info(size, incx, dtype); - return ResultType(std::move(info)); - } -}; - #endif // __SCAL_H__ diff --git a/src/infiniop/ops/swap/bang/swap_bang.mlu b/src/infiniop/ops/swap/bang/swap_bang.mlu index 041d09654..855997b38 100644 --- a/src/infiniop/ops/swap/bang/swap_bang.mlu +++ b/src/infiniop/ops/swap/bang/swap_bang.mlu @@ -13,11 +13,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t y_desc) { auto handle = reinterpret_cast(handle_); - auto info = SwapInfo::createSwapInfo(x_desc, y_desc); - CHECK_RESULT(info); + auto result = SwapInfo::createSwapInfo(x_desc, y_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -41,18 +41,18 @@ infiniStatus_t calculateSwap( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.getIncx() == 1 && info.getIncy() == 1) { + if (info.incx == 1 && info.incy == 1) { swapKernelContiguous<<>>( - info.getSize(), + info.n, x, y); } else { swapKernelStrided<<>>( - info.getSize(), + info.n, x, - info.getIncx(), + info.incx, y, - info.getIncy()); + info.incy); } cnrtQueueSync(queue); @@ -75,7 +75,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_SWAP(half); case INFINI_DTYPE_BF16: diff --git a/src/infiniop/ops/swap/cpu/swap_cpu.cc b/src/infiniop/ops/swap/cpu/swap_cpu.cc index 4a59f96c4..3f804d974 100644 --- a/src/infiniop/ops/swap/cpu/swap_cpu.cc +++ b/src/infiniop/ops/swap/cpu/swap_cpu.cc @@ -12,11 +12,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t y_desc) { auto handle = reinterpret_cast(handle_); - auto info = SwapInfo::createSwapInfo(x_desc, y_desc); - CHECK_RESULT(info); + auto result = SwapInfo::createSwapInfo(x_desc, y_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, nullptr, handle->device, @@ -31,9 +31,9 @@ infiniStatus_t calculateSwap( Tdata *x, Tdata *y) { - const ptrdiff_t size = info.getSize(); - const ptrdiff_t incx = info.getIncx(); - const ptrdiff_t incy = info.getIncy(); + const ptrdiff_t size = info.n; + const ptrdiff_t incx = info.incx; + const ptrdiff_t incy = info.incy; #pragma omp parallel for if (size > 1024) for (ptrdiff_t i = 0; i < size; ++i) { @@ -63,7 +63,7 @@ infiniStatus_t Descriptor::calculate( (void)workspace_size; (void)stream; - switch (_info.getDtype()) { + switch (_info.data_type) { case INFINI_DTYPE_F16: return CALCULATE_SWAP(fp16_t); case INFINI_DTYPE_BF16: diff --git a/src/infiniop/ops/swap/info.h b/src/infiniop/ops/swap/info.h new file mode 100644 index 000000000..bb4d84b31 --- /dev/null +++ b/src/infiniop/ops/swap/info.h @@ -0,0 +1,44 @@ +#ifndef __SWAP_INFO_H__ +#define __SWAP_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +class SwapInfo { +private: + SwapInfo() = default; + +public: + size_t n; + ptrdiff_t incx; + ptrdiff_t incy; + infiniDtype_t data_type; + + static utils::Result createSwapInfo( + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); + CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); + + auto data_type = x_desc->dtype(); + + CHECK_OR_RETURN(y_desc->dtype() == data_type, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto n = x_desc->numel(); + auto incx = x_desc->stride(0); + auto incy = y_desc->stride(0); + + return utils::Result(SwapInfo{ + n, + incx, + incy, + data_type}); + } +}; + +#endif // __SWAP_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/swap/metax/swap_metax.cc b/src/infiniop/ops/swap/metax/swap_metax.cc index 4cc7d336a..8c442e533 100644 --- a/src/infiniop/ops/swap/metax/swap_metax.cc +++ b/src/infiniop/ops/swap/metax/swap_metax.cc @@ -19,11 +19,11 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t y_desc) { auto handle = reinterpret_cast(handle_); - auto info = SwapInfo::createSwapInfo(x_desc, y_desc); - CHECK_RESULT(info); + auto result = SwapInfo::createSwapInfo(x_desc, y_desc); + CHECK_RESULT(result); *desc_ptr = new Descriptor( - info.take(), + result.take(), 0, new Opaque{handle->internal()}, handle->device, @@ -42,10 +42,10 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.getSize(); - const ptrdiff_t incx = _info.getIncx(); - const ptrdiff_t incy = _info.getIncy(); - const infiniDtype_t data_type = _info.getDtype(); + const size_t size = _info.n; + const ptrdiff_t incx = _info.incx; + const ptrdiff_t incy = _info.incy; + const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( (hcStream_t)stream, diff --git a/src/infiniop/ops/swap/swap.h b/src/infiniop/ops/swap/swap.h index 7ace7d8ee..f9eadfff9 100644 --- a/src/infiniop/ops/swap/swap.h +++ b/src/infiniop/ops/swap/swap.h @@ -1,10 +1,8 @@ #ifndef __SWAP_H__ #define __SWAP_H__ -#include "../../../utils.h" #include "../../operator.h" -#include "../../tensor.h" -#include "infiniop/ops/swap.h" +#include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ @@ -46,49 +44,4 @@ }; \ } -class SwapInfo { -private: - size_t _size; - ptrdiff_t _incx; - ptrdiff_t _incy; - infiniDtype_t _dtype; - - SwapInfo(size_t size, - ptrdiff_t incx, - ptrdiff_t incy, - infiniDtype_t dtype) - : _size(size), _incx(incx), _incy(incy), _dtype(dtype) {} - -public: - inline size_t getSize() const { return _size; } - inline ptrdiff_t getIncx() const { return _incx; } - inline ptrdiff_t getIncy() const { return _incy; } - inline infiniDtype_t getDtype() const { return _dtype; } - - using ResultType = utils::Result; - - static ResultType createSwapInfo( - infiniopTensorDescriptor_t x_desc, - infiniopTensorDescriptor_t y_desc) { - - CHECK_OR_RETURN(x_desc != nullptr, INFINI_STATUS_NULL_POINTER); - CHECK_OR_RETURN(y_desc != nullptr, INFINI_STATUS_NULL_POINTER); - - auto dtype = x_desc->dtype(); - - CHECK_OR_RETURN(y_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - CHECK_OR_RETURN(x_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(y_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(x_desc->numel() == y_desc->numel(), INFINI_STATUS_BAD_TENSOR_SHAPE); - - auto size = x_desc->numel(); - auto incx = x_desc->stride(0); - auto incy = y_desc->stride(0); - - SwapInfo info(size, incx, incy, dtype); - return ResultType(std::move(info)); - } -}; - #endif // __SWAP_H__ From b4f5bfa04392ce9003daf0110bc52bafddb1e433 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Wed, 6 May 2026 06:32:01 +0000 Subject: [PATCH 15/25] Add InfiniCore `asum`, `blas_amax`, `blas_amin`, `blas_dot` and `nrm2` wrappers --- include/infinicore/ops.hpp | 5 + include/infinicore/ops/asum.hpp | 18 +++ include/infinicore/ops/blas_amax.hpp | 18 +++ include/infinicore/ops/blas_amin.hpp | 18 +++ include/infinicore/ops/blas_dot.hpp | 18 +++ include/infinicore/ops/nrm2.hpp | 18 +++ python/infinicore/__init__.py | 10 ++ python/infinicore/ops/asum.py | 10 ++ python/infinicore/ops/blas_amax.py | 10 ++ python/infinicore/ops/blas_amin.py | 10 ++ python/infinicore/ops/blas_dot.py | 10 ++ python/infinicore/ops/nrm2.py | 10 ++ src/infinicore/ops/asum/asum.cc | 28 +++++ src/infinicore/ops/asum/asum_infiniop.cc | 56 +++++++++ src/infinicore/ops/blas_amax/blas_amax.cc | 28 +++++ .../ops/blas_amax/blas_amax_infiniop.cc | 56 +++++++++ src/infinicore/ops/blas_amin/blas_amin.cc | 28 +++++ .../ops/blas_amin/blas_amin_infiniop.cc | 56 +++++++++ src/infinicore/ops/blas_dot/blas_dot.cc | 28 +++++ .../ops/blas_dot/blas_dot_infiniop.cc | 56 +++++++++ src/infinicore/ops/nrm2/nrm2.cc | 28 +++++ src/infinicore/ops/nrm2/nrm2_infiniop.cc | 56 +++++++++ src/infinicore/pybind11/ops.hpp | 10 ++ src/infinicore/pybind11/ops/asum.hpp | 24 ++++ src/infinicore/pybind11/ops/blas_amax.hpp | 24 ++++ src/infinicore/pybind11/ops/blas_amin.hpp | 24 ++++ src/infinicore/pybind11/ops/blas_dot.hpp | 26 ++++ src/infinicore/pybind11/ops/nrm2.hpp | 24 ++++ test/infinicore/ops/asum.py | 114 ++++++++++++++++++ test/infinicore/ops/blas_amax.py | 100 +++++++++++++++ test/infinicore/ops/blas_amin.py | 100 +++++++++++++++ test/infinicore/ops/blas_dot.py | 111 +++++++++++++++++ test/infinicore/ops/nrm2.py | 104 ++++++++++++++++ 33 files changed, 1236 insertions(+) create mode 100644 include/infinicore/ops/asum.hpp create mode 100644 include/infinicore/ops/blas_amax.hpp create mode 100644 include/infinicore/ops/blas_amin.hpp create mode 100644 include/infinicore/ops/blas_dot.hpp create mode 100644 include/infinicore/ops/nrm2.hpp create mode 100644 python/infinicore/ops/asum.py create mode 100644 python/infinicore/ops/blas_amax.py create mode 100644 python/infinicore/ops/blas_amin.py create mode 100644 python/infinicore/ops/blas_dot.py create mode 100644 python/infinicore/ops/nrm2.py create mode 100644 src/infinicore/ops/asum/asum.cc create mode 100644 src/infinicore/ops/asum/asum_infiniop.cc create mode 100644 src/infinicore/ops/blas_amax/blas_amax.cc create mode 100644 src/infinicore/ops/blas_amax/blas_amax_infiniop.cc create mode 100644 src/infinicore/ops/blas_amin/blas_amin.cc create mode 100644 src/infinicore/ops/blas_amin/blas_amin_infiniop.cc create mode 100644 src/infinicore/ops/blas_dot/blas_dot.cc create mode 100644 src/infinicore/ops/blas_dot/blas_dot_infiniop.cc create mode 100644 src/infinicore/ops/nrm2/nrm2.cc create mode 100644 src/infinicore/ops/nrm2/nrm2_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/asum.hpp create mode 100644 src/infinicore/pybind11/ops/blas_amax.hpp create mode 100644 src/infinicore/pybind11/ops/blas_amin.hpp create mode 100644 src/infinicore/pybind11/ops/blas_dot.hpp create mode 100644 src/infinicore/pybind11/ops/nrm2.hpp create mode 100644 test/infinicore/ops/asum.py create mode 100644 test/infinicore/ops/blas_amax.py create mode 100644 test/infinicore/ops/blas_amin.py create mode 100644 test/infinicore/ops/blas_dot.py create mode 100644 test/infinicore/ops/nrm2.py diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 832f48683..6df98da82 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -6,12 +6,16 @@ #include "ops/addcmul.hpp" #include "ops/asin.hpp" #include "ops/asinh.hpp" +#include "ops/asum.hpp" #include "ops/atanh.hpp" #include "ops/attention.hpp" #include "ops/avg_pool1d.hpp" #include "ops/baddbmm.hpp" #include "ops/bilinear.hpp" #include "ops/binary_cross_entropy_with_logits.hpp" +#include "ops/blas_amax.hpp" +#include "ops/blas_amin.hpp" +#include "ops/blas_dot.hpp" #include "ops/causal_softmax.hpp" #include "ops/cdist.hpp" #include "ops/conv2d.hpp" @@ -28,6 +32,7 @@ #include "ops/layer_norm.hpp" #include "ops/linear.hpp" #include "ops/matmul.hpp" +#include "ops/nrm2.hpp" #include "ops/ones.hpp" #include "ops/paged_attention.hpp" #include "ops/paged_attention_prefill.hpp" diff --git a/include/infinicore/ops/asum.hpp b/include/infinicore/ops/asum.hpp new file mode 100644 index 000000000..df94f2183 --- /dev/null +++ b/include/infinicore/ops/asum.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class Asum { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor result, Tensor x); + static common::OpDispatcher &dispatcher(); +}; + +Tensor asum(Tensor x); +void asum_(Tensor result, Tensor x); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/blas_amax.hpp b/include/infinicore/ops/blas_amax.hpp new file mode 100644 index 000000000..b43dff58e --- /dev/null +++ b/include/infinicore/ops/blas_amax.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class BlasAmax { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor result, Tensor x); + static common::OpDispatcher &dispatcher(); +}; + +Tensor blas_amax(Tensor x); +void blas_amax_(Tensor result, Tensor x); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/blas_amin.hpp b/include/infinicore/ops/blas_amin.hpp new file mode 100644 index 000000000..d4161ce40 --- /dev/null +++ b/include/infinicore/ops/blas_amin.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class BlasAmin { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor result, Tensor x); + static common::OpDispatcher &dispatcher(); +}; + +Tensor blas_amin(Tensor x); +void blas_amin_(Tensor result, Tensor x); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/blas_dot.hpp b/include/infinicore/ops/blas_dot.hpp new file mode 100644 index 000000000..2eff6f396 --- /dev/null +++ b/include/infinicore/ops/blas_dot.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class BlasDot { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor result, Tensor x, Tensor y); + static common::OpDispatcher &dispatcher(); +}; + +Tensor blas_dot(Tensor x, Tensor y); +void blas_dot_(Tensor result, Tensor x, Tensor y); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/nrm2.hpp b/include/infinicore/ops/nrm2.hpp new file mode 100644 index 000000000..ea3c6f8c1 --- /dev/null +++ b/include/infinicore/ops/nrm2.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class Nrm2 { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor result, Tensor x); + static common::OpDispatcher &dispatcher(); +}; + +Tensor nrm2(Tensor x); +void nrm2_(Tensor result, Tensor x); + +} // namespace infinicore::op diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 9c87a5108..17eb7504d 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -58,10 +58,14 @@ from infinicore.ops.argwhere import argwhere from infinicore.ops.asin import asin from infinicore.ops.asinh import asinh +from infinicore.ops.asum import asum from infinicore.ops.atanh import atanh from infinicore.ops.attention import attention from infinicore.ops.baddbmm import baddbmm from infinicore.ops.bilinear import bilinear +from infinicore.ops.blas_amax import blas_amax +from infinicore.ops.blas_amin import blas_amin +from infinicore.ops.blas_dot import blas_dot from infinicore.ops.binary_cross_entropy_with_logits import ( binary_cross_entropy_with_logits, ) @@ -102,6 +106,7 @@ from infinicore.ops.mha_varlen import mha_varlen from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow +from infinicore.ops.nrm2 import nrm2 from infinicore.ops.paged_attention import paged_attention from infinicore.ops.paged_attention_prefill import paged_attention_prefill from infinicore.ops.paged_caching import paged_caching @@ -185,6 +190,10 @@ "add_rms_norm", "argwhere", "asin", + "asum", + "blas_amax", + "blas_amin", + "blas_dot", "acos", "addbmm", "floor", @@ -210,6 +219,7 @@ "dist", "logdet", "narrow", + "nrm2", "ldexp", "lerp", "kthvalue", diff --git a/python/infinicore/ops/asum.py b/python/infinicore/ops/asum.py new file mode 100644 index 000000000..5f8cce9af --- /dev/null +++ b/python/infinicore/ops/asum.py @@ -0,0 +1,10 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def asum(x: Tensor, *, out=None): + if out is None: + return Tensor(_infinicore.asum(x._underlying)) + + _infinicore.asum_(out._underlying, x._underlying) + return out diff --git a/python/infinicore/ops/blas_amax.py b/python/infinicore/ops/blas_amax.py new file mode 100644 index 000000000..ed7fbaf54 --- /dev/null +++ b/python/infinicore/ops/blas_amax.py @@ -0,0 +1,10 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def blas_amax(x: Tensor, *, out=None): + if out is None: + return Tensor(_infinicore.blas_amax(x._underlying)) + + _infinicore.blas_amax_(out._underlying, x._underlying) + return out diff --git a/python/infinicore/ops/blas_amin.py b/python/infinicore/ops/blas_amin.py new file mode 100644 index 000000000..a93ee43a6 --- /dev/null +++ b/python/infinicore/ops/blas_amin.py @@ -0,0 +1,10 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def blas_amin(x: Tensor, *, out=None): + if out is None: + return Tensor(_infinicore.blas_amin(x._underlying)) + + _infinicore.blas_amin_(out._underlying, x._underlying) + return out diff --git a/python/infinicore/ops/blas_dot.py b/python/infinicore/ops/blas_dot.py new file mode 100644 index 000000000..e96eca838 --- /dev/null +++ b/python/infinicore/ops/blas_dot.py @@ -0,0 +1,10 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def blas_dot(x: Tensor, y: Tensor, *, out=None): + if out is None: + return Tensor(_infinicore.blas_dot(x._underlying, y._underlying)) + + _infinicore.blas_dot_(out._underlying, x._underlying, y._underlying) + return out diff --git a/python/infinicore/ops/nrm2.py b/python/infinicore/ops/nrm2.py new file mode 100644 index 000000000..34de0e25a --- /dev/null +++ b/python/infinicore/ops/nrm2.py @@ -0,0 +1,10 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def nrm2(x: Tensor, *, out=None): + if out is None: + return Tensor(_infinicore.nrm2(x._underlying)) + + _infinicore.nrm2_(out._underlying, x._underlying) + return out diff --git a/src/infinicore/ops/asum/asum.cc b/src/infinicore/ops/asum/asum.cc new file mode 100644 index 000000000..abf59fd90 --- /dev/null +++ b/src/infinicore/ops/asum/asum.cc @@ -0,0 +1,28 @@ +#include "infinicore/ops/asum.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Asum::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Asum::execute(Tensor result, Tensor x) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, x); + infinicore::context::setDevice(result->device()); + dispatcher().lookup(result->device().getType())(result, x); +} + +Tensor asum(Tensor x) { + auto result = Tensor::empty({}, x->dtype(), x->device()); + asum_(result, x); + return result; +} + +void asum_(Tensor result, Tensor x) { + Asum::execute(result, x); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/asum/asum_infiniop.cc b/src/infinicore/ops/asum/asum_infiniop.cc new file mode 100644 index 000000000..65eded161 --- /dev/null +++ b/src/infinicore/ops/asum/asum_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/asum.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::asum_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopAsumDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyAsumDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor result, Tensor x) { + size_t seed = hash_combine(result, x); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopAsumDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateAsumDescriptor( + context::getInfiniopHandle(result->device()), &desc, + x->desc(), result->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetAsumWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopAsum( + desc, workspace->data(), workspace_size, + x->data(), result->data(), context::getStream())); +} + +static bool registered = []() { + Asum::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::asum_impl::infiniop diff --git a/src/infinicore/ops/blas_amax/blas_amax.cc b/src/infinicore/ops/blas_amax/blas_amax.cc new file mode 100644 index 000000000..5ec8825c4 --- /dev/null +++ b/src/infinicore/ops/blas_amax/blas_amax.cc @@ -0,0 +1,28 @@ +#include "infinicore/ops/blas_amax.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &BlasAmax::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void BlasAmax::execute(Tensor result, Tensor x) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, x); + infinicore::context::setDevice(result->device()); + dispatcher().lookup(result->device().getType())(result, x); +} + +Tensor blas_amax(Tensor x) { + auto result = Tensor::empty({}, DataType::I32, x->device()); + blas_amax_(result, x); + return result; +} + +void blas_amax_(Tensor result, Tensor x) { + BlasAmax::execute(result, x); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/blas_amax/blas_amax_infiniop.cc b/src/infinicore/ops/blas_amax/blas_amax_infiniop.cc new file mode 100644 index 000000000..9b66fd164 --- /dev/null +++ b/src/infinicore/ops/blas_amax/blas_amax_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/blas_amax.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::blas_amax_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopBlasAmaxDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyBlasAmaxDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor result, Tensor x) { + size_t seed = hash_combine(result, x); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopBlasAmaxDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateBlasAmaxDescriptor( + context::getInfiniopHandle(result->device()), &desc, + x->desc(), result->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetBlasAmaxWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopBlasAmax( + desc, workspace->data(), workspace_size, + x->data(), result->data(), context::getStream())); +} + +static bool registered = []() { + BlasAmax::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::blas_amax_impl::infiniop diff --git a/src/infinicore/ops/blas_amin/blas_amin.cc b/src/infinicore/ops/blas_amin/blas_amin.cc new file mode 100644 index 000000000..058ae1fdc --- /dev/null +++ b/src/infinicore/ops/blas_amin/blas_amin.cc @@ -0,0 +1,28 @@ +#include "infinicore/ops/blas_amin.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &BlasAmin::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void BlasAmin::execute(Tensor result, Tensor x) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, x); + infinicore::context::setDevice(result->device()); + dispatcher().lookup(result->device().getType())(result, x); +} + +Tensor blas_amin(Tensor x) { + auto result = Tensor::empty({}, DataType::I32, x->device()); + blas_amin_(result, x); + return result; +} + +void blas_amin_(Tensor result, Tensor x) { + BlasAmin::execute(result, x); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/blas_amin/blas_amin_infiniop.cc b/src/infinicore/ops/blas_amin/blas_amin_infiniop.cc new file mode 100644 index 000000000..f93575c6c --- /dev/null +++ b/src/infinicore/ops/blas_amin/blas_amin_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/blas_amin.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::blas_amin_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopBlasAminDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyBlasAminDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor result, Tensor x) { + size_t seed = hash_combine(result, x); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopBlasAminDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateBlasAminDescriptor( + context::getInfiniopHandle(result->device()), &desc, + x->desc(), result->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetBlasAminWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopBlasAmin( + desc, workspace->data(), workspace_size, + x->data(), result->data(), context::getStream())); +} + +static bool registered = []() { + BlasAmin::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::blas_amin_impl::infiniop diff --git a/src/infinicore/ops/blas_dot/blas_dot.cc b/src/infinicore/ops/blas_dot/blas_dot.cc new file mode 100644 index 000000000..66465bf27 --- /dev/null +++ b/src/infinicore/ops/blas_dot/blas_dot.cc @@ -0,0 +1,28 @@ +#include "infinicore/ops/blas_dot.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &BlasDot::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void BlasDot::execute(Tensor result, Tensor x, Tensor y) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, x, y); + infinicore::context::setDevice(result->device()); + dispatcher().lookup(result->device().getType())(result, x, y); +} + +Tensor blas_dot(Tensor x, Tensor y) { + auto result = Tensor::empty({}, x->dtype(), x->device()); + blas_dot_(result, x, y); + return result; +} + +void blas_dot_(Tensor result, Tensor x, Tensor y) { + BlasDot::execute(result, x, y); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/blas_dot/blas_dot_infiniop.cc b/src/infinicore/ops/blas_dot/blas_dot_infiniop.cc new file mode 100644 index 000000000..51e13a511 --- /dev/null +++ b/src/infinicore/ops/blas_dot/blas_dot_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/blas_dot.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::blas_dot_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopBlasDotDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyBlasDotDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor result, Tensor x, Tensor y) { + size_t seed = hash_combine(result, x, y); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopBlasDotDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateBlasDotDescriptor( + context::getInfiniopHandle(result->device()), &desc, + x->desc(), y->desc(), result->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetBlasDotWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopBlasDot( + desc, workspace->data(), workspace_size, + x->data(), y->data(), result->data(), context::getStream())); +} + +static bool registered = []() { + BlasDot::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::blas_dot_impl::infiniop diff --git a/src/infinicore/ops/nrm2/nrm2.cc b/src/infinicore/ops/nrm2/nrm2.cc new file mode 100644 index 000000000..26e9f1d1b --- /dev/null +++ b/src/infinicore/ops/nrm2/nrm2.cc @@ -0,0 +1,28 @@ +#include "infinicore/ops/nrm2.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Nrm2::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Nrm2::execute(Tensor result, Tensor x) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, x); + infinicore::context::setDevice(result->device()); + dispatcher().lookup(result->device().getType())(result, x); +} + +Tensor nrm2(Tensor x) { + auto result = Tensor::empty({}, x->dtype(), x->device()); + nrm2_(result, x); + return result; +} + +void nrm2_(Tensor result, Tensor x) { + Nrm2::execute(result, x); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/nrm2/nrm2_infiniop.cc b/src/infinicore/ops/nrm2/nrm2_infiniop.cc new file mode 100644 index 000000000..975237037 --- /dev/null +++ b/src/infinicore/ops/nrm2/nrm2_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/nrm2.hpp" +#include + +namespace infinicore::op::nrm2_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopNrm2Descriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyNrm2Descriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor result, Tensor x) { + size_t seed = hash_combine(result, x); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopNrm2Descriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateNrm2Descriptor( + context::getInfiniopHandle(result->device()), &desc, + x->desc(), result->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetNrm2WorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopNrm2( + desc, workspace->data(), workspace_size, + x->data(), result->data(), context::getStream())); +} + +static bool registered = []() { + Nrm2::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::nrm2_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 383429f8f..72da952dc 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -16,6 +16,7 @@ #include "ops/argwhere.hpp" #include "ops/asin.hpp" #include "ops/asinh.hpp" +#include "ops/asum.hpp" #include "ops/atanh.hpp" #include "ops/attention.hpp" #include "ops/avg_pool1d.hpp" @@ -23,6 +24,9 @@ #include "ops/bilinear.hpp" #include "ops/binary_cross_entropy_with_logits.hpp" #include "ops/bitwise_right_shift.hpp" +#include "ops/blas_amax.hpp" +#include "ops/blas_amin.hpp" +#include "ops/blas_dot.hpp" #include "ops/block_diag.hpp" #include "ops/broadcast_to.hpp" #include "ops/cat.hpp" @@ -72,6 +76,7 @@ #include "ops/mha_varlen.hpp" #include "ops/mul.hpp" #include "ops/multi_margin_loss.hpp" +#include "ops/nrm2.hpp" #include "ops/pad.hpp" #include "ops/paged_attention.hpp" #include "ops/paged_attention_prefill.hpp" @@ -127,8 +132,12 @@ inline void bind(py::module &m) { bind_adaptive_avg_pool1d(m); bind_attention(m); bind_asinh(m); + bind_asum(m); bind_baddbmm(m); bind_bilinear(m); + bind_blas_amax(m); + bind_blas_amin(m); + bind_blas_dot(m); bind_block_diag(m); bind_bitwise_right_shift(m); bind_causal_softmax(m); @@ -153,6 +162,7 @@ inline void bind(py::module &m) { bind_matmul(m); bind_kron(m); bind_mul(m); + bind_nrm2(m); bind_mha_kvcache(m); bind_mha_varlen(m); bind_hardswish(m); diff --git a/src/infinicore/pybind11/ops/asum.hpp b/src/infinicore/pybind11/ops/asum.hpp new file mode 100644 index 000000000..40661c562 --- /dev/null +++ b/src/infinicore/pybind11/ops/asum.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "infinicore/ops/asum.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_asum(py::module &m) { + m.def("asum", + &op::asum, + py::arg("x"), + R"doc(BLAS level-1 asum.)doc"); + + m.def("asum_", + &op::asum_, + py::arg("result"), + py::arg("x"), + R"doc(In-place BLAS level-1 asum.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/blas_amax.hpp b/src/infinicore/pybind11/ops/blas_amax.hpp new file mode 100644 index 000000000..5afdee62c --- /dev/null +++ b/src/infinicore/pybind11/ops/blas_amax.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "infinicore/ops/blas_amax.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_blas_amax(py::module &m) { + m.def("blas_amax", + &op::blas_amax, + py::arg("x"), + R"doc(BLAS level-1 amax.)doc"); + + m.def("blas_amax_", + &op::blas_amax_, + py::arg("result"), + py::arg("x"), + R"doc(In-place BLAS level-1 amax.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/blas_amin.hpp b/src/infinicore/pybind11/ops/blas_amin.hpp new file mode 100644 index 000000000..4f9b316de --- /dev/null +++ b/src/infinicore/pybind11/ops/blas_amin.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "infinicore/ops/blas_amin.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_blas_amin(py::module &m) { + m.def("blas_amin", + &op::blas_amin, + py::arg("x"), + R"doc(BLAS level-1 amin.)doc"); + + m.def("blas_amin_", + &op::blas_amin_, + py::arg("result"), + py::arg("x"), + R"doc(In-place BLAS level-1 amin.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/blas_dot.hpp b/src/infinicore/pybind11/ops/blas_dot.hpp new file mode 100644 index 000000000..9290db5a7 --- /dev/null +++ b/src/infinicore/pybind11/ops/blas_dot.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "infinicore/ops/blas_dot.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_blas_dot(py::module &m) { + m.def("blas_dot", + &op::blas_dot, + py::arg("x"), + py::arg("y"), + R"doc(BLAS level-1 dot.)doc"); + + m.def("blas_dot_", + &op::blas_dot_, + py::arg("result"), + py::arg("x"), + py::arg("y"), + R"doc(In-place BLAS level-1 dot.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/nrm2.hpp b/src/infinicore/pybind11/ops/nrm2.hpp new file mode 100644 index 000000000..5431046b7 --- /dev/null +++ b/src/infinicore/pybind11/ops/nrm2.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "infinicore/ops/nrm2.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_nrm2(py::module &m) { + m.def("nrm2", + &op::nrm2, + py::arg("x"), + R"doc(BLAS level-1 nrm2.)doc"); + + m.def("nrm2_", + &op::nrm2_, + py::arg("result"), + py::arg("x"), + R"doc(In-place BLAS level-1 nrm2.)doc"); +} + +} // namespace infinicore::ops diff --git a/test/infinicore/ops/asum.py b/test/infinicore/ops/asum.py new file mode 100644 index 000000000..26c25d699 --- /dev/null +++ b/test/infinicore/ops/asum.py @@ -0,0 +1,114 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) + +import infinicore + +# ======================================================================= +# Test cases format: (shape, x_strides_or_None) +# ======================================================================= + +_TEST_CASES_DATA = [ + ((13,), None), + ((13,), (10,)), + ((16,), None), + ((16,), (4,)), + ((255,), None), + ((5632,), None), +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-4}, + infinicore.float64: {"atol": 1e-9, "rtol": 1e-6}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 5e-2}, +} + +_TENSOR_DTYPES = [ + # infinicore.float16, + infinicore.float32, + # infinicore.float64, + # infinicore.bfloat16, +] + + +def torch_asum(x, *, out=None): + def _asum(x, out): + out.copy_(torch.sum(x.abs())) + + if out is None: + out = torch.empty(1, dtype=x.dtype, device=x.device) + + _asum(x, out) + return out + + +def parse_test_cases(): + test_cases = [] + for data in _TEST_CASES_DATA: + shape = data[0] + x_strides = data[1] if len(data) > 1 else None + + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) + x_spec = TensorSpec.from_tensor(shape, x_strides, dtype) + out_spec = TensorSpec.from_tensor((), None, dtype) + + test_cases.append( + TestCase( + inputs=[x_spec], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=tol, + description="asum - OUT_OF_PLACE", + ) + ) + + test_cases.append( + TestCase( + inputs=[x_spec], + kwargs={}, + output_spec=out_spec, + comparison_target="out", + tolerance=tol, + description="asum - INPLACE(out)", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 asum operator test""" + + def __init__(self): + super().__init__("Asum") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_asum(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.asum(*args, **kwargs) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/blas_amax.py b/test/infinicore/ops/blas_amax.py new file mode 100644 index 000000000..7e0412c9a --- /dev/null +++ b/test/infinicore/ops/blas_amax.py @@ -0,0 +1,100 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) +from framework.tensor import TensorInitializer + +import infinicore + +_TEST_CASES_DATA = [ + ((3,), None), + ((8,), (2,)), + ((32,), None), + ((257,), (3,)), + ((65535,), None), +] + +_TENSOR_DTYPES = [ + # infinicore.float16, + infinicore.float32, + # infinicore.float64, + # infinicore.bfloat16, +] + +_TOLERANCE = {"atol": 0, "rtol": 0} + + +def torch_blas_amax(x, *, out=None): + result = torch.argmax(x.abs()).to(torch.int32) + 1 + if out is None: + return result + + out.copy_(result) + return out + + +def parse_test_cases(): + test_cases = [] + for shape, x_strides in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + x_spec = TensorSpec.from_tensor(shape, x_strides, dtype) + out_spec = TensorSpec.from_tensor( + (), None, infinicore.int32, init_mode=TensorInitializer.ZEROS + ) + + test_cases.append( + TestCase( + inputs=[x_spec], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=_TOLERANCE, + description="blas_amax - OUT_OF_PLACE", + ) + ) + + test_cases.append( + TestCase( + inputs=[x_spec], + kwargs={}, + output_spec=out_spec, + comparison_target="out", + tolerance=_TOLERANCE, + description="blas_amax - INPLACE(out)", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 amax operator test""" + + def __init__(self): + super().__init__("BlasAmax") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_blas_amax(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.blas_amax(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/blas_amin.py b/test/infinicore/ops/blas_amin.py new file mode 100644 index 000000000..6063508c5 --- /dev/null +++ b/test/infinicore/ops/blas_amin.py @@ -0,0 +1,100 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) +from framework.tensor import TensorInitializer + +import infinicore + +_TEST_CASES_DATA = [ + ((3,), None), + ((8,), (2,)), + ((32,), None), + ((257,), (3,)), + ((65535,), None), +] + +_TENSOR_DTYPES = [ + # infinicore.float16, + infinicore.float32, + # infinicore.float64, + # infinicore.bfloat16, +] + +_TOLERANCE = {"atol": 0, "rtol": 0} + + +def torch_blas_amin(x, *, out=None): + result = torch.argmin(x.abs()).to(torch.int32) + 1 + if out is None: + return result + + out.copy_(result) + return out + + +def parse_test_cases(): + test_cases = [] + for shape, x_strides in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + x_spec = TensorSpec.from_tensor(shape, x_strides, dtype) + out_spec = TensorSpec.from_tensor( + (), None, infinicore.int32, init_mode=TensorInitializer.ZEROS + ) + + test_cases.append( + TestCase( + inputs=[x_spec], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=_TOLERANCE, + description="blas_amin - OUT_OF_PLACE", + ) + ) + + test_cases.append( + TestCase( + inputs=[x_spec], + kwargs={}, + output_spec=out_spec, + comparison_target="out", + tolerance=_TOLERANCE, + description="blas_amin - INPLACE(out)", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 amin operator test""" + + def __init__(self): + super().__init__("BlasAmin") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_blas_amin(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.blas_amin(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/blas_dot.py b/test/infinicore/ops/blas_dot.py new file mode 100644 index 000000000..edc4e9ff5 --- /dev/null +++ b/test/infinicore/ops/blas_dot.py @@ -0,0 +1,111 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) +from framework.tensor import TensorInitializer + +import infinicore + +_TEST_CASES_DATA = [ + ((3,), None, None), + ((8,), (2,), (3,)), + ((32,), None, (2,)), + ((257,), (3,), None), + ((65535,), None, None), +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, + infinicore.float64: {"atol": 1e-9, "rtol": 1e-9}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, +} + +_TENSOR_DTYPES = [ + infinicore.float16, + infinicore.float32, + # infinicore.float64, + infinicore.bfloat16, +] + + +def torch_blas_dot(x, y, *, out=None): + if x.dtype in (torch.float16, torch.bfloat16): + result = torch.dot(x.float(), y.float()).to(x.dtype) + else: + result = torch.dot(x, y) + + if out is None: + return result + + out.copy_(result) + return out + + +def parse_test_cases(): + test_cases = [] + for shape, x_strides, y_strides in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) + x_spec = TensorSpec.from_tensor(shape, x_strides, dtype) + y_spec = TensorSpec.from_tensor(shape, y_strides, dtype) + out_spec = TensorSpec.from_tensor( + (), None, dtype, init_mode=TensorInitializer.ZEROS + ) + + test_cases.append( + TestCase( + inputs=[x_spec, y_spec], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=tol, + description="blas_dot - OUT_OF_PLACE", + ) + ) + + test_cases.append( + TestCase( + inputs=[x_spec, y_spec], + kwargs={}, + output_spec=out_spec, + comparison_target="out", + tolerance=tol, + description="blas_dot - INPLACE(out)", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 dot operator test""" + + def __init__(self): + super().__init__("BlasDot") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_blas_dot(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.blas_dot(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/nrm2.py b/test/infinicore/ops/nrm2.py new file mode 100644 index 000000000..ca5d4a5fc --- /dev/null +++ b/test/infinicore/ops/nrm2.py @@ -0,0 +1,104 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) + +import infinicore + +_TEST_CASES_DATA = [ + ((13,), None), + ((13,), (10,)), + ((5632,), None), + ((5632,), (5,)), + ((16,), (4,)), + ((5632,), (32,)), +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, + infinicore.float64: {"atol": 1e-7, "rtol": 1e-7}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, +} + +_TENSOR_DTYPES = [ + infinicore.float16, + infinicore.float32, + # infinicore.float64, + infinicore.bfloat16, +] + + +def torch_nrm2(x, *, out=None): + result = torch.norm(x, p=2) + if out is None: + return result + + out.copy_(result) + return out + + +def parse_test_cases(): + test_cases = [] + for shape, x_strides in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) + x_spec = TensorSpec.from_tensor(shape, x_strides, dtype) + out_spec = TensorSpec.from_tensor((), None, dtype) + + test_cases.append( + TestCase( + inputs=[x_spec], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=tol, + description="nrm2 - OUT_OF_PLACE", + ) + ) + + test_cases.append( + TestCase( + inputs=[x_spec], + kwargs={}, + output_spec=out_spec, + comparison_target="out", + tolerance=tol, + description="nrm2 - INPLACE(out)", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 nrm2 operator test""" + + def __init__(self): + super().__init__("Nrm2") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_nrm2(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.nrm2(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() From 8b57466221885c0595e93a21c6850d9b06938142 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Wed, 6 May 2026 06:36:15 +0000 Subject: [PATCH 16/25] Add InfiniCore `axpy`, `blas_copy`, `scal` and `swap` wrappers --- include/infinicore/ops.hpp | 4 + include/infinicore/ops/axpy.hpp | 17 ++++ include/infinicore/ops/blas_copy.hpp | 17 ++++ include/infinicore/ops/scal.hpp | 17 ++++ include/infinicore/ops/swap.hpp | 17 ++++ python/infinicore/__init__.py | 8 ++ python/infinicore/ops/axpy.py | 7 ++ python/infinicore/ops/blas_copy.py | 7 ++ python/infinicore/ops/scal.py | 7 ++ python/infinicore/ops/swap.py | 7 ++ src/infinicore/ops/axpy/axpy.cc | 22 +++++ src/infinicore/ops/axpy/axpy_infiniop.cc | 56 +++++++++++ src/infinicore/ops/blas_copy/blas_copy.cc | 22 +++++ .../ops/blas_copy/blas_copy_infiniop.cc | 56 +++++++++++ src/infinicore/ops/scal/scal.cc | 22 +++++ src/infinicore/ops/scal/scal_infiniop.cc | 56 +++++++++++ src/infinicore/ops/swap/swap.cc | 22 +++++ src/infinicore/ops/swap/swap_infiniop.cc | 56 +++++++++++ src/infinicore/pybind11/ops.hpp | 8 ++ src/infinicore/pybind11/ops/axpy.hpp | 20 ++++ src/infinicore/pybind11/ops/blas_copy.hpp | 19 ++++ src/infinicore/pybind11/ops/scal.hpp | 19 ++++ src/infinicore/pybind11/ops/swap.hpp | 19 ++++ test/infinicore/ops/axpy.py | 92 +++++++++++++++++++ test/infinicore/ops/blas_copy.py | 87 ++++++++++++++++++ test/infinicore/ops/scal.py | 91 ++++++++++++++++++ test/infinicore/ops/swap.py | 90 ++++++++++++++++++ 27 files changed, 865 insertions(+) create mode 100644 include/infinicore/ops/axpy.hpp create mode 100644 include/infinicore/ops/blas_copy.hpp create mode 100644 include/infinicore/ops/scal.hpp create mode 100644 include/infinicore/ops/swap.hpp create mode 100644 python/infinicore/ops/axpy.py create mode 100644 python/infinicore/ops/blas_copy.py create mode 100644 python/infinicore/ops/scal.py create mode 100644 python/infinicore/ops/swap.py create mode 100644 src/infinicore/ops/axpy/axpy.cc create mode 100644 src/infinicore/ops/axpy/axpy_infiniop.cc create mode 100644 src/infinicore/ops/blas_copy/blas_copy.cc create mode 100644 src/infinicore/ops/blas_copy/blas_copy_infiniop.cc create mode 100644 src/infinicore/ops/scal/scal.cc create mode 100644 src/infinicore/ops/scal/scal_infiniop.cc create mode 100644 src/infinicore/ops/swap/swap.cc create mode 100644 src/infinicore/ops/swap/swap_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/axpy.hpp create mode 100644 src/infinicore/pybind11/ops/blas_copy.hpp create mode 100644 src/infinicore/pybind11/ops/scal.hpp create mode 100644 src/infinicore/pybind11/ops/swap.hpp create mode 100644 test/infinicore/ops/axpy.py create mode 100644 test/infinicore/ops/blas_copy.py create mode 100644 test/infinicore/ops/scal.py create mode 100644 test/infinicore/ops/swap.py diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 6df98da82..72d6c311c 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -10,11 +10,13 @@ #include "ops/atanh.hpp" #include "ops/attention.hpp" #include "ops/avg_pool1d.hpp" +#include "ops/axpy.hpp" #include "ops/baddbmm.hpp" #include "ops/bilinear.hpp" #include "ops/binary_cross_entropy_with_logits.hpp" #include "ops/blas_amax.hpp" #include "ops/blas_amin.hpp" +#include "ops/blas_copy.hpp" #include "ops/blas_dot.hpp" #include "ops/causal_softmax.hpp" #include "ops/cdist.hpp" @@ -46,8 +48,10 @@ #include "ops/relu.hpp" #include "ops/rms_norm.hpp" #include "ops/rope.hpp" +#include "ops/scal.hpp" #include "ops/silu.hpp" #include "ops/silu_and_mul.hpp" #include "ops/softmax.hpp" +#include "ops/swap.hpp" #include "ops/swiglu.hpp" #include "ops/topksoftmax.hpp" diff --git a/include/infinicore/ops/axpy.hpp b/include/infinicore/ops/axpy.hpp new file mode 100644 index 000000000..3f9da7c4a --- /dev/null +++ b/include/infinicore/ops/axpy.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class Axpy { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor alpha, Tensor x, Tensor y); + static common::OpDispatcher &dispatcher(); +}; + +void axpy_(Tensor alpha, Tensor x, Tensor y); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/blas_copy.hpp b/include/infinicore/ops/blas_copy.hpp new file mode 100644 index 000000000..1ac048a41 --- /dev/null +++ b/include/infinicore/ops/blas_copy.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class BlasCopy { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor x, Tensor y); + static common::OpDispatcher &dispatcher(); +}; + +void blas_copy_(Tensor x, Tensor y); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/scal.hpp b/include/infinicore/ops/scal.hpp new file mode 100644 index 000000000..38045c980 --- /dev/null +++ b/include/infinicore/ops/scal.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class Scal { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor alpha, Tensor x); + static common::OpDispatcher &dispatcher(); +}; + +void scal_(Tensor x, Tensor alpha); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/swap.hpp b/include/infinicore/ops/swap.hpp new file mode 100644 index 000000000..9b2f85dea --- /dev/null +++ b/include/infinicore/ops/swap.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class Swap { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor x, Tensor y); + static common::OpDispatcher &dispatcher(); +}; + +void swap_(Tensor x, Tensor y); + +} // namespace infinicore::op diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 17eb7504d..883bf7098 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -61,10 +61,12 @@ from infinicore.ops.asum import asum from infinicore.ops.atanh import atanh from infinicore.ops.attention import attention +from infinicore.ops.axpy import axpy from infinicore.ops.baddbmm import baddbmm from infinicore.ops.bilinear import bilinear from infinicore.ops.blas_amax import blas_amax from infinicore.ops.blas_amin import blas_amin +from infinicore.ops.blas_copy import blas_copy from infinicore.ops.blas_dot import blas_dot from infinicore.ops.binary_cross_entropy_with_logits import ( binary_cross_entropy_with_logits, @@ -112,8 +114,10 @@ from infinicore.ops.paged_caching import paged_caching from infinicore.ops.rearrange import rearrange from infinicore.ops.reciprocal import reciprocal +from infinicore.ops.scal import scal from infinicore.ops.scatter import scatter from infinicore.ops.sinh import sinh +from infinicore.ops.swap import swap from infinicore.ops.squeeze import squeeze from infinicore.ops.sum import sum from infinicore.ops.take import take @@ -191,8 +195,10 @@ "argwhere", "asin", "asum", + "axpy", "blas_amax", "blas_amin", + "blas_copy", "blas_dot", "acos", "addbmm", @@ -241,6 +247,7 @@ "float_power", "flipud", "scatter", + "scal", "logcumsumexp", "logical_not", "logical_and", @@ -253,6 +260,7 @@ "index_add", "take", "sinh", + "swap", "ones", "broadcast_to", "strided_empty", diff --git a/python/infinicore/ops/axpy.py b/python/infinicore/ops/axpy.py new file mode 100644 index 000000000..3457038fb --- /dev/null +++ b/python/infinicore/ops/axpy.py @@ -0,0 +1,7 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def axpy(alpha: Tensor, x: Tensor, y: Tensor): + _infinicore.axpy_(alpha._underlying, x._underlying, y._underlying) + return y diff --git a/python/infinicore/ops/blas_copy.py b/python/infinicore/ops/blas_copy.py new file mode 100644 index 000000000..75d4abef3 --- /dev/null +++ b/python/infinicore/ops/blas_copy.py @@ -0,0 +1,7 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def blas_copy(x: Tensor, y: Tensor): + _infinicore.blas_copy_(x._underlying, y._underlying) + return y diff --git a/python/infinicore/ops/scal.py b/python/infinicore/ops/scal.py new file mode 100644 index 000000000..86b07eb2b --- /dev/null +++ b/python/infinicore/ops/scal.py @@ -0,0 +1,7 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def scal(x: Tensor, alpha: Tensor): + _infinicore.scal_(x._underlying, alpha._underlying) + return x diff --git a/python/infinicore/ops/swap.py b/python/infinicore/ops/swap.py new file mode 100644 index 000000000..f773a34eb --- /dev/null +++ b/python/infinicore/ops/swap.py @@ -0,0 +1,7 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def swap(x: Tensor, y: Tensor): + _infinicore.swap_(x._underlying, y._underlying) + return x, y diff --git a/src/infinicore/ops/axpy/axpy.cc b/src/infinicore/ops/axpy/axpy.cc new file mode 100644 index 000000000..957b83ef6 --- /dev/null +++ b/src/infinicore/ops/axpy/axpy.cc @@ -0,0 +1,22 @@ +#include "infinicore/ops/axpy.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Axpy::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Axpy::execute(Tensor alpha, Tensor x, Tensor y) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(alpha, x, y); + infinicore::context::setDevice(y->device()); + dispatcher().lookup(y->device().getType())(alpha, x, y); +} + +void axpy_(Tensor alpha, Tensor x, Tensor y) { + Axpy::execute(alpha, x, y); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/axpy/axpy_infiniop.cc b/src/infinicore/ops/axpy/axpy_infiniop.cc new file mode 100644 index 000000000..013b5b2d9 --- /dev/null +++ b/src/infinicore/ops/axpy/axpy_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/axpy.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::axpy_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopAxpyDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyAxpyDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor alpha, Tensor x, Tensor y) { + size_t seed = hash_combine(alpha, x, y); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopAxpyDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateAxpyDescriptor( + context::getInfiniopHandle(y->device()), &desc, + alpha->desc(), x->desc(), y->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetAxpyWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopAxpy( + desc, workspace->data(), workspace_size, + alpha->data(), x->data(), y->data(), context::getStream())); +} + +static bool registered = []() { + Axpy::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::axpy_impl::infiniop diff --git a/src/infinicore/ops/blas_copy/blas_copy.cc b/src/infinicore/ops/blas_copy/blas_copy.cc new file mode 100644 index 000000000..374774f5c --- /dev/null +++ b/src/infinicore/ops/blas_copy/blas_copy.cc @@ -0,0 +1,22 @@ +#include "infinicore/ops/blas_copy.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &BlasCopy::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void BlasCopy::execute(Tensor x, Tensor y) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, y); + infinicore::context::setDevice(y->device()); + dispatcher().lookup(y->device().getType())(x, y); +} + +void blas_copy_(Tensor x, Tensor y) { + BlasCopy::execute(x, y); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/blas_copy/blas_copy_infiniop.cc b/src/infinicore/ops/blas_copy/blas_copy_infiniop.cc new file mode 100644 index 000000000..aa1c8bc52 --- /dev/null +++ b/src/infinicore/ops/blas_copy/blas_copy_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/blas_copy.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::blas_copy_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopBlasCopyDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyBlasCopyDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor x, Tensor y) { + size_t seed = hash_combine(x, y); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopBlasCopyDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateBlasCopyDescriptor( + context::getInfiniopHandle(y->device()), &desc, + x->desc(), y->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetBlasCopyWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopBlasCopy( + desc, workspace->data(), workspace_size, + x->data(), y->data(), context::getStream())); +} + +static bool registered = []() { + BlasCopy::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::blas_copy_impl::infiniop diff --git a/src/infinicore/ops/scal/scal.cc b/src/infinicore/ops/scal/scal.cc new file mode 100644 index 000000000..c0ef9871a --- /dev/null +++ b/src/infinicore/ops/scal/scal.cc @@ -0,0 +1,22 @@ +#include "infinicore/ops/scal.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Scal::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Scal::execute(Tensor alpha, Tensor x) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(alpha, x); + infinicore::context::setDevice(x->device()); + dispatcher().lookup(x->device().getType())(alpha, x); +} + +void scal_(Tensor x, Tensor alpha) { + Scal::execute(alpha, x); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/scal/scal_infiniop.cc b/src/infinicore/ops/scal/scal_infiniop.cc new file mode 100644 index 000000000..9e07e1798 --- /dev/null +++ b/src/infinicore/ops/scal/scal_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/scal.hpp" +#include + +namespace infinicore::op::scal_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopScalDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyScalDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor alpha, Tensor x) { + size_t seed = hash_combine(alpha, x); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopScalDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateScalDescriptor( + context::getInfiniopHandle(x->device()), &desc, + alpha->desc(), x->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetScalWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopScal( + desc, workspace->data(), workspace_size, + alpha->data(), x->data(), context::getStream())); +} + +static bool registered = []() { + Scal::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::scal_impl::infiniop diff --git a/src/infinicore/ops/swap/swap.cc b/src/infinicore/ops/swap/swap.cc new file mode 100644 index 000000000..5bc078e5c --- /dev/null +++ b/src/infinicore/ops/swap/swap.cc @@ -0,0 +1,22 @@ +#include "infinicore/ops/swap.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Swap::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Swap::execute(Tensor x, Tensor y) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, y); + infinicore::context::setDevice(x->device()); + dispatcher().lookup(x->device().getType())(x, y); +} + +void swap_(Tensor x, Tensor y) { + Swap::execute(x, y); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/swap/swap_infiniop.cc b/src/infinicore/ops/swap/swap_infiniop.cc new file mode 100644 index 000000000..f344e36a0 --- /dev/null +++ b/src/infinicore/ops/swap/swap_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/swap.hpp" +#include + +namespace infinicore::op::swap_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopSwapDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroySwapDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor x, Tensor y) { + size_t seed = hash_combine(x, y); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopSwapDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateSwapDescriptor( + context::getInfiniopHandle(x->device()), &desc, + x->desc(), y->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetSwapWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopSwap( + desc, workspace->data(), workspace_size, + x->data(), y->data(), context::getStream())); +} + +static bool registered = []() { + Swap::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::swap_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 72da952dc..08475a979 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -20,12 +20,14 @@ #include "ops/atanh.hpp" #include "ops/attention.hpp" #include "ops/avg_pool1d.hpp" +#include "ops/axpy.hpp" #include "ops/baddbmm.hpp" #include "ops/bilinear.hpp" #include "ops/binary_cross_entropy_with_logits.hpp" #include "ops/bitwise_right_shift.hpp" #include "ops/blas_amax.hpp" #include "ops/blas_amin.hpp" +#include "ops/blas_copy.hpp" #include "ops/blas_dot.hpp" #include "ops/block_diag.hpp" #include "ops/broadcast_to.hpp" @@ -88,6 +90,7 @@ #include "ops/relu6.hpp" #include "ops/rms_norm.hpp" #include "ops/rope.hpp" +#include "ops/scal.hpp" #include "ops/scatter.hpp" #include "ops/selu.hpp" #include "ops/silu.hpp" @@ -97,6 +100,7 @@ #include "ops/softplus.hpp" #include "ops/softsign.hpp" #include "ops/sum.hpp" +#include "ops/swap.hpp" #include "ops/swiglu.hpp" #include "ops/take.hpp" #include "ops/tan.hpp" @@ -133,10 +137,12 @@ inline void bind(py::module &m) { bind_attention(m); bind_asinh(m); bind_asum(m); + bind_axpy(m); bind_baddbmm(m); bind_bilinear(m); bind_blas_amax(m); bind_blas_amin(m); + bind_blas_copy(m); bind_blas_dot(m); bind_block_diag(m); bind_bitwise_right_shift(m); @@ -200,6 +206,7 @@ inline void bind(py::module &m) { bind_flipud(m); bind_multi_margin_loss(m); bind_scatter(m); + bind_scal(m); bind_broadcast_to(m); bind_softplus(m); bind_softsign(m); @@ -227,6 +234,7 @@ inline void bind(py::module &m) { bind_lerp(m); bind_triplet_margin_loss(m); bind_selu(m); + bind_swap(m); bind_sinh(m); bind_layer_norm(m); bind_topksoftmax(m); diff --git a/src/infinicore/pybind11/ops/axpy.hpp b/src/infinicore/pybind11/ops/axpy.hpp new file mode 100644 index 000000000..f71270a08 --- /dev/null +++ b/src/infinicore/pybind11/ops/axpy.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include "infinicore/ops/axpy.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_axpy(py::module &m) { + m.def("axpy_", + py::overload_cast(&op::axpy_), + py::arg("alpha"), + py::arg("x"), + py::arg("y"), + R"doc(In-place BLAS level-1 axpy, updating y.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/blas_copy.hpp b/src/infinicore/pybind11/ops/blas_copy.hpp new file mode 100644 index 000000000..c348ac38b --- /dev/null +++ b/src/infinicore/pybind11/ops/blas_copy.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include "infinicore/ops/blas_copy.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_blas_copy(py::module &m) { + m.def("blas_copy_", + &op::blas_copy_, + py::arg("x"), + py::arg("y"), + R"doc(In-place BLAS level-1 copy from x to y.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/scal.hpp b/src/infinicore/pybind11/ops/scal.hpp new file mode 100644 index 000000000..9b914d1fe --- /dev/null +++ b/src/infinicore/pybind11/ops/scal.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include "infinicore/ops/scal.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_scal(py::module &m) { + m.def("scal_", + py::overload_cast(&op::scal_), + py::arg("x"), + py::arg("alpha"), + R"doc(In-place BLAS level-1 scal, updating x.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/swap.hpp b/src/infinicore/pybind11/ops/swap.hpp new file mode 100644 index 000000000..0b8f2ae3e --- /dev/null +++ b/src/infinicore/pybind11/ops/swap.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include "infinicore/ops/swap.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_swap(py::module &m) { + m.def("swap_", + &op::swap_, + py::arg("x"), + py::arg("y"), + R"doc(In-place BLAS level-1 swap.)doc"); +} + +} // namespace infinicore::ops diff --git a/test/infinicore/ops/axpy.py b/test/infinicore/ops/axpy.py new file mode 100644 index 000000000..cf37c8100 --- /dev/null +++ b/test/infinicore/ops/axpy.py @@ -0,0 +1,92 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) +from framework.tensor import TensorInitializer + +import infinicore + +_TEST_CASES_DATA = [ + ((3,), None, None), + ((8,), (2,), (3,)), + ((32,), None, (2,)), + ((257,), (3,), None), + ((65535,), None, None), +] + +_TENSOR_DTYPES = [ + infinicore.float16, + infinicore.float32, + # infinicore.float64, + infinicore.bfloat16, +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, + infinicore.float64: {"atol": 1e-9, "rtol": 1e-9}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, +} + + +def torch_axpy(alpha, x, y): + y.add_(x, alpha=alpha.item()) + + return y + + +def parse_test_cases(): + test_cases = [] + for shape, x_strides, y_strides in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) + alpha_spec = TensorSpec.from_tensor( + (), None, dtype, init_mode=TensorInitializer.ONES + ) + x_spec = TensorSpec.from_tensor(shape, x_strides, dtype) + y_spec = TensorSpec.from_tensor(shape, y_strides, dtype) + + test_cases.append( + TestCase( + inputs=[alpha_spec, x_spec, y_spec], + kwargs={}, + output_spec=None, + comparison_target=2, + tolerance=tol, + description="axpy - INPLACE", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 axpy operator test""" + + def __init__(self): + super().__init__("Axpy") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_axpy(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.axpy(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/blas_copy.py b/test/infinicore/ops/blas_copy.py new file mode 100644 index 000000000..62b9b81a8 --- /dev/null +++ b/test/infinicore/ops/blas_copy.py @@ -0,0 +1,87 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) + +import infinicore + +_TEST_CASES_DATA = [ + ((3,), None, None), + ((8,), (2,), (3,)), + ((32,), None, (2,)), + ((257,), (3,), None), + ((65535,), None, None), +] + +_TENSOR_DTYPES = [ + # infinicore.float16, + infinicore.float32, + # infinicore.float64, + # infinicore.bfloat16, +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.float32: {"atol": 1e-7, "rtol": 1e-7}, + infinicore.float64: {"atol": 1e-15, "rtol": 1e-15}, + infinicore.bfloat16: {"atol": 5e-3, "rtol": 1e-2}, +} + + +def torch_blas_copy(x, y): + y.copy_(x) + return y + + +def parse_test_cases(): + test_cases = [] + for shape, x_strides, y_strides in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) + x_spec = TensorSpec.from_tensor(shape, x_strides, dtype) + y_spec = TensorSpec.from_tensor(shape, y_strides, dtype) + + test_cases.append( + TestCase( + inputs=[x_spec, y_spec], + kwargs={}, + output_spec=None, + comparison_target=1, + tolerance=tol, + description="blas_copy - INPLACE", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 copy operator test""" + + def __init__(self): + super().__init__("BlasCopy") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_blas_copy(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.blas_copy(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/scal.py b/test/infinicore/ops/scal.py new file mode 100644 index 000000000..c044989a8 --- /dev/null +++ b/test/infinicore/ops/scal.py @@ -0,0 +1,91 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) +from framework.tensor import TensorInitializer + +import infinicore + +_TEST_CASES_DATA = [ + ((13,), None), + ((13,), (10,)), + ((5632,), None), + ((5632,), (5,)), + ((16,), (4,)), + ((5632,), (32,)), +] + +_TENSOR_DTYPES = [ + infinicore.float16, + infinicore.float32, + # infinicore.float64, + infinicore.bfloat16, +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.float32: {"atol": 1e-7, "rtol": 1e-7}, + infinicore.float64: {"atol": 1e-15, "rtol": 1e-15}, + infinicore.bfloat16: {"atol": 5e-3, "rtol": 1e-2}, +} + + +def torch_scal(x, alpha): + x.mul_(alpha) + return x + + +def parse_test_cases(): + test_cases = [] + for shape, x_strides in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) + x_spec = TensorSpec.from_tensor(shape, x_strides, dtype) + alpha_spec = TensorSpec.from_tensor( + (1,), None, dtype, init_mode=TensorInitializer.ONES + ) + + test_cases.append( + TestCase( + inputs=[x_spec, alpha_spec], + kwargs={}, + output_spec=None, + comparison_target=0, + tolerance=tol, + description="scal - INPLACE", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 scal operator test""" + + def __init__(self): + super().__init__("Scal") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_scal(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.scal(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/swap.py b/test/infinicore/ops/swap.py new file mode 100644 index 000000000..b30f1290d --- /dev/null +++ b/test/infinicore/ops/swap.py @@ -0,0 +1,90 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) + +import infinicore + +_TEST_CASES_DATA = [ + ((13,), None, None), + ((13,), (10,), (10,)), + ((5632,), None, None), + ((5632,), (5,), (5,)), + ((16,), (4,), (4,)), + ((5632,), (32,), (32,)), +] + +_TENSOR_DTYPES = [ + # infinicore.float16, + infinicore.float32, + # infinicore.float64, + # infinicore.bfloat16, +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.float32: {"atol": 1e-7, "rtol": 1e-7}, + infinicore.float64: {"atol": 1e-15, "rtol": 1e-15}, + infinicore.bfloat16: {"atol": 5e-3, "rtol": 1e-2}, +} + + +def torch_swap(x, y): + tmp = x.clone() + x.copy_(y) + y.copy_(tmp) + return x, y + + +def parse_test_cases(): + test_cases = [] + for shape, x_strides, y_strides in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) + x_spec = TensorSpec.from_tensor(shape, x_strides, dtype) + y_spec = TensorSpec.from_tensor(shape, y_strides, dtype) + + test_cases.append( + TestCase( + inputs=[x_spec, y_spec], + kwargs={}, + comparison_target=[0, 1], + tolerance=tol, + output_count=2, + description="swap - INPLACE", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 swap operator test""" + + def __init__(self): + super().__init__("Swap") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_swap(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.swap(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() From 74564bd3de5cb4a3d4a0f05c3dee588aa2bbdff9 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Wed, 6 May 2026 06:41:05 +0000 Subject: [PATCH 17/25] Add InfiniCore `rot`, `rotg`, `rotm` and `rotmg` wrappers --- include/infinicore/ops.hpp | 4 + include/infinicore/ops/rot.hpp | 17 ++ include/infinicore/ops/rotg.hpp | 17 ++ include/infinicore/ops/rotm.hpp | 17 ++ include/infinicore/ops/rotmg.hpp | 17 ++ python/infinicore/__init__.py | 8 + python/infinicore/ops/rot.py | 7 + python/infinicore/ops/rotg.py | 7 + python/infinicore/ops/rotm.py | 7 + python/infinicore/ops/rotmg.py | 13 ++ src/infinicore/ops/rot/rot.cc | 22 ++ src/infinicore/ops/rot/rot_infiniop.cc | 56 +++++ src/infinicore/ops/rotg/rotg.cc | 22 ++ src/infinicore/ops/rotg/rotg_infiniop.cc | 56 +++++ src/infinicore/ops/rotm/rotm.cc | 22 ++ src/infinicore/ops/rotm/rotm_infiniop.cc | 56 +++++ src/infinicore/ops/rotmg/rotmg.cc | 22 ++ src/infinicore/ops/rotmg/rotmg_infiniop.cc | 56 +++++ src/infinicore/pybind11/ops.hpp | 8 + src/infinicore/pybind11/ops/rot.hpp | 21 ++ src/infinicore/pybind11/ops/rotg.hpp | 21 ++ src/infinicore/pybind11/ops/rotm.hpp | 20 ++ src/infinicore/pybind11/ops/rotmg.hpp | 22 ++ test/infinicore/ops/rot.py | 107 ++++++++++ test/infinicore/ops/rotg.py | 145 +++++++++++++ test/infinicore/ops/rotm.py | 110 ++++++++++ test/infinicore/ops/rotmg.py | 233 +++++++++++++++++++++ 27 files changed, 1113 insertions(+) create mode 100644 include/infinicore/ops/rot.hpp create mode 100644 include/infinicore/ops/rotg.hpp create mode 100644 include/infinicore/ops/rotm.hpp create mode 100644 include/infinicore/ops/rotmg.hpp create mode 100644 python/infinicore/ops/rot.py create mode 100644 python/infinicore/ops/rotg.py create mode 100644 python/infinicore/ops/rotm.py create mode 100644 python/infinicore/ops/rotmg.py create mode 100644 src/infinicore/ops/rot/rot.cc create mode 100644 src/infinicore/ops/rot/rot_infiniop.cc create mode 100644 src/infinicore/ops/rotg/rotg.cc create mode 100644 src/infinicore/ops/rotg/rotg_infiniop.cc create mode 100644 src/infinicore/ops/rotm/rotm.cc create mode 100644 src/infinicore/ops/rotm/rotm_infiniop.cc create mode 100644 src/infinicore/ops/rotmg/rotmg.cc create mode 100644 src/infinicore/ops/rotmg/rotmg_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/rot.hpp create mode 100644 src/infinicore/pybind11/ops/rotg.hpp create mode 100644 src/infinicore/pybind11/ops/rotm.hpp create mode 100644 src/infinicore/pybind11/ops/rotmg.hpp create mode 100644 test/infinicore/ops/rot.py create mode 100644 test/infinicore/ops/rotg.py create mode 100644 test/infinicore/ops/rotm.py create mode 100644 test/infinicore/ops/rotmg.py diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 72d6c311c..d72e8b5c3 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -48,6 +48,10 @@ #include "ops/relu.hpp" #include "ops/rms_norm.hpp" #include "ops/rope.hpp" +#include "ops/rot.hpp" +#include "ops/rotg.hpp" +#include "ops/rotm.hpp" +#include "ops/rotmg.hpp" #include "ops/scal.hpp" #include "ops/silu.hpp" #include "ops/silu_and_mul.hpp" diff --git a/include/infinicore/ops/rot.hpp b/include/infinicore/ops/rot.hpp new file mode 100644 index 000000000..6229983e4 --- /dev/null +++ b/include/infinicore/ops/rot.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class Rot { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor); + static void execute(Tensor x, Tensor y, Tensor c, Tensor s); + static common::OpDispatcher &dispatcher(); +}; + +void rot_(Tensor x, Tensor y, Tensor c, Tensor s); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/rotg.hpp b/include/infinicore/ops/rotg.hpp new file mode 100644 index 000000000..c65f211dd --- /dev/null +++ b/include/infinicore/ops/rotg.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class Rotg { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor); + static void execute(Tensor x, Tensor y, Tensor c, Tensor s); + static common::OpDispatcher &dispatcher(); +}; + +void rotg_(Tensor x, Tensor y, Tensor c, Tensor s); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/rotm.hpp b/include/infinicore/ops/rotm.hpp new file mode 100644 index 000000000..813fc27ee --- /dev/null +++ b/include/infinicore/ops/rotm.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class Rotm { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor x, Tensor y, Tensor param); + static common::OpDispatcher &dispatcher(); +}; + +void rotm_(Tensor x, Tensor y, Tensor param); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/rotmg.hpp b/include/infinicore/ops/rotmg.hpp new file mode 100644 index 000000000..e245840a4 --- /dev/null +++ b/include/infinicore/ops/rotmg.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class Rotmg { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor); + static void execute(Tensor d1, Tensor d2, Tensor x1, Tensor y1, Tensor param); + static common::OpDispatcher &dispatcher(); +}; + +void rotmg_(Tensor d1, Tensor d2, Tensor x1, Tensor y1, Tensor param); + +} // namespace infinicore::op diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 883bf7098..3275117e3 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -114,6 +114,10 @@ from infinicore.ops.paged_caching import paged_caching from infinicore.ops.rearrange import rearrange from infinicore.ops.reciprocal import reciprocal +from infinicore.ops.rot import rot +from infinicore.ops.rotg import rotg +from infinicore.ops.rotm import rotm +from infinicore.ops.rotmg import rotmg from infinicore.ops.scal import scal from infinicore.ops.scatter import scatter from infinicore.ops.sinh import sinh @@ -247,6 +251,10 @@ "float_power", "flipud", "scatter", + "rot", + "rotg", + "rotm", + "rotmg", "scal", "logcumsumexp", "logical_not", diff --git a/python/infinicore/ops/rot.py b/python/infinicore/ops/rot.py new file mode 100644 index 000000000..091775c4e --- /dev/null +++ b/python/infinicore/ops/rot.py @@ -0,0 +1,7 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def rot(x: Tensor, y: Tensor, c: Tensor, s: Tensor): + _infinicore.rot_(x._underlying, y._underlying, c._underlying, s._underlying) + return x, y diff --git a/python/infinicore/ops/rotg.py b/python/infinicore/ops/rotg.py new file mode 100644 index 000000000..8299e6205 --- /dev/null +++ b/python/infinicore/ops/rotg.py @@ -0,0 +1,7 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def rotg(x: Tensor, y: Tensor, c: Tensor, s: Tensor): + _infinicore.rotg_(x._underlying, y._underlying, c._underlying, s._underlying) + return x, y, c, s diff --git a/python/infinicore/ops/rotm.py b/python/infinicore/ops/rotm.py new file mode 100644 index 000000000..68acfc231 --- /dev/null +++ b/python/infinicore/ops/rotm.py @@ -0,0 +1,7 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def rotm(x: Tensor, y: Tensor, param: Tensor): + _infinicore.rotm_(x._underlying, y._underlying, param._underlying) + return x, y diff --git a/python/infinicore/ops/rotmg.py b/python/infinicore/ops/rotmg.py new file mode 100644 index 000000000..468dd687e --- /dev/null +++ b/python/infinicore/ops/rotmg.py @@ -0,0 +1,13 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def rotmg(d1: Tensor, d2: Tensor, x1: Tensor, y1: Tensor, param: Tensor): + _infinicore.rotmg_( + d1._underlying, + d2._underlying, + x1._underlying, + y1._underlying, + param._underlying, + ) + return d1, d2, x1, param diff --git a/src/infinicore/ops/rot/rot.cc b/src/infinicore/ops/rot/rot.cc new file mode 100644 index 000000000..70aa8e22c --- /dev/null +++ b/src/infinicore/ops/rot/rot.cc @@ -0,0 +1,22 @@ +#include "infinicore/ops/rot.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Rot::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Rot::execute(Tensor x, Tensor y, Tensor c, Tensor s) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, y, c, s); + infinicore::context::setDevice(x->device()); + dispatcher().lookup(x->device().getType())(x, y, c, s); +} + +void rot_(Tensor x, Tensor y, Tensor c, Tensor s) { + Rot::execute(x, y, c, s); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/rot/rot_infiniop.cc b/src/infinicore/ops/rot/rot_infiniop.cc new file mode 100644 index 000000000..113bc488f --- /dev/null +++ b/src/infinicore/ops/rot/rot_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/rot.hpp" +#include + +namespace infinicore::op::rot_impl::infiniop { + +thread_local common::OpCache caches( + 100, + [](infiniopRotDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyRotDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor x, Tensor y, Tensor c, Tensor s) { + size_t seed = hash_combine(x, y, c, s); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopRotDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateRotDescriptor( + context::getInfiniopHandle(x->device()), &desc, + x->desc(), y->desc(), c->desc(), s->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetRotWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopRot( + desc, workspace->data(), workspace_size, + x->data(), y->data(), c->data(), s->data(), context::getStream())); +} + +static bool registered = []() { + Rot::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::rot_impl::infiniop diff --git a/src/infinicore/ops/rotg/rotg.cc b/src/infinicore/ops/rotg/rotg.cc new file mode 100644 index 000000000..233177b68 --- /dev/null +++ b/src/infinicore/ops/rotg/rotg.cc @@ -0,0 +1,22 @@ +#include "infinicore/ops/rotg.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Rotg::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Rotg::execute(Tensor x, Tensor y, Tensor c, Tensor s) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, y, c, s); + infinicore::context::setDevice(x->device()); + dispatcher().lookup(x->device().getType())(x, y, c, s); +} + +void rotg_(Tensor x, Tensor y, Tensor c, Tensor s) { + Rotg::execute(x, y, c, s); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/rotg/rotg_infiniop.cc b/src/infinicore/ops/rotg/rotg_infiniop.cc new file mode 100644 index 000000000..d5e7a0e99 --- /dev/null +++ b/src/infinicore/ops/rotg/rotg_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/rotg.hpp" +#include + +namespace infinicore::op::rotg_impl::infiniop { + +thread_local common::OpCache caches( + 100, + [](infiniopRotgDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyRotgDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor x, Tensor y, Tensor c, Tensor s) { + size_t seed = hash_combine(x, y, c, s); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopRotgDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateRotgDescriptor( + context::getInfiniopHandle(x->device()), &desc, + x->desc(), y->desc(), c->desc(), s->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetRotgWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopRotg( + desc, workspace->data(), workspace_size, + x->data(), y->data(), c->data(), s->data(), context::getStream())); +} + +static bool registered = []() { + Rotg::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::rotg_impl::infiniop diff --git a/src/infinicore/ops/rotm/rotm.cc b/src/infinicore/ops/rotm/rotm.cc new file mode 100644 index 000000000..2b1921f8b --- /dev/null +++ b/src/infinicore/ops/rotm/rotm.cc @@ -0,0 +1,22 @@ +#include "infinicore/ops/rotm.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Rotm::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Rotm::execute(Tensor x, Tensor y, Tensor param) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, y, param); + infinicore::context::setDevice(x->device()); + dispatcher().lookup(x->device().getType())(x, y, param); +} + +void rotm_(Tensor x, Tensor y, Tensor param) { + Rotm::execute(x, y, param); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/rotm/rotm_infiniop.cc b/src/infinicore/ops/rotm/rotm_infiniop.cc new file mode 100644 index 000000000..9df5c9ab1 --- /dev/null +++ b/src/infinicore/ops/rotm/rotm_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/rotm.hpp" +#include + +namespace infinicore::op::rotm_impl::infiniop { + +thread_local common::OpCache caches( + 100, + [](infiniopRotmDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyRotmDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor x, Tensor y, Tensor param) { + size_t seed = hash_combine(x, y, param); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopRotmDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateRotmDescriptor( + context::getInfiniopHandle(x->device()), &desc, + x->desc(), y->desc(), param->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetRotmWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopRotm( + desc, workspace->data(), workspace_size, + x->data(), y->data(), param->data(), context::getStream())); +} + +static bool registered = []() { + Rotm::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::rotm_impl::infiniop diff --git a/src/infinicore/ops/rotmg/rotmg.cc b/src/infinicore/ops/rotmg/rotmg.cc new file mode 100644 index 000000000..8a2860d63 --- /dev/null +++ b/src/infinicore/ops/rotmg/rotmg.cc @@ -0,0 +1,22 @@ +#include "infinicore/ops/rotmg.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Rotmg::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Rotmg::execute(Tensor d1, Tensor d2, Tensor x1, Tensor y1, Tensor param) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(d1, d2, x1, y1, param); + infinicore::context::setDevice(d1->device()); + dispatcher().lookup(d1->device().getType())(d1, d2, x1, y1, param); +} + +void rotmg_(Tensor d1, Tensor d2, Tensor x1, Tensor y1, Tensor param) { + Rotmg::execute(d1, d2, x1, y1, param); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/rotmg/rotmg_infiniop.cc b/src/infinicore/ops/rotmg/rotmg_infiniop.cc new file mode 100644 index 000000000..ade0bd7aa --- /dev/null +++ b/src/infinicore/ops/rotmg/rotmg_infiniop.cc @@ -0,0 +1,56 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/rotmg.hpp" +#include + +namespace infinicore::op::rotmg_impl::infiniop { + +thread_local common::OpCache caches( + 100, + [](infiniopRotmgDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyRotmgDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor d1, Tensor d2, Tensor x1, Tensor y1, Tensor param) { + size_t seed = hash_combine(d1, d2, x1, y1, param); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopRotmgDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateRotmgDescriptor( + context::getInfiniopHandle(d1->device()), &desc, + d1->desc(), d2->desc(), x1->desc(), y1->desc(), param->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetRotmgWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopRotmg( + desc, workspace->data(), workspace_size, + d1->data(), d2->data(), x1->data(), y1->data(), param->data(), context::getStream())); +} + +static bool registered = []() { + Rotmg::dispatcher().registerDevice({Device::Type::CPU, + Device::Type::CAMBRICON, + Device::Type::METAX}, + &calculate, + false); + return true; +}(); + +} // namespace infinicore::op::rotmg_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 08475a979..198821594 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -90,6 +90,10 @@ #include "ops/relu6.hpp" #include "ops/rms_norm.hpp" #include "ops/rope.hpp" +#include "ops/rot.hpp" +#include "ops/rotg.hpp" +#include "ops/rotm.hpp" +#include "ops/rotmg.hpp" #include "ops/scal.hpp" #include "ops/scatter.hpp" #include "ops/selu.hpp" @@ -201,6 +205,10 @@ inline void bind(py::module &m) { bind_vander(m); bind_unfold(m); bind_rope(m); + bind_rot(m); + bind_rotg(m); + bind_rotm(m); + bind_rotmg(m); bind_floor_divide(m); bind_float_power(m); bind_flipud(m); diff --git a/src/infinicore/pybind11/ops/rot.hpp b/src/infinicore/pybind11/ops/rot.hpp new file mode 100644 index 000000000..9e2de6c47 --- /dev/null +++ b/src/infinicore/pybind11/ops/rot.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include "infinicore/ops/rot.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_rot(py::module &m) { + m.def("rot_", + py::overload_cast(&op::rot_), + py::arg("x"), + py::arg("y"), + py::arg("c"), + py::arg("s"), + R"doc(In-place BLAS level-1 rot, updating x and y.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/rotg.hpp b/src/infinicore/pybind11/ops/rotg.hpp new file mode 100644 index 000000000..a37e79336 --- /dev/null +++ b/src/infinicore/pybind11/ops/rotg.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include "infinicore/ops/rotg.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_rotg(py::module &m) { + m.def("rotg_", + py::overload_cast(&op::rotg_), + py::arg("x"), + py::arg("y"), + py::arg("c"), + py::arg("s"), + R"doc(In-place BLAS level-1 rotg, updating x, y, c, and s.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/rotm.hpp b/src/infinicore/pybind11/ops/rotm.hpp new file mode 100644 index 000000000..a88db38ba --- /dev/null +++ b/src/infinicore/pybind11/ops/rotm.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include "infinicore/ops/rotm.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_rotm(py::module &m) { + m.def("rotm_", + py::overload_cast(&op::rotm_), + py::arg("x"), + py::arg("y"), + py::arg("param"), + R"doc(In-place BLAS level-1 rotm, updating x and y.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/rotmg.hpp b/src/infinicore/pybind11/ops/rotmg.hpp new file mode 100644 index 000000000..72b816751 --- /dev/null +++ b/src/infinicore/pybind11/ops/rotmg.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include + +#include "infinicore/ops/rotmg.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_rotmg(py::module &m) { + m.def("rotmg_", + py::overload_cast(&op::rotmg_), + py::arg("d1"), + py::arg("d2"), + py::arg("x1"), + py::arg("y1"), + py::arg("param"), + R"doc(In-place BLAS level-1 rotmg, updating d1, d2, x1, and param.)doc"); +} + +} // namespace infinicore::ops diff --git a/test/infinicore/ops/rot.py b/test/infinicore/ops/rot.py new file mode 100644 index 000000000..bace043ee --- /dev/null +++ b/test/infinicore/ops/rot.py @@ -0,0 +1,107 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) +from framework.tensor import TensorInitializer + +import infinicore + +_TEST_CASES_DATA = [ + ((13,), None, None), + ((13,), (10,), (10,)), + ((5632,), None, None), + ((5632,), (5,), (5,)), + ((16,), (4,), (4,)), + ((5632,), (32,), (32,)), +] + +_TENSOR_DTYPES = [ + infinicore.float16, + infinicore.float32, + # infinicore.float64, + infinicore.bfloat16, +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, + infinicore.float64: {"atol": 1e-9, "rtol": 1e-9}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, +} + + +def torch_rot(x, y, c, s): + x0 = x.clone() + y0 = y.clone() + x.copy_(c * x0 + s * y0) + y.copy_(c * y0 - s * x0) + return x, y + + +def parse_test_cases(): + test_cases = [] + for shape, x_strides, y_strides in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) + x_spec = TensorSpec.from_tensor(shape, x_strides, dtype) + y_spec = TensorSpec.from_tensor(shape, y_strides, dtype) + c_spec = TensorSpec.from_tensor( + (), + None, + dtype, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor(0.6), + ) + s_spec = TensorSpec.from_tensor( + (), + None, + dtype, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor(0.8), + ) + + test_cases.append( + TestCase( + inputs=[x_spec, y_spec, c_spec, s_spec], + kwargs={}, + comparison_target=[0, 1], + tolerance=tol, + output_count=2, + description="rot - INPLACE", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 rot operator test""" + + def __init__(self): + super().__init__("Rot") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_rot(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.rot(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/rotg.py b/test/infinicore/ops/rotg.py new file mode 100644 index 000000000..3b269925b --- /dev/null +++ b/test/infinicore/ops/rotg.py @@ -0,0 +1,145 @@ +import math +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) +from framework.tensor import TensorInitializer + +import infinicore + +_TEST_CASES_DATA = [ + (0.0, 0.0), + (3.0, 4.0), + (-2.5, 5.0), + (7.0, -1.5), + (-3.2, -8.4), +] + +_TENSOR_DTYPES = [ + # infinicore.float16, + infinicore.float32, + # infinicore.float64, + # infinicore.bfloat16, +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, + infinicore.float64: {"atol": 1e-7, "rtol": 1e-7}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, +} + + +def torch_rotg(a, b, c, s): + a0 = a.item() + b0 = b.item() + anorm = abs(a0) + bnorm = abs(b0) + if bnorm == 0.0: + a.fill_(a0) + b.zero_() + c.fill_(1.0) + s.zero_() + return a, b, c, s + if anorm == 0.0: + a.fill_(b0) + b.fill_(1.0) + c.zero_() + s.fill_(1.0) + return a, b, c, s + + sigma = math.copysign(1.0, a0 if anorm > bnorm else b0) + r = sigma * math.hypot(a0, b0) + c0 = a0 / r + s0 = b0 / r + if anorm > bnorm: + z = s0 + elif c0 != 0.0: + z = 1.0 / c0 + else: + z = 1.0 + + a.fill_(r) + b.fill_(z) + c.fill_(c0) + s.fill_(s0) + return a, b, c, s + + +def parse_test_cases(): + test_cases = [] + for a_value, b_value in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) + a_spec = TensorSpec.from_tensor( + (1,), + None, + dtype, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([a_value]), + ) + b_spec = TensorSpec.from_tensor( + (1,), + None, + dtype, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([b_value]), + ) + c_spec = TensorSpec.from_tensor( + (), + None, + dtype, + init_mode=TensorInitializer.ZEROS, + ) + s_spec = TensorSpec.from_tensor( + (), + None, + dtype, + init_mode=TensorInitializer.ZEROS, + ) + + test_cases.append( + TestCase( + inputs=[a_spec, b_spec, c_spec, s_spec], + kwargs={}, + comparison_target=[0, 1, 2, 3], + tolerance=tol, + output_count=4, + description="rotg - INPLACE", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 rotg operator test""" + + def __init__(self): + super().__init__("Rotg") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_rotg(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.rotg(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/rotm.py b/test/infinicore/ops/rotm.py new file mode 100644 index 000000000..36c3ca019 --- /dev/null +++ b/test/infinicore/ops/rotm.py @@ -0,0 +1,110 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) +from framework.tensor import TensorInitializer + +import infinicore + +_TEST_CASES_DATA = [ + ((13,), None, None, (-1.0, 1.2, -0.3, 0.4, 0.8)), + ((13,), (10,), (10,), (0.0, 0.0, -0.25, 0.5, 0.0)), + ((5632,), None, None, (1.0, 1.1, 0.0, 0.0, 0.9)), + ((5632,), (5,), (5,), (-2.0, 0.0, 0.0, 0.0, 0.0)), +] + +_TENSOR_DTYPES = [ + # infinicore.float16, + infinicore.float32, + # infinicore.float64, + # infinicore.bfloat16, +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, + infinicore.float64: {"atol": 1e-9, "rtol": 1e-9}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, +} + + +def torch_rotm(x, y, param): + sflag, sh11, sh21, sh12, sh22 = param + if sflag == -2.0: + return x, y + + w = x.clone() + z = y.clone() + + if sflag < 0.0: + x.copy_(w * sh11 + z * sh12) + y.copy_(w * sh21 + z * sh22) + elif sflag == 0.0: + x.copy_(w + z * sh12) + y.copy_(w * sh21 + z) + else: + x.copy_(w * sh11 + z) + y.copy_(-w + sh22 * z) + return x, y + + +def parse_test_cases(): + test_cases = [] + for shape, x_strides, y_strides, param in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) + x_spec = TensorSpec.from_tensor(shape, x_strides, dtype) + y_spec = TensorSpec.from_tensor(shape, y_strides, dtype) + param_spec = TensorSpec.from_tensor( + (5,), + None, + dtype, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor(param), + ) + + test_cases.append( + TestCase( + inputs=[x_spec, y_spec, param_spec], + kwargs={}, + comparison_target=[0, 1], + tolerance=tol, + output_count=2, + description="rotm - INPLACE", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 rotm operator test""" + + def __init__(self): + super().__init__("Rotm") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_rotm(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.rotm(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/rotmg.py b/test/infinicore/ops/rotmg.py new file mode 100644 index 000000000..52fa1c4c3 --- /dev/null +++ b/test/infinicore/ops/rotmg.py @@ -0,0 +1,233 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) +from framework.tensor import TensorInitializer + +import infinicore + +_TEST_CASES_DATA = [ + (1.0, 2.0, 3.0, 4.0), + (2.5, 0.5, -1.2, 0.8), + (3.0, 4.0, 0.0, 2.0), + (1.5, 1.5, 2.0, -3.0), +] + +_TENSOR_DTYPES = [ + # infinicore.float16, + infinicore.float32, + # infinicore.float64, + # infinicore.bfloat16, +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.float32: {"atol": 1e-7, "rtol": 1e-7}, + infinicore.float64: {"atol": 1e-12, "rtol": 1e-12}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, +} + + +def _rotmg_values(d1, d2, x1, y1): + zero = 0.0 + one = 1.0 + two = 2.0 + gam = 4096.0 + gamsq = 1.67772e7 + rgamsq = 5.96046e-8 + + param = [0.0] * 5 + sh11 = sh12 = sh21 = sh22 = 0.0 + + if d1 < zero: + sflag = -one + d1 = d2 = x1 = zero + else: + sp2 = d2 * y1 + if sp2 == zero: + param[0] = -two + return d1, d2, x1, param + + sp1 = d1 * x1 + sq2 = sp2 * y1 + sq1 = sp1 * x1 + + if abs(sq1) > abs(sq2): + sh21 = -y1 / x1 + sh12 = sp2 / sp1 + su = one - sh12 * sh21 + if su > zero: + sflag = zero + d1 = d1 / su + d2 = d2 / su + x1 = x1 * su + else: + sflag = -one + sh11 = sh12 = sh21 = sh22 = zero + d1 = d2 = x1 = zero + else: + if sq2 < zero: + sflag = -one + d1 = d2 = x1 = zero + else: + sflag = one + sh11 = sp1 / sp2 + sh22 = x1 / y1 + su = one + sh11 * sh22 + stemp = d2 / su + d2 = d1 / su + d1 = stemp + x1 = y1 * su + + if d1 != zero: + while d1 <= rgamsq or d1 >= gamsq: + if sflag == zero: + sh11 = one + sh22 = one + sflag = -one + else: + sh21 = -one + sh12 = one + sflag = -one + if d1 <= rgamsq: + d1 = d1 * gam * gam + x1 = x1 / gam + sh11 = sh11 / gam + sh12 = sh12 / gam + else: + d1 = d1 / (gam * gam) + x1 = x1 * gam + sh11 = sh11 * gam + sh12 = sh12 * gam + + if d2 != zero: + while abs(d2) <= rgamsq or abs(d2) >= gamsq: + if sflag == zero: + sh11 = one + sh22 = one + sflag = -one + else: + sh21 = -one + sh12 = one + sflag = -one + if abs(d2) <= rgamsq: + d2 = d2 * gam * gam + sh21 = sh21 / gam + sh22 = sh22 / gam + else: + d2 = d2 / (gam * gam) + sh21 = sh21 * gam + sh22 = sh22 * gam + + if sflag < zero: + param[1] = sh11 + param[2] = sh21 + param[3] = sh12 + param[4] = sh22 + elif sflag == zero: + param[2] = sh21 + param[3] = sh12 + else: + param[1] = sh11 + param[4] = sh22 + + param[0] = sflag + return d1, d2, x1, param + + +def torch_rotmg(d1, d2, x1, y1, param): + out_d1, out_d2, out_x1, out_param = _rotmg_values( + d1.item(), d2.item(), x1.item(), y1.item() + ) + d1.fill_(out_d1) + d2.fill_(out_d2) + x1.fill_(out_x1) + param.copy_(torch.tensor(out_param, dtype=param.dtype, device=param.device)) + return d1, d2, x1, param + + +def parse_test_cases(): + test_cases = [] + for d1_value, d2_value, x1_value, y1_value in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) + d1_spec = TensorSpec.from_tensor( + (1,), + None, + dtype, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([d1_value]), + ) + d2_spec = TensorSpec.from_tensor( + (1,), + None, + dtype, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([d2_value]), + ) + x1_spec = TensorSpec.from_tensor( + (1,), + None, + dtype, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([x1_value]), + ) + y1_spec = TensorSpec.from_tensor( + (1,), + None, + dtype, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([y1_value]), + ) + param_spec = TensorSpec.from_tensor( + (5,), + None, + dtype, + init_mode=TensorInitializer.ZEROS, + ) + + test_cases.append( + TestCase( + inputs=[d1_spec, d2_spec, x1_spec, y1_spec, param_spec], + kwargs={}, + comparison_target=[0, 1, 2, 4], + tolerance=tol, + output_count=4, + description="rotmg - INPLACE", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """BLAS Level-1 rotmg operator test""" + + def __init__(self): + super().__init__("Rotmg") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_rotmg(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.rotmg(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() From f91e4d34f919473c74efc0c0ee5caac47a3a823d Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Thu, 7 May 2026 06:25:01 +0000 Subject: [PATCH 18/25] Fix format and type conversion errors --- python/infinicore/__init__.py | 10 +- src/infiniop/ops/asum/bang/asum_bang.mlu | 4 +- .../ops/asum/bang/asum_bang_kernel.mlu | 106 +++++++++++---- src/infiniop/ops/asum/cpu/asum_cpu.cc | 12 +- src/infiniop/ops/asum/metax/asum_metax.cc | 4 +- src/infiniop/ops/axpy/bang/axpy_bang.mlu | 10 +- .../ops/axpy/bang/axpy_bang_kernel.mlu | 45 +++--- src/infiniop/ops/axpy/cpu/axpy_cpu.cc | 18 ++- src/infiniop/ops/axpy/metax/axpy_metax.cc | 8 +- .../ops/blas_amax/bang/blas_amax_bang.mlu | 8 +- .../blas_amax/bang/blas_amax_bang_kernel.mlu | 32 +++-- .../ops/blas_amax/cpu/blas_amax_cpu.cc | 18 +-- .../ops/blas_amax/metax/blas_amax_metax.cc | 8 +- .../ops/blas_amin/bang/blas_amin_bang.mlu | 8 +- .../blas_amin/bang/blas_amin_bang_kernel.mlu | 32 +++-- .../ops/blas_amin/cpu/blas_amin_cpu.cc | 18 +-- .../ops/blas_amin/metax/blas_amin_metax.cc | 8 +- .../ops/blas_copy/bang/blas_copy_bang.mlu | 14 +- .../blas_copy/bang/blas_copy_bang_kernel.mlu | 32 +++-- .../ops/blas_copy/cpu/blas_copy_cpu.cc | 8 +- .../ops/blas_copy/metax/blas_copy_metax.cc | 10 +- .../ops/blas_dot/bang/blas_dot_bang.mlu | 14 +- .../blas_dot/bang/blas_dot_bang_kernel.mlu | 128 ++++++++++++------ src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc | 10 +- .../ops/blas_dot/metax/blas_dot_metax.cc | 8 +- src/infiniop/ops/nrm2/bang/nrm2_bang.mlu | 11 +- .../ops/nrm2/bang/nrm2_bang_kernel.mlu | 64 +++++---- src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc | 6 +- src/infiniop/ops/nrm2/metax/nrm2_metax.cc | 6 +- src/infiniop/ops/rot/bang/rot_bang.mlu | 14 +- src/infiniop/ops/rot/bang/rot_bang_kernel.mlu | 56 ++++---- src/infiniop/ops/rot/cpu/rot_cpu.cc | 14 +- src/infiniop/ops/rot/metax/rot_metax.cc | 8 +- src/infiniop/ops/rotm/bang/rotm_bang.mlu | 14 +- .../ops/rotm/bang/rotm_bang_kernel.mlu | 90 ++++++------ src/infiniop/ops/rotm/cpu/rotm_cpu.cc | 14 +- src/infiniop/ops/rotm/metax/rotm_metax.cc | 10 +- src/infiniop/ops/scal/bang/scal_bang.mlu | 11 +- .../ops/scal/bang/scal_bang_kernel.mlu | 32 +++-- src/infiniop/ops/scal/cpu/scal_cpu.cc | 6 +- src/infiniop/ops/scal/metax/scal_metax.cc | 6 +- src/infiniop/ops/swap/bang/swap_bang.mlu | 14 +- .../ops/swap/bang/swap_bang_kernel.mlu | 40 +++--- src/infiniop/ops/swap/cpu/swap_cpu.cc | 9 +- src/infiniop/ops/swap/metax/swap_metax.cc | 10 +- 45 files changed, 579 insertions(+), 429 deletions(-) diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 3275117e3..8c9adc64c 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -64,14 +64,14 @@ from infinicore.ops.axpy import axpy from infinicore.ops.baddbmm import baddbmm from infinicore.ops.bilinear import bilinear -from infinicore.ops.blas_amax import blas_amax -from infinicore.ops.blas_amin import blas_amin -from infinicore.ops.blas_copy import blas_copy -from infinicore.ops.blas_dot import blas_dot from infinicore.ops.binary_cross_entropy_with_logits import ( binary_cross_entropy_with_logits, ) from infinicore.ops.bitwise_right_shift import bitwise_right_shift +from infinicore.ops.blas_amax import blas_amax +from infinicore.ops.blas_amin import blas_amin +from infinicore.ops.blas_copy import blas_copy +from infinicore.ops.blas_dot import blas_dot from infinicore.ops.block_diag import block_diag from infinicore.ops.broadcast_to import broadcast_to from infinicore.ops.cat import cat @@ -121,9 +121,9 @@ from infinicore.ops.scal import scal from infinicore.ops.scatter import scatter from infinicore.ops.sinh import sinh -from infinicore.ops.swap import swap from infinicore.ops.squeeze import squeeze from infinicore.ops.sum import sum +from infinicore.ops.swap import swap from infinicore.ops.take import take from infinicore.ops.tan import tan from infinicore.ops.topk import topk diff --git a/src/infiniop/ops/asum/bang/asum_bang.mlu b/src/infiniop/ops/asum/bang/asum_bang.mlu index 4ea5ef5b6..9b85ba56a 100644 --- a/src/infiniop/ops/asum/bang/asum_bang.mlu +++ b/src/infiniop/ops/asum/bang/asum_bang.mlu @@ -33,8 +33,8 @@ infiniStatus_t calculateAsum( Tdata *result, cnrtQueue_t queue) { - const size_t n = info.n; - const ptrdiff_t incx = info.incx; + const int n = utils::cast(info.n); + const int incx = utils::cast(info.incx); cnrtDim3_t k_dim; cnrtFunctionType_t k_type; diff --git a/src/infiniop/ops/asum/bang/asum_bang_kernel.mlu b/src/infiniop/ops/asum/bang/asum_bang_kernel.mlu index 37ba8f681..633ad767d 100644 --- a/src/infiniop/ops/asum/bang/asum_bang_kernel.mlu +++ b/src/infiniop/ops/asum/bang/asum_bang_kernel.mlu @@ -1,54 +1,98 @@ #include "../../../devices/bang/common_bang.h" #include "asum_bang.h" +#include + __nram__ char nram_buffer[NRAM_MAX_SIZE]; +template +__mlu_device__ void asumToCompute(float *dst, const Tdata *src, int size) { + if constexpr (std::is_same_v) { + __bang_half2float(dst, src, size); + } else if constexpr (std::is_same_v) { + __bang_bfloat162float(dst, src, size); + } else { + __memcpy(dst, src, size * sizeof(float), NRAM2NRAM); + } +} + +template +__mlu_device__ float asumToCompute(Tdata value) { + if constexpr (std::is_same_v) { + return __half2float(value); + } else if constexpr (std::is_same_v) { + return __bfloat162float(value); + } else { + return static_cast(value); + } +} + +template +__mlu_device__ void asumStoreResult(Tdata *result, Tdata *nram_result, float *nram_compute, float value) { + nram_compute[0] = value; + if constexpr (std::is_same_v) { + __bang_float2half(nram_result, nram_compute, 1); + result[0] = nram_result[0]; + } else if constexpr (std::is_same_v) { + __bang_float2bfloat16(nram_result, nram_compute, 1); + result[0] = nram_result[0]; + } else { + result[0] = nram_compute[0]; + } +} + template __mlu_global__ void asumKernelContiguous( - size_t n, + int n, const Tdata *x, Tdata *result) { - __mlu_shared__ Tdata shared_partial_sum[4]; + __mlu_shared__ float shared_partial_sum[4]; - Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); - size_t max_chunk_elements = nram_usable / sizeof(Tdata); + size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer); + size_t max_chunk_elements = nram_usable / (sizeof(Tdata) + sizeof(float)); - int align_elements = ALIGN_SIZE / sizeof(Tdata); + size_t align_elements = ALIGN_SIZE / sizeof(Tdata); if (align_elements == 0) { align_elements = 1; } - max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements); + + Tdata *nram_x = (Tdata *)nram_aligned; + float *nram_compute = (float *)(nram_x + chunk_size); int elements_per_core = n / taskDim; int remain = n % taskDim; int core_elements = elements_per_core + (taskId < remain ? 1 : 0); int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; - int chunks = core_elements / max_chunk_elements; - int chunk_rem = core_elements % max_chunk_elements; + int chunks = core_elements / chunk_size; + int chunk_rem = core_elements % chunk_size; - Tdata partial_sum = static_cast(0); + float partial_sum = 0.0f; for (int c = 0; c < chunks; c++) { - size_t current_offset = core_offset + c * max_chunk_elements; - __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + int current_offset = core_offset + c * chunk_size; - __bang_abs(nram_x, nram_x, max_chunk_elements); + __memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); - partial_sum += __bang_sum(nram_x, max_chunk_elements); + asumToCompute(nram_compute, nram_x, chunk_size); + __bang_abs(nram_compute, nram_compute, chunk_size); + + partial_sum += __bang_sum(nram_compute, chunk_size); } if (chunk_rem > 0) { - size_t current_offset = core_offset + chunks * max_chunk_elements; + int current_offset = core_offset + chunks * chunk_size; __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); - __bang_abs(nram_x, nram_x, chunk_rem); + asumToCompute(nram_compute, nram_x, chunk_rem); + __bang_abs(nram_compute, nram_compute, chunk_rem); - partial_sum += __bang_sum(nram_x, chunk_rem); + partial_sum += __bang_sum(nram_compute, chunk_rem); } shared_partial_sum[coreId] = partial_sum; @@ -56,35 +100,41 @@ __mlu_global__ void asumKernelContiguous( __sync_cluster(); if (coreId == 0) { - Tdata cluster_sum = static_cast(0); + float cluster_sum = 0.0f; for (int i = 0; i < coreDim; i++) { cluster_sum += shared_partial_sum[i]; } - result[0] = cluster_sum; + asumStoreResult(result, nram_x, nram_compute, cluster_sum); } } template __mlu_global__ void asumKernelStrided( - size_t n, + int n, const Tdata *x, - size_t incx, + int incx, Tdata *result) { - __mlu_shared__ Tdata shared_partial_sum[4]; + __mlu_shared__ float shared_partial_sum[4]; + + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + + float *nram_compute = (float *)nram_aligned; + Tdata *nram_result = (Tdata *)(nram_compute + 1); int elements_per_core = n / taskDim; int remain = n % taskDim; int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0); int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; - Tdata partial_sum = static_cast(0); + float partial_sum = 0.0f; for (int i = start_idx; i < start_idx + actual_tasks; ++i) { - size_t offset = i * incx; - Tdata abs_val = x[offset] > static_cast(0) ? x[offset] : -x[offset]; + int offset = i * incx; + float x_val = asumToCompute(x[offset]); + float abs_val = x_val > 0.0f ? x_val : -x_val; partial_sum += abs_val; } @@ -94,12 +144,12 @@ __mlu_global__ void asumKernelStrided( __sync_cluster(); if (coreId == 0) { - Tdata cluster_sum = static_cast(0); + float cluster_sum = 0.0f; for (int i = 0; i < coreDim; i++) { cluster_sum += shared_partial_sum[i]; } - result[0] = cluster_sum; + asumStoreResult(result, nram_result, nram_compute, cluster_sum); } -} \ No newline at end of file +} diff --git a/src/infiniop/ops/asum/cpu/asum_cpu.cc b/src/infiniop/ops/asum/cpu/asum_cpu.cc index bae85f6c2..c1178a74c 100644 --- a/src/infiniop/ops/asum/cpu/asum_cpu.cc +++ b/src/infiniop/ops/asum/cpu/asum_cpu.cc @@ -31,22 +31,24 @@ infiniStatus_t calculateAsum( const Tdata *x, Tdata *result) { - const ptrdiff_t n = info.n; + const size_t n = info.n; const ptrdiff_t incx = info.incx; if constexpr (std::is_same::value || std::is_same::value) { float total_sum = 0.0; - for (ptrdiff_t i = 0; i < n; ++i) { - total_sum += std::abs(utils::cast(x[i * incx])); + for (size_t i = 0; i < n; ++i) { + const ptrdiff_t idx = utils::cast(i) * incx; + total_sum += std::abs(utils::cast(x[idx])); } result[0] = utils::cast(total_sum); } else { Tdata total_sum = 0.0; - for (ptrdiff_t i = 0; i < n; ++i) { - total_sum += std::abs(x[i * incx]); + for (size_t i = 0; i < n; ++i) { + const ptrdiff_t idx = utils::cast(i) * incx; + total_sum += std::abs(x[idx]); } result[0] = total_sum; diff --git a/src/infiniop/ops/asum/metax/asum_metax.cc b/src/infiniop/ops/asum/metax/asum_metax.cc index 61098eed2..9fb7453bb 100644 --- a/src/infiniop/ops/asum/metax/asum_metax.cc +++ b/src/infiniop/ops/asum/metax/asum_metax.cc @@ -42,8 +42,8 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t n = _info.n; - const ptrdiff_t incx = _info.incx; + const int n = utils::cast(_info.n); + const int incx = utils::cast(_info.incx); const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( diff --git a/src/infiniop/ops/axpy/bang/axpy_bang.mlu b/src/infiniop/ops/axpy/bang/axpy_bang.mlu index e74d07dbf..c3d72baa8 100644 --- a/src/infiniop/ops/axpy/bang/axpy_bang.mlu +++ b/src/infiniop/ops/axpy/bang/axpy_bang.mlu @@ -36,9 +36,9 @@ infiniStatus_t calculateAxpy( Tdata *y, cnrtQueue_t queue) { - const size_t size = info.n; - const ptrdiff_t incx = info.incx; - const ptrdiff_t incy = info.incy; + const int n = utils::cast(info.n); + const int incx = utils::cast(info.incx); + const int incy = utils::cast(info.incy); cnrtDim3_t k_dim; cnrtFunctionType_t k_type; @@ -50,13 +50,13 @@ infiniStatus_t calculateAxpy( if (incx == 1 && incy == 1) { axpyKernelContiguous<<>>( - size, + n, alpha, x, y); } else { axpyKernelStrided<<>>( - size, + n, alpha, x, incx, diff --git a/src/infiniop/ops/axpy/bang/axpy_bang_kernel.mlu b/src/infiniop/ops/axpy/bang/axpy_bang_kernel.mlu index 8e5ee3db0..df0409987 100644 --- a/src/infiniop/ops/axpy/bang/axpy_bang_kernel.mlu +++ b/src/infiniop/ops/axpy/bang/axpy_bang_kernel.mlu @@ -5,24 +5,24 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE]; template __mlu_global__ void axpyKernelContiguous( - size_t n, + int n, const Tdata *alpha, const Tdata *x, Tdata *y) { - Tdata *nram_align = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_align - nram_buffer); + size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer); size_t max_chunk_elements = nram_usable / (2 * sizeof(Tdata)); - int align_elements = ALIGN_SIZE / sizeof(Tdata); + size_t align_elements = ALIGN_SIZE / sizeof(Tdata); if (align_elements == 0) { align_elements = 1; } - max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements); - Tdata *nram_x = nram_align; - Tdata *nram_y = nram_align + max_chunk_elements; + Tdata *nram_x = (Tdata *)nram_aligned; + Tdata *nram_y = nram_x + chunk_size; int elements_per_core = n / taskDim; int remain = n % taskDim; @@ -33,22 +33,23 @@ __mlu_global__ void axpyKernelContiguous( return; } - int chunks = core_elements / max_chunk_elements; - int chunk_rem = core_elements % max_chunk_elements; + int chunks = core_elements / chunk_size; + int chunk_rem = core_elements % chunk_size; for (int c = 0; c < chunks; c++) { - size_t current_offset = core_offset + c * max_chunk_elements; - __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); - __memcpy(nram_y, y + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + int current_offset = core_offset + c * chunk_size; - __bang_mul_scalar(nram_x, nram_x, alpha[0], max_chunk_elements); - __bang_add(nram_y, nram_y, nram_x, max_chunk_elements); + __memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); - __memcpy(y + current_offset, nram_y, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + __bang_mul_scalar(nram_x, nram_x, alpha[0], chunk_size); + __bang_add(nram_y, nram_y, nram_x, chunk_size); + + __memcpy(y + current_offset, nram_y, chunk_size * sizeof(Tdata), NRAM2GDRAM); } if (chunk_rem > 0) { - size_t current_offset = core_offset + chunks * max_chunk_elements; + int current_offset = core_offset + chunks * chunk_size; int align_rem = ((chunk_rem + align_elements - 1) / align_elements) * align_elements; __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); @@ -63,12 +64,12 @@ __mlu_global__ void axpyKernelContiguous( template __mlu_global__ void axpyKernelStrided( - size_t n, + int n, const Tdata *alpha, const Tdata *x, - size_t incx, + int incx, Tdata *y, - size_t incy) { + int incy) { int elements_per_core = n / taskDim; int remain = n % taskDim; @@ -76,9 +77,9 @@ __mlu_global__ void axpyKernelStrided( int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; for (int i = start_idx; i < start_idx + actual_tasks; ++i) { - size_t idx_x = i * incx; - size_t idx_y = i * incy; + int idx_x = i * incx; + int idx_y = i * incy; y[idx_y] += alpha[0] * x[idx_x]; } -} \ No newline at end of file +} diff --git a/src/infiniop/ops/axpy/cpu/axpy_cpu.cc b/src/infiniop/ops/axpy/cpu/axpy_cpu.cc index 8c42c4855..d2f97f8df 100644 --- a/src/infiniop/ops/axpy/cpu/axpy_cpu.cc +++ b/src/infiniop/ops/axpy/cpu/axpy_cpu.cc @@ -33,21 +33,25 @@ infiniStatus_t calculateAxpy( const Tdata *x, Tdata *y) { - const ptrdiff_t size = info.n; + const size_t n = info.n; const ptrdiff_t incx = info.incx; const ptrdiff_t incy = info.incy; if constexpr (std::is_same::value || std::is_same::value) { const float alpha_f = utils::cast(alpha[0]); - for (ptrdiff_t i = 0; i < size; ++i) { - const float x_f = utils::cast(x[i * incx]); - const float y_f = utils::cast(y[i * incy]); - y[i * incy] = utils::cast(alpha_f * x_f + y_f); + for (size_t i = 0; i < n; ++i) { + const ptrdiff_t x_idx = utils::cast(i) * incx; + const ptrdiff_t y_idx = utils::cast(i) * incy; + const float x_f = utils::cast(x[x_idx]); + const float y_f = utils::cast(y[y_idx]); + y[y_idx] = utils::cast(alpha_f * x_f + y_f); } } else { const Tdata alpha_v = alpha[0]; - for (ptrdiff_t i = 0; i < size; ++i) { - y[i * incy] = alpha_v * x[i * incx] + y[i * incy]; + for (size_t i = 0; i < n; ++i) { + const ptrdiff_t x_idx = utils::cast(i) * incx; + const ptrdiff_t y_idx = utils::cast(i) * incy; + y[y_idx] = alpha_v * x[x_idx] + y[y_idx]; } } diff --git a/src/infiniop/ops/axpy/metax/axpy_metax.cc b/src/infiniop/ops/axpy/metax/axpy_metax.cc index f1fd491cf..3a36e31f3 100644 --- a/src/infiniop/ops/axpy/metax/axpy_metax.cc +++ b/src/infiniop/ops/axpy/metax/axpy_metax.cc @@ -44,9 +44,9 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.n; - const ptrdiff_t incx = _info.incx; - const ptrdiff_t incy = _info.incy; + const int n = utils::cast(_info.n); + const int incx = utils::cast(_info.incx); + const int incy = utils::cast(_info.incy); const infiniDtype_t data_type = _info.data_type; hpccDataType alpha_type, x_type, y_type; @@ -82,7 +82,7 @@ infiniStatus_t Descriptor::calculate( CHECK_MCBLAS(hcblasAxpyEx( handle, - size, + n, alpha, alpha_type, x, diff --git a/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu b/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu index 5e0e6ac46..62f43aaf9 100644 --- a/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu +++ b/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu @@ -34,8 +34,8 @@ infiniStatus_t calculateBlasAmax( int *result, cnrtQueue_t queue) { - const size_t size = info.n; - const ptrdiff_t incx = info.incx; + const int n = utils::cast(info.n); + const int incx = utils::cast(info.incx); cnrtDim3_t k_dim; cnrtFunctionType_t k_type; @@ -47,12 +47,12 @@ infiniStatus_t calculateBlasAmax( if (incx == 1) { blasAmaxKernelContiguous<<>>( - size, + n, x, result); } else { blasAmaxKernelStrided<<>>( - size, + n, x, incx, result); diff --git a/src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu b/src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu index be6efb60a..e3736aa08 100644 --- a/src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu +++ b/src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu @@ -5,42 +5,44 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE]; template __mlu_global__ void blasAmaxKernelContiguous( - size_t n, + int n, const Tdata *x, int *result) { __mlu_shared__ int shared_max_index[4]; __mlu_shared__ Tdata shared_max_value[4]; - Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); + size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer); size_t max_chunk_elements = nram_usable / sizeof(Tdata); - int align_elements = ALIGN_SIZE / sizeof(Tdata); + size_t align_elements = ALIGN_SIZE / sizeof(Tdata); if (align_elements == 0) { align_elements = 1; } - max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements); + + Tdata *nram_x = (Tdata *)nram_aligned; int elements_per_core = n / taskDim; int remain = n % taskDim; int core_elements = elements_per_core + (taskId < remain ? 1 : 0); int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; - int chunks = core_elements / max_chunk_elements; - int chunk_rem = core_elements % max_chunk_elements; + int chunks = core_elements / chunk_size; + int chunk_rem = core_elements % chunk_size; int max_index = -1; Tdata max_value = static_cast(0); for (int c = 0; c < chunks; c++) { - size_t current_offset = core_offset + c * max_chunk_elements; - __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + int current_offset = core_offset + c * chunk_size; + __memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); - __bang_abs(nram_x, nram_x, max_chunk_elements); + __bang_abs(nram_x, nram_x, chunk_size); - for (int i = 0; i < max_chunk_elements; i++) { + for (int i = 0; i < chunk_size; i++) { Tdata abs_val = nram_x[i]; if (abs_val > max_value) { max_value = abs_val; @@ -50,7 +52,7 @@ __mlu_global__ void blasAmaxKernelContiguous( } if (chunk_rem > 0) { - size_t current_offset = core_offset + chunks * max_chunk_elements; + int current_offset = core_offset + chunks * chunk_size; __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); @@ -87,9 +89,9 @@ __mlu_global__ void blasAmaxKernelContiguous( template __mlu_global__ void blasAmaxKernelStrided( - size_t n, + int n, const Tdata *x, - size_t incx, + int incx, int *result) { __mlu_shared__ int shared_max_index[4]; @@ -104,7 +106,7 @@ __mlu_global__ void blasAmaxKernelStrided( Tdata max_value = static_cast(0); for (int i = start_idx; i < start_idx + actual_tasks; ++i) { - size_t offset = i * incx; + int offset = i * incx; Tdata abs_val = x[offset] > static_cast(0) ? x[offset] : -x[offset]; if (abs_val > max_value) { diff --git a/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.cc b/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.cc index 8d66d71f2..4b3553041 100644 --- a/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.cc +++ b/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.cc @@ -32,20 +32,21 @@ infiniStatus_t calculateBlasAmax( const Tdata *x, int *result) { - const ptrdiff_t size = info.n; + const size_t n = info.n; const ptrdiff_t incx = info.incx; - if (size < 1 || incx == 0) { + if (n < 1 || incx == 0) { result[0] = 0; return INFINI_STATUS_SUCCESS; } - int max_index = 0; + size_t max_index = 0; if constexpr (std::is_same::value || std::is_same::value) { float max_value = std::abs(utils::cast(x[0])); - for (ptrdiff_t i = 1; i < size; ++i) { - float current_value = std::abs(utils::cast(x[i * incx])); + for (size_t i = 1; i < n; ++i) { + const ptrdiff_t idx = utils::cast(i) * incx; + float current_value = std::abs(utils::cast(x[idx])); if (current_value > max_value) { max_value = current_value; max_index = i; @@ -54,8 +55,9 @@ infiniStatus_t calculateBlasAmax( } else { Tdata max_value = std::abs(x[0]); - for (ptrdiff_t i = 1; i < size; ++i) { - Tdata current_value = std::abs(x[i * incx]); + for (size_t i = 1; i < n; ++i) { + const ptrdiff_t idx = utils::cast(i) * incx; + Tdata current_value = std::abs(x[idx]); if (current_value > max_value) { max_value = current_value; max_index = i; @@ -100,4 +102,4 @@ infiniStatus_t Descriptor::calculate( #undef CALCULATE_BLAS_AMAX -} // namespace op::blas_amax::cpu \ No newline at end of file +} // namespace op::blas_amax::cpu diff --git a/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc b/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc index cb73bdd67..975586dc2 100644 --- a/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc +++ b/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc @@ -42,8 +42,8 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.n; - const ptrdiff_t incx = _info.incx; + const int n = utils::cast(_info.n); + const int incx = utils::cast(_info.incx); const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( @@ -53,10 +53,10 @@ infiniStatus_t Descriptor::calculate( switch (data_type) { case INFINI_DTYPE_F32: - CHECK_MCBLAS(hcblasIsamax(handle, size, (const float *)x, incx, (int *)result)); + CHECK_MCBLAS(hcblasIsamax(handle, n, (const float *)x, incx, (int *)result)); break; case INFINI_DTYPE_F64: - CHECK_MCBLAS(hcblasIdamax(handle, size, (const double *)x, incx, (int *)result)); + CHECK_MCBLAS(hcblasIdamax(handle, n, (const double *)x, incx, (int *)result)); break; default: return INFINI_STATUS_BAD_TENSOR_DTYPE; diff --git a/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu b/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu index 7c0f795f0..b3020d68a 100644 --- a/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu +++ b/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu @@ -34,8 +34,8 @@ infiniStatus_t calculateBlasAmin( int *result, cnrtQueue_t queue) { - const size_t size = info.n; - const ptrdiff_t incx = info.incx; + const int n = utils::cast(info.n); + const int incx = utils::cast(info.incx); cnrtDim3_t k_dim; cnrtFunctionType_t k_type; @@ -47,12 +47,12 @@ infiniStatus_t calculateBlasAmin( if (incx == 1) { blasAminKernelContiguous<<>>( - size, + n, x, result); } else { blasAminKernelStrided<<>>( - size, + n, x, incx, result); diff --git a/src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu b/src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu index 57b14169f..99d2a828b 100644 --- a/src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu +++ b/src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu @@ -5,43 +5,45 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE]; template __mlu_global__ void blasAminKernelContiguous( - size_t n, + int n, const Tdata *x, int *result) { __mlu_shared__ int shared_min_index[4]; __mlu_shared__ Tdata shared_min_value[4]; - Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); + size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer); size_t max_chunk_elements = nram_usable / sizeof(Tdata); - int align_elements = ALIGN_SIZE / sizeof(Tdata); + size_t align_elements = ALIGN_SIZE / sizeof(Tdata); if (align_elements == 0) { align_elements = 1; } - max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements); + + Tdata *nram_x = (Tdata *)nram_aligned; int elements_per_core = n / taskDim; int remain = n % taskDim; int core_elements = elements_per_core + (taskId < remain ? 1 : 0); int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; - int chunks = core_elements / max_chunk_elements; - int chunk_rem = core_elements % max_chunk_elements; + int chunks = core_elements / chunk_size; + int chunk_rem = core_elements % chunk_size; int min_index = -1; Tdata min_value = static_cast(0); bool initialized = false; for (int c = 0; c < chunks; c++) { - size_t current_offset = core_offset + c * max_chunk_elements; - __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + int current_offset = core_offset + c * chunk_size; + __memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); - __bang_abs(nram_x, nram_x, max_chunk_elements); + __bang_abs(nram_x, nram_x, chunk_size); - for (int i = 0; i < max_chunk_elements; i++) { + for (int i = 0; i < chunk_size; i++) { Tdata abs_val = nram_x[i]; if (!initialized || abs_val < min_value) { min_value = abs_val; @@ -52,7 +54,7 @@ __mlu_global__ void blasAminKernelContiguous( } if (chunk_rem > 0) { - size_t current_offset = core_offset + chunks * max_chunk_elements; + int current_offset = core_offset + chunks * chunk_size; __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); @@ -87,9 +89,9 @@ __mlu_global__ void blasAminKernelContiguous( template __mlu_global__ void blasAminKernelStrided( - size_t n, + int n, const Tdata *x, - size_t incx, + int incx, int *result) { __mlu_shared__ int shared_min_index[4]; @@ -105,7 +107,7 @@ __mlu_global__ void blasAminKernelStrided( bool initialized = false; for (int i = start_idx; i < start_idx + actual_tasks; ++i) { - size_t offset = i * incx; + int offset = i * incx; Tdata abs_val = x[offset] > static_cast(0) ? x[offset] : -x[offset]; if (!initialized || abs_val < min_value) { diff --git a/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.cc b/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.cc index 57c060c63..07cad0461 100644 --- a/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.cc +++ b/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.cc @@ -32,20 +32,21 @@ infiniStatus_t calculateBlasAmin( const Tdata *x, int *result) { - const ptrdiff_t size = info.n; + const size_t n = info.n; const ptrdiff_t incx = info.incx; - if (size < 1 || incx == 0) { + if (n < 1 || incx == 0) { result[0] = 0; return INFINI_STATUS_SUCCESS; } - int min_index = 0; + size_t min_index = 0; if constexpr (std::is_same::value || std::is_same::value) { float min_value = std::abs(utils::cast(x[0])); - for (ptrdiff_t i = 1; i < size; ++i) { - float current_value = std::abs(utils::cast(x[i * incx])); + for (size_t i = 1; i < n; ++i) { + const ptrdiff_t idx = utils::cast(i) * incx; + float current_value = std::abs(utils::cast(x[idx])); if (current_value < min_value) { min_value = current_value; min_index = i; @@ -54,8 +55,9 @@ infiniStatus_t calculateBlasAmin( } else { Tdata min_value = std::abs(x[0]); - for (ptrdiff_t i = 1; i < size; ++i) { - Tdata current_value = std::abs(x[i * incx]); + for (size_t i = 1; i < n; ++i) { + const ptrdiff_t idx = utils::cast(i) * incx; + Tdata current_value = std::abs(x[idx]); if (current_value < min_value) { min_value = current_value; min_index = i; @@ -100,4 +102,4 @@ infiniStatus_t Descriptor::calculate( #undef CALCULATE_BLAS_AMIN -} // namespace op::blas_amin::cpu \ No newline at end of file +} // namespace op::blas_amin::cpu diff --git a/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc b/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc index 8217eacf6..d75ff93ca 100644 --- a/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc +++ b/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc @@ -42,8 +42,8 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.n; - const ptrdiff_t incx = _info.incx; + const int n = utils::cast(_info.n); + const int incx = utils::cast(_info.incx); const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( @@ -53,10 +53,10 @@ infiniStatus_t Descriptor::calculate( switch (data_type) { case INFINI_DTYPE_F32: - CHECK_MCBLAS(hcblasIsamin(handle, size, (const float *)x, incx, (int *)result)); + CHECK_MCBLAS(hcblasIsamin(handle, n, (const float *)x, incx, (int *)result)); break; case INFINI_DTYPE_F64: - CHECK_MCBLAS(hcblasIdamin(handle, size, (const double *)x, incx, (int *)result)); + CHECK_MCBLAS(hcblasIdamin(handle, n, (const double *)x, incx, (int *)result)); break; default: return INFINI_STATUS_BAD_TENSOR_DTYPE; diff --git a/src/infiniop/ops/blas_copy/bang/blas_copy_bang.mlu b/src/infiniop/ops/blas_copy/bang/blas_copy_bang.mlu index 65e68b631..fe47729c0 100644 --- a/src/infiniop/ops/blas_copy/bang/blas_copy_bang.mlu +++ b/src/infiniop/ops/blas_copy/bang/blas_copy_bang.mlu @@ -33,6 +33,10 @@ infiniStatus_t calculateBlasCopy( Tdata *y, cnrtQueue_t queue) { + const int n = utils::cast(info.n); + const int incx = utils::cast(info.incx); + const int incy = utils::cast(info.incy); + cnrtDim3_t k_dim; cnrtFunctionType_t k_type; @@ -41,18 +45,18 @@ infiniStatus_t calculateBlasCopy( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.incx == 1 && info.incy == 1) { + if (incx == 1 && incy == 1) { blasCopyKernelContiguous<<>>( - info.n, + n, x, y); } else { blasCopyKernelStrided<<>>( - info.n, + n, x, - info.incx, + incx, y, - info.incy); + incy); } cnrtQueueSync(queue); diff --git a/src/infiniop/ops/blas_copy/bang/blas_copy_bang_kernel.mlu b/src/infiniop/ops/blas_copy/bang/blas_copy_bang_kernel.mlu index 2ab596ffa..63bc22a66 100644 --- a/src/infiniop/ops/blas_copy/bang/blas_copy_bang_kernel.mlu +++ b/src/infiniop/ops/blas_copy/bang/blas_copy_bang_kernel.mlu @@ -5,20 +5,22 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE]; template __mlu_global__ void blasCopyKernelContiguous( - size_t n, + int n, const Tdata *x, Tdata *y) { - Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); - size_t max_chunk_elements = nram_usable / sizeof(Tdata); + size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer); + size_t max_chunk_elements = nram_usable / (2 * sizeof(Tdata)); - int align_elements = ALIGN_SIZE / sizeof(Tdata); + size_t align_elements = ALIGN_SIZE / sizeof(Tdata); if (align_elements == 0) { align_elements = 1; } - max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements); + + Tdata *nram_x = (Tdata *)nram_aligned; int elements_per_core = n / taskDim; int remain = n % taskDim; @@ -29,17 +31,17 @@ __mlu_global__ void blasCopyKernelContiguous( return; } - int chunks = core_elements / max_chunk_elements; - int chunk_rem = core_elements % max_chunk_elements; + int chunks = core_elements / chunk_size; + int chunk_rem = core_elements % chunk_size; for (int c = 0; c < chunks; c++) { - size_t current_offset = core_offset + c * max_chunk_elements; - __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); - __memcpy(y + current_offset, nram_x, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + int current_offset = core_offset + c * chunk_size; + __memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); + __memcpy(y + current_offset, nram_x, chunk_size * sizeof(Tdata), NRAM2GDRAM); } if (chunk_rem > 0) { - size_t current_offset = core_offset + chunks * max_chunk_elements; + int current_offset = core_offset + chunks * chunk_size; __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); __memcpy(y + current_offset, nram_x, chunk_rem * sizeof(Tdata), NRAM2GDRAM); } @@ -47,11 +49,11 @@ __mlu_global__ void blasCopyKernelContiguous( template __mlu_global__ void blasCopyKernelStrided( - size_t n, + int n, const Tdata *x, - size_t incx, + int incx, Tdata *y, - size_t incy) { + int incy) { int elements_per_core = n / taskDim; int remain = n % taskDim; diff --git a/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.cc b/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.cc index d4671fb72..47123d371 100644 --- a/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.cc +++ b/src/infiniop/ops/blas_copy/cpu/blas_copy_cpu.cc @@ -31,11 +31,11 @@ infiniStatus_t calculateBlasCopy( const Tdata *x, Tdata *y) { - const ptrdiff_t size = info.n; + const size_t n = info.n; - for (ptrdiff_t i = 0; i < size; ++i) { - size_t x_idx = i * info.incx; - size_t y_idx = i * info.incy; + for (size_t i = 0; i < n; ++i) { + ptrdiff_t x_idx = utils::cast(i) * info.incx; + ptrdiff_t y_idx = utils::cast(i) * info.incy; y[y_idx] = x[x_idx]; } diff --git a/src/infiniop/ops/blas_copy/metax/blas_copy_metax.cc b/src/infiniop/ops/blas_copy/metax/blas_copy_metax.cc index 398688a87..e51ec95e9 100644 --- a/src/infiniop/ops/blas_copy/metax/blas_copy_metax.cc +++ b/src/infiniop/ops/blas_copy/metax/blas_copy_metax.cc @@ -42,9 +42,9 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.n; - const ptrdiff_t incx = _info.incx; - const ptrdiff_t incy = _info.incy; + const int n = utils::cast(_info.n); + const int incx = utils::cast(_info.incx); + const int incy = utils::cast(_info.incy); const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( @@ -54,10 +54,10 @@ infiniStatus_t Descriptor::calculate( switch (data_type) { case INFINI_DTYPE_F32: - CHECK_MCBLAS(hcblasScopy(handle, size, (const float *)x, incx, (float *)y, incy)); + CHECK_MCBLAS(hcblasScopy(handle, n, (const float *)x, incx, (float *)y, incy)); break; case INFINI_DTYPE_F64: - CHECK_MCBLAS(hcblasDcopy(handle, size, (const double *)x, incx, (double *)y, incy)); + CHECK_MCBLAS(hcblasDcopy(handle, n, (const double *)x, incx, (double *)y, incy)); break; default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/blas_dot/bang/blas_dot_bang.mlu b/src/infiniop/ops/blas_dot/bang/blas_dot_bang.mlu index 45c93d69c..8d7f1b816 100644 --- a/src/infiniop/ops/blas_dot/bang/blas_dot_bang.mlu +++ b/src/infiniop/ops/blas_dot/bang/blas_dot_bang.mlu @@ -35,6 +35,10 @@ infiniStatus_t calculateBlasDot( Tdata *result, cnrtQueue_t queue) { + const int n = utils::cast(info.n); + const int incx = utils::cast(info.incx); + const int incy = utils::cast(info.incy); + cnrtDim3_t k_dim; cnrtFunctionType_t k_type; @@ -43,19 +47,19 @@ infiniStatus_t calculateBlasDot( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.incx == 1 && info.incy == 1) { + if (incx == 1 && incy == 1) { blasDotKernelContiguous<<>>( - info.n, + n, x, y, result); } else { blasDotKernelStrided<<>>( - info.n, + n, x, - info.incx, + incx, y, - info.incy, + incy, result); } diff --git a/src/infiniop/ops/blas_dot/bang/blas_dot_bang_kernel.mlu b/src/infiniop/ops/blas_dot/bang/blas_dot_bang_kernel.mlu index a1e3a84d9..bf249101e 100644 --- a/src/infiniop/ops/blas_dot/bang/blas_dot_bang_kernel.mlu +++ b/src/infiniop/ops/blas_dot/bang/blas_dot_bang_kernel.mlu @@ -1,56 +1,101 @@ #include "../../../devices/bang/common_bang.h" #include "blas_dot_bang.h" +#include + __nram__ char nram_buffer[NRAM_MAX_SIZE]; +template +__mlu_device__ void blasDotToCompute(float *dst, const Tdata *src, int size) { + if constexpr (std::is_same_v) { + __bang_half2float(dst, src, size); + } else if constexpr (std::is_same_v) { + __bang_bfloat162float(dst, src, size); + } else { + __memcpy(dst, src, size * sizeof(float), NRAM2NRAM); + } +} + +template +__mlu_device__ float blasDotToCompute(Tdata value) { + if constexpr (std::is_same_v) { + return __half2float(value); + } else if constexpr (std::is_same_v) { + return __bfloat162float(value); + } else { + return static_cast(value); + } +} + +template +__mlu_device__ void blasDotStoreResult(Tdata *result, Tdata *nram_result, float *nram_compute, float value) { + nram_compute[0] = value; + if constexpr (std::is_same_v) { + __bang_float2half(nram_result, nram_compute, 1); + result[0] = nram_result[0]; + } else if constexpr (std::is_same_v) { + __bang_float2bfloat16(nram_result, nram_compute, 1); + result[0] = nram_result[0]; + } else { + result[0] = nram_compute[0]; + } +} + template __mlu_global__ void blasDotKernelContiguous( - size_t n, + int n, const Tdata *x, const Tdata *y, Tdata *result) { - __mlu_shared__ Tdata shared_partial_sum[4]; - Tdata *nram_align = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + __mlu_shared__ float shared_partial_sum[4]; + + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_align - nram_buffer); - size_t max_chunk_elements = nram_usable / (2 * sizeof(Tdata)); + size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer); + size_t max_chunk_elements = nram_usable / (2 * sizeof(Tdata) + 2 * sizeof(float)); - int align_elements = ALIGN_SIZE / sizeof(Tdata); + size_t align_elements = ALIGN_SIZE / sizeof(Tdata); if (align_elements == 0) { align_elements = 1; } - max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements); - Tdata *nram_x = nram_align; - Tdata *nram_y = nram_align + max_chunk_elements; + Tdata *nram_x = (Tdata *)nram_aligned; + Tdata *nram_y = nram_x + chunk_size; + float *nram_compute_x = (float *)(nram_y + chunk_size); + float *nram_compute_y = nram_compute_x + chunk_size; int elements_per_core = n / taskDim; int remain = n % taskDim; int core_elements = elements_per_core + (taskId < remain ? 1 : 0); int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; - int chunks = core_elements / max_chunk_elements; - int chunk_rem = core_elements % max_chunk_elements; + int chunks = core_elements / chunk_size; + int chunk_rem = core_elements % chunk_size; - Tdata partial_sum = 0; + float partial_sum = 0.0f; for (int c = 0; c < chunks; c++) { - size_t current_offset = core_offset + c * max_chunk_elements; - __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); - __memcpy(nram_y, y + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); - - __bang_mul(nram_x, nram_x, nram_y, max_chunk_elements); - partial_sum += __bang_sum(nram_x, max_chunk_elements); + int current_offset = core_offset + c * chunk_size; + __memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); + + blasDotToCompute(nram_compute_x, nram_x, chunk_size); + blasDotToCompute(nram_compute_y, nram_y, chunk_size); + __bang_mul(nram_compute_x, nram_compute_x, nram_compute_y, chunk_size); + partial_sum += __bang_sum(nram_compute_x, chunk_size); } if (chunk_rem > 0) { - size_t current_offset = core_offset + chunks * max_chunk_elements; + int current_offset = core_offset + chunks * chunk_size; __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); __memcpy(nram_y, y + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); - __bang_mul(nram_x, nram_x, nram_y, chunk_rem); - partial_sum += __bang_sum(nram_x, chunk_rem); + blasDotToCompute(nram_compute_x, nram_x, chunk_rem); + blasDotToCompute(nram_compute_y, nram_y, chunk_rem); + __bang_mul(nram_compute_x, nram_compute_x, nram_compute_y, chunk_rem); + partial_sum += __bang_sum(nram_compute_x, chunk_rem); } shared_partial_sum[coreId] = partial_sum; @@ -58,37 +103,42 @@ __mlu_global__ void blasDotKernelContiguous( __sync_cluster(); if (coreId == 0) { - Tdata cluster_sum = 0; + float cluster_sum = 0.0f; for (int i = 0; i < coreDim; i++) { cluster_sum += shared_partial_sum[i]; } - result[0] = cluster_sum; + blasDotStoreResult(result, nram_x, nram_compute_x, cluster_sum); } } template __mlu_global__ void blasDotKernelStrided( - size_t n, + int n, const Tdata *x, - ptrdiff_t incx, + int incx, const Tdata *y, - ptrdiff_t incy, + int incy, Tdata *result) { - __mlu_shared__ Tdata shared_partial_sum[4]; + __mlu_shared__ float shared_partial_sum[4]; + + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t elements_per_core = n / taskDim; - size_t remain = n % taskDim; - size_t core_elements = elements_per_core + (taskId < remain ? 1 : 0); - size_t start_idx = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; + float *nram_compute = (float *)nram_aligned; + Tdata *nram_result = (Tdata *)(nram_compute + 1); - Tdata partial_sum = 0; - ptrdiff_t x_offset = static_cast(start_idx) * incx; - ptrdiff_t y_offset = static_cast(start_idx) * incy; + int elements_per_core = n / taskDim; + int remain = n % taskDim; + int core_elements = elements_per_core + (taskId < remain ? 1 : 0); + int start_idx = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; - for (size_t i = 0; i < core_elements; ++i) { - partial_sum += x[x_offset] * y[y_offset]; + float partial_sum = 0.0f; + int x_offset = start_idx * incx; + int y_offset = start_idx * incy; + + for (int i = 0; i < core_elements; ++i) { + partial_sum += blasDotToCompute(x[x_offset]) * blasDotToCompute(y[y_offset]); x_offset += incx; y_offset += incy; } @@ -98,10 +148,10 @@ __mlu_global__ void blasDotKernelStrided( __sync_cluster(); if (coreId == 0) { - Tdata cluster_sum = 0; + float cluster_sum = 0.0f; for (int i = 0; i < coreDim; ++i) { cluster_sum += shared_partial_sum[i]; } - result[0] = cluster_sum; + blasDotStoreResult(result, nram_result, nram_compute, cluster_sum); } -} \ No newline at end of file +} diff --git a/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc b/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc index e88b2abd9..e250aa2ee 100644 --- a/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc +++ b/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.cc @@ -33,17 +33,17 @@ infiniStatus_t calculateBlasDot( const Tdata *y, Tdata *result) { - const ptrdiff_t n = info.n; + const size_t n = info.n; const ptrdiff_t incx = info.incx; const ptrdiff_t incy = info.incy; - ptrdiff_t ix = (incx < 0) ? (1 - n) * incx : 0; - ptrdiff_t iy = (incy < 0) ? (1 - n) * incy : 0; + ptrdiff_t ix = (incx < 0) ? (1 - utils::cast(n)) * incx : 0; + ptrdiff_t iy = (incy < 0) ? (1 - utils::cast(n)) * incy : 0; if constexpr (std::is_same::value || std::is_same::value) { float total = 0.0f; - for (ptrdiff_t i = 0; i < n; ++i) { + for (size_t i = 0; i < n; ++i) { total += utils::cast(x[ix]) * utils::cast(y[iy]); ix += incx; iy += incy; @@ -53,7 +53,7 @@ infiniStatus_t calculateBlasDot( } else { Tdata total = utils::cast(0); - for (ptrdiff_t i = 0; i < n; ++i) { + for (size_t i = 0; i < n; ++i) { total += x[ix] * y[iy]; ix += incx; iy += incy; diff --git a/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc b/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc index 220886132..28e7f301a 100644 --- a/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc +++ b/src/infiniop/ops/blas_dot/metax/blas_dot_metax.cc @@ -44,9 +44,9 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.n; - const ptrdiff_t incx = _info.incx; - const ptrdiff_t incy = _info.incy; + const int n = utils::cast(_info.n); + const int incx = utils::cast(_info.incx); + const int incy = utils::cast(_info.incy); const infiniDtype_t data_type = _info.data_type; hpccDataType x_type, y_type, result_type; @@ -82,7 +82,7 @@ infiniStatus_t Descriptor::calculate( CHECK_MCBLAS(hcblasDotEx( handle, - size, + n, x, x_type, incx, diff --git a/src/infiniop/ops/nrm2/bang/nrm2_bang.mlu b/src/infiniop/ops/nrm2/bang/nrm2_bang.mlu index 6e25f5b68..0d6e54517 100644 --- a/src/infiniop/ops/nrm2/bang/nrm2_bang.mlu +++ b/src/infiniop/ops/nrm2/bang/nrm2_bang.mlu @@ -37,21 +37,24 @@ infiniStatus_t calculateNrm2( cnrtDim3_t k_dim; cnrtFunctionType_t k_type; + const int n = utils::cast(info.n); + const int incx = utils::cast(info.incx); + k_dim.x = 4; k_dim.y = 1; k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.incx == 1) { + if (incx == 1) { Nrm2KernelContiguous<<>>( - info.n, + n, x, result); } else { Nrm2KernelStrided<<>>( - info.n, + n, x, - info.incx, + incx, result); } diff --git a/src/infiniop/ops/nrm2/bang/nrm2_bang_kernel.mlu b/src/infiniop/ops/nrm2/bang/nrm2_bang_kernel.mlu index f48a4ad52..3778b24ec 100644 --- a/src/infiniop/ops/nrm2/bang/nrm2_bang_kernel.mlu +++ b/src/infiniop/ops/nrm2/bang/nrm2_bang_kernel.mlu @@ -6,7 +6,7 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE]; template -__mlu_device__ void nrm2ToCompute(float *dst, const Tdata *src, size_t size) { +__mlu_device__ void nrm2ToCompute(float *dst, const Tdata *src, int size) { if constexpr (std::is_same_v) { __bang_half2float(dst, src, size); } else if constexpr (std::is_same_v) { @@ -32,45 +32,48 @@ __mlu_device__ void nrm2StoreResult(Tdata *result, Tdata *nram_result, float *nr template __mlu_global__ void Nrm2KernelContiguous( - size_t n, + int n, const Tdata *x, Tdata *result) { + __mlu_shared__ float shared_partial_sum[4]; - Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); - size_t max_chunk_elements = (nram_usable - ALIGN_SIZE) / (sizeof(Tdata) + sizeof(float)); + size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer); + size_t max_chunk_elements = nram_usable / (sizeof(Tdata) + sizeof(float)); - int align_elements = ALIGN_SIZE / sizeof(Tdata); + size_t align_elements = ALIGN_SIZE / sizeof(Tdata); if (align_elements == 0) { align_elements = 1; } - max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; - float *nram_compute = (float *)(((size_t)(nram_x + max_chunk_elements) + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements); + + Tdata *nram_x = (Tdata *)nram_aligned; + float *nram_compute = (float *)(nram_x + chunk_size); int elements_per_core = n / taskDim; int remain = n % taskDim; int core_elements = elements_per_core + (taskId < remain ? 1 : 0); int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain; - int chunks = core_elements / max_chunk_elements; - int chunk_rem = core_elements % max_chunk_elements; + int chunks = core_elements / chunk_size; + int chunk_rem = core_elements % chunk_size; float partial_sum = 0.0f; for (int c = 0; c < chunks; c++) { - size_t current_offset = core_offset + c * max_chunk_elements; - __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + int current_offset = core_offset + c * chunk_size; + __memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); - nrm2ToCompute(nram_compute, nram_x, max_chunk_elements); - __bang_square(nram_compute, nram_compute, max_chunk_elements); + nrm2ToCompute(nram_compute, nram_x, chunk_size); + __bang_square(nram_compute, nram_compute, chunk_size); - partial_sum += __bang_sum(nram_compute, max_chunk_elements); + partial_sum += __bang_sum(nram_compute, chunk_size); } if (chunk_rem > 0) { - size_t current_offset = core_offset + chunks * max_chunk_elements; + int current_offset = core_offset + chunks * chunk_size; __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); @@ -97,23 +100,26 @@ __mlu_global__ void Nrm2KernelContiguous( template __mlu_global__ void Nrm2KernelStrided( - size_t n, + int n, const Tdata *x, - size_t incx, + int incx, Tdata *result) { + __mlu_shared__ float shared_partial_sum[4]; - Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); - size_t max_chunk_elements = (nram_usable - ALIGN_SIZE) / (sizeof(Tdata) + sizeof(float)); + size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer); + size_t max_chunk_elements = nram_usable / (sizeof(Tdata) + sizeof(float)); - int align_elements = ALIGN_SIZE / sizeof(Tdata); + size_t align_elements = ALIGN_SIZE / sizeof(Tdata); if (align_elements == 0) { align_elements = 1; } - max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; - float *nram_compute = (float *)(((size_t)(nram_x + max_chunk_elements) + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements); + + Tdata *nram_x = (Tdata *)nram_aligned; + float *nram_compute = (float *)(nram_x + chunk_size); int elements_per_core = n / taskDim; int remain = n % taskDim; @@ -122,12 +128,12 @@ __mlu_global__ void Nrm2KernelStrided( float partial_sum = 0.0f; - int chunks = actual_tasks / max_chunk_elements; - int chunk_rem = actual_tasks % max_chunk_elements; + int chunks = actual_tasks / chunk_size; + int chunk_rem = actual_tasks % chunk_size; for (int c = 0; c < chunks; c++) { - int current_elements = max_chunk_elements; - int current_start = start_idx + c * max_chunk_elements; + int current_elements = chunk_size; + int current_start = start_idx + c * chunk_size; for (int i = 0; i < current_elements; ++i) { nram_x[i] = x[(current_start + i) * incx]; } @@ -139,7 +145,7 @@ __mlu_global__ void Nrm2KernelStrided( } if (chunk_rem > 0) { - int current_start = start_idx + chunks * max_chunk_elements; + int current_start = start_idx + chunks * chunk_size; for (int i = 0; i < chunk_rem; ++i) { nram_x[i] = x[(current_start + i) * incx]; } diff --git a/src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc b/src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc index d6a45b2d2..b494fe7d3 100644 --- a/src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc +++ b/src/infiniop/ops/nrm2/cpu/nrm2_cpu.cc @@ -37,7 +37,7 @@ infiniStatus_t calculateNrm2( using Tcompute = std::conditional_t, double, float>; - const ptrdiff_t n = info.n; + const size_t n = info.n; const ptrdiff_t incx = info.incx; // Blue's scaling constants (float vs double) @@ -79,9 +79,9 @@ infiniStatus_t calculateNrm2( Tcompute abig = Tcompute(0); // 0-based index; handle negative stride - ptrdiff_t ix = (incx < 0) ? (ptrdiff_t(1) - n) * incx : 0; + ptrdiff_t ix = (incx < 0) ? (ptrdiff_t(1) - utils::cast(n)) * incx : 0; - for (ptrdiff_t i = 0; i < n; ++i) { + for (size_t i = 0; i < n; ++i) { Tcompute ax = std::abs(utils::cast(x[ix])); if (ax > tbig) { diff --git a/src/infiniop/ops/nrm2/metax/nrm2_metax.cc b/src/infiniop/ops/nrm2/metax/nrm2_metax.cc index 6c7c6ff89..6b23502bf 100644 --- a/src/infiniop/ops/nrm2/metax/nrm2_metax.cc +++ b/src/infiniop/ops/nrm2/metax/nrm2_metax.cc @@ -42,8 +42,8 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.n; - const ptrdiff_t incx = _info.incx; + const int n = utils::cast(_info.n); + const int incx = utils::cast(_info.incx); const infiniDtype_t data_type = _info.data_type; hpccDataType x_type, result_type; @@ -79,7 +79,7 @@ infiniStatus_t Descriptor::calculate( CHECK_MCBLAS(hcblasNrm2Ex( handle, - size, + n, x, x_type, incx, diff --git a/src/infiniop/ops/rot/bang/rot_bang.mlu b/src/infiniop/ops/rot/bang/rot_bang.mlu index 7da1728e8..b9601bde3 100644 --- a/src/infiniop/ops/rot/bang/rot_bang.mlu +++ b/src/infiniop/ops/rot/bang/rot_bang.mlu @@ -37,6 +37,10 @@ infiniStatus_t calculateRot( const Tdata *s, cnrtQueue_t queue) { + const int n = utils::cast(info.n); + const int incx = utils::cast(info.incx); + const int incy = utils::cast(info.incy); + cnrtDim3_t k_dim; cnrtFunctionType_t k_type; k_dim.x = 4; @@ -44,20 +48,20 @@ infiniStatus_t calculateRot( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.incx == 1 && info.incy == 1) { + if (incx == 1 && incy == 1) { rotKernelContiguous<<>>( - info.n, + n, x, y, c, s); } else { rotKernelStrided<<>>( - info.n, + n, x, - info.incx, + incx, y, - info.incy, + incy, c, s); } diff --git a/src/infiniop/ops/rot/bang/rot_bang_kernel.mlu b/src/infiniop/ops/rot/bang/rot_bang_kernel.mlu index 27f279e14..24cf3ead3 100644 --- a/src/infiniop/ops/rot/bang/rot_bang_kernel.mlu +++ b/src/infiniop/ops/rot/bang/rot_bang_kernel.mlu @@ -4,27 +4,27 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE]; template __mlu_global__ void rotKernelContiguous( - size_t n, + int n, Tdata *x, Tdata *y, const Tdata *c, const Tdata *s) { - Tdata *nram_align = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_align - nram_buffer); + size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer); size_t max_chunk_elements = nram_usable / (4 * sizeof(Tdata)); - int align_elements = ALIGN_SIZE / sizeof(Tdata); + size_t align_elements = ALIGN_SIZE / sizeof(Tdata); if (align_elements == 0) { align_elements = 1; } - max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements); - Tdata *nram_x = nram_align; - Tdata *nram_y = nram_align + max_chunk_elements; - Tdata *nram_x_out = nram_align + 2 * max_chunk_elements; - Tdata *nram_y_out = nram_align + 3 * max_chunk_elements; + Tdata *nram_x = (Tdata *)nram_aligned; + Tdata *nram_y = nram_x + chunk_size; + Tdata *nram_x_out = nram_y + chunk_size; + Tdata *nram_y_out = nram_x_out + chunk_size; int elements_per_core = n / taskDim; int remain = n % taskDim; @@ -35,29 +35,29 @@ __mlu_global__ void rotKernelContiguous( return; } - int chunks = core_elements / max_chunk_elements; - int chunk_rem = core_elements % max_chunk_elements; + int chunks = core_elements / chunk_size; + int chunk_rem = core_elements % chunk_size; for (int ck = 0; ck < chunks; ck++) { - size_t current_offset = core_offset + ck * max_chunk_elements; - __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); - __memcpy(nram_y, y + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + int current_offset = core_offset + ck * chunk_size; + __memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); - __bang_mul_scalar(nram_x_out, nram_x, c[0], max_chunk_elements); - __bang_mul_scalar(nram_y_out, nram_y, s[0], max_chunk_elements); - __bang_add(nram_x_out, nram_x_out, nram_y_out, max_chunk_elements); + __bang_mul_scalar(nram_x_out, nram_x, c[0], chunk_size); + __bang_mul_scalar(nram_y_out, nram_y, s[0], chunk_size); + __bang_add(nram_x_out, nram_x_out, nram_y_out, chunk_size); - __memcpy(x + current_offset, nram_x_out, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + __memcpy(x + current_offset, nram_x_out, chunk_size * sizeof(Tdata), NRAM2GDRAM); - __bang_mul_scalar(nram_y_out, nram_y, c[0], max_chunk_elements); - __bang_mul_scalar(nram_x_out, nram_x, s[0], max_chunk_elements); - __bang_sub(nram_y_out, nram_y_out, nram_x_out, max_chunk_elements); + __bang_mul_scalar(nram_y_out, nram_y, c[0], chunk_size); + __bang_mul_scalar(nram_x_out, nram_x, s[0], chunk_size); + __bang_sub(nram_y_out, nram_y_out, nram_x_out, chunk_size); - __memcpy(y + current_offset, nram_y_out, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + __memcpy(y + current_offset, nram_y_out, chunk_size * sizeof(Tdata), NRAM2GDRAM); } if (chunk_rem > 0) { - size_t current_offset = core_offset + chunks * max_chunk_elements; + int current_offset = core_offset + chunks * chunk_size; __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); __memcpy(nram_y, y + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); @@ -77,11 +77,11 @@ __mlu_global__ void rotKernelContiguous( template __mlu_global__ void rotKernelStrided( - size_t n, + int n, Tdata *x, - size_t incx, + int incx, Tdata *y, - size_t incy, + int incy, const Tdata *c, const Tdata *s) { @@ -91,8 +91,8 @@ __mlu_global__ void rotKernelStrided( int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; for (int i = start_idx; i < start_idx + actual_tasks; ++i) { - size_t x_idx = i * incx; - size_t y_idx = i * incy; + int x_idx = i * incx; + int y_idx = i * incy; Tdata x_val = x[x_idx]; Tdata y_val = y[y_idx]; diff --git a/src/infiniop/ops/rot/cpu/rot_cpu.cc b/src/infiniop/ops/rot/cpu/rot_cpu.cc index 6087487d1..ce804ee0f 100644 --- a/src/infiniop/ops/rot/cpu/rot_cpu.cc +++ b/src/infiniop/ops/rot/cpu/rot_cpu.cc @@ -40,20 +40,20 @@ infiniStatus_t calculateRot( const Tcompute c_val = utils::cast(c[0]); const Tcompute s_val = utils::cast(s[0]); - const ptrdiff_t size = static_cast(info.n); + const size_t n = info.n; const ptrdiff_t incx = info.incx; const ptrdiff_t incy = info.incy; - if (size <= 0) { + if (n == 0) { return INFINI_STATUS_SUCCESS; } - const ptrdiff_t ix = incx >= 0 ? 0 : (size - 1) * (-incx); - const ptrdiff_t iy = incy >= 0 ? 0 : (size - 1) * (-incy); + const ptrdiff_t ix = incx >= 0 ? 0 : utils::cast(n - 1) * (-incx); + const ptrdiff_t iy = incy >= 0 ? 0 : utils::cast(n - 1) * (-incy); - for (ptrdiff_t i = 0; i < size; ++i) { - const ptrdiff_t x_idx = ix + i * incx; - const ptrdiff_t y_idx = iy + i * incy; + for (size_t i = 0; i < n; ++i) { + const ptrdiff_t x_idx = ix + utils::cast(i) * incx; + const ptrdiff_t y_idx = iy + utils::cast(i) * incy; const Tcompute x_val = utils::cast(x[x_idx]); const Tcompute y_val = utils::cast(y[y_idx]); diff --git a/src/infiniop/ops/rot/metax/rot_metax.cc b/src/infiniop/ops/rot/metax/rot_metax.cc index 0362643da..a1c3e13ef 100644 --- a/src/infiniop/ops/rot/metax/rot_metax.cc +++ b/src/infiniop/ops/rot/metax/rot_metax.cc @@ -46,9 +46,9 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.n; - const ptrdiff_t incx = _info.incx; - const ptrdiff_t incy = _info.incy; + const int n = utils::cast(_info.n); + const int incx = utils::cast(_info.incx); + const int incy = utils::cast(_info.incy); const infiniDtype_t data_type = _info.data_type; hpccDataType x_type, y_type, cs_type; @@ -84,7 +84,7 @@ infiniStatus_t Descriptor::calculate( CHECK_MCBLAS(hcblasRotEx( handle, - size, + n, x, x_type, incx, diff --git a/src/infiniop/ops/rotm/bang/rotm_bang.mlu b/src/infiniop/ops/rotm/bang/rotm_bang.mlu index 5b1f28be7..2f3ee92f6 100644 --- a/src/infiniop/ops/rotm/bang/rotm_bang.mlu +++ b/src/infiniop/ops/rotm/bang/rotm_bang.mlu @@ -35,6 +35,10 @@ infiniStatus_t calculateRotm( const Tdata *param, cnrtQueue_t queue) { + const int n = utils::cast(info.n); + const int incx = utils::cast(info.incx); + const int incy = utils::cast(info.incy); + cnrtDim3_t k_dim; cnrtFunctionType_t k_type; k_dim.x = 4; @@ -42,19 +46,19 @@ infiniStatus_t calculateRotm( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.incx == 1 && info.incy == 1) { + if (incx == 1 && incy == 1) { rotmKernelContiguous<<>>( - info.n, + n, x, y, param); } else { rotmKernelStrided<<>>( - info.n, + n, x, - info.incx, + incx, y, - info.incy, + incy, param); } diff --git a/src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu b/src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu index 297c7bea7..d17e66856 100644 --- a/src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu +++ b/src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu @@ -4,7 +4,7 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE]; template __mlu_global__ void rotmKernelContiguous( - size_t n, + int n, Tdata *x, Tdata *y, const Tdata *param) { @@ -32,21 +32,21 @@ __mlu_global__ void rotmKernelContiguous( h22 = param[4]; } - Tdata *nram_align = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_align - nram_buffer); + size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer); size_t max_chunk_elements = nram_usable / (4 * sizeof(Tdata)); - int align_elements = ALIGN_SIZE / sizeof(Tdata); + size_t align_elements = ALIGN_SIZE / sizeof(Tdata); if (align_elements == 0) { align_elements = 1; } - max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements); - Tdata *nram_x = nram_align; - Tdata *nram_y = nram_align + max_chunk_elements; - Tdata *nram_x_out = nram_align + 2 * max_chunk_elements; - Tdata *nram_y_out = nram_align + 3 * max_chunk_elements; + Tdata *nram_x = (Tdata *)nram_aligned; + Tdata *nram_y = nram_x + chunk_size; + Tdata *nram_x_out = nram_y + chunk_size; + Tdata *nram_y_out = nram_x_out + chunk_size; int elements_per_core = n / taskDim; int remain = n % taskDim; @@ -57,42 +57,42 @@ __mlu_global__ void rotmKernelContiguous( return; } - int chunks = core_elements / max_chunk_elements; - int chunk_rem = core_elements % max_chunk_elements; + int chunks = core_elements / chunk_size; + int chunk_rem = core_elements % chunk_size; for (int c = 0; c < chunks; c++) { - size_t current_offset = core_offset + c * max_chunk_elements; - __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); - __memcpy(nram_y, y + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + int current_offset = core_offset + c * chunk_size; + __memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); if (flag < static_cast(0)) { - __bang_mul_scalar(nram_x_out, nram_x, h11, max_chunk_elements); - __bang_mul_scalar(nram_y_out, nram_y, h12, max_chunk_elements); - __bang_add(nram_x_out, nram_x_out, nram_y_out, max_chunk_elements); + __bang_mul_scalar(nram_x_out, nram_x, h11, chunk_size); + __bang_mul_scalar(nram_y_out, nram_y, h12, chunk_size); + __bang_add(nram_x_out, nram_x_out, nram_y_out, chunk_size); - __bang_mul_scalar(nram_y_out, nram_x, h21, max_chunk_elements); - __bang_mul_scalar(nram_y, nram_y, h22, max_chunk_elements); - __bang_add(nram_y_out, nram_y_out, nram_y, max_chunk_elements); + __bang_mul_scalar(nram_y_out, nram_x, h21, chunk_size); + __bang_mul_scalar(nram_y, nram_y, h22, chunk_size); + __bang_add(nram_y_out, nram_y_out, nram_y, chunk_size); } else if (flag == static_cast(0)) { - __bang_mul_scalar(nram_x_out, nram_y, h12, max_chunk_elements); - __bang_add(nram_x_out, nram_x, nram_x_out, max_chunk_elements); + __bang_mul_scalar(nram_x_out, nram_y, h12, chunk_size); + __bang_add(nram_x_out, nram_x, nram_x_out, chunk_size); - __bang_mul_scalar(nram_y_out, nram_x, h21, max_chunk_elements); - __bang_add(nram_y_out, nram_y_out, nram_y, max_chunk_elements); + __bang_mul_scalar(nram_y_out, nram_x, h21, chunk_size); + __bang_add(nram_y_out, nram_y_out, nram_y, chunk_size); } else { - __bang_mul_scalar(nram_x_out, nram_x, h11, max_chunk_elements); - __bang_add(nram_x_out, nram_x_out, nram_y, max_chunk_elements); + __bang_mul_scalar(nram_x_out, nram_x, h11, chunk_size); + __bang_add(nram_x_out, nram_x_out, nram_y, chunk_size); - __bang_mul_scalar(nram_y_out, nram_y, h22, max_chunk_elements); - __bang_sub(nram_y_out, nram_y_out, nram_x, max_chunk_elements); + __bang_mul_scalar(nram_y_out, nram_y, h22, chunk_size); + __bang_sub(nram_y_out, nram_y_out, nram_x, chunk_size); } - __memcpy(x + current_offset, nram_x_out, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); - __memcpy(y + current_offset, nram_y_out, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + __memcpy(x + current_offset, nram_x_out, chunk_size * sizeof(Tdata), NRAM2GDRAM); + __memcpy(y + current_offset, nram_y_out, chunk_size * sizeof(Tdata), NRAM2GDRAM); } if (chunk_rem > 0) { - size_t current_offset = core_offset + chunks * max_chunk_elements; + int current_offset = core_offset + chunks * chunk_size; __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); __memcpy(nram_y, y + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); @@ -125,11 +125,11 @@ __mlu_global__ void rotmKernelContiguous( template __mlu_global__ void rotmKernelStrided( - size_t n, + int n, Tdata *x, - size_t incx, + int incx, Tdata *y, - size_t incy, + int incy, const Tdata *param) { const Tdata flag = param[0]; @@ -155,17 +155,17 @@ __mlu_global__ void rotmKernelStrided( h22 = param[4]; } - const size_t task = taskId; - const size_t tasks = taskDim; - const size_t per_task = n / tasks; - const size_t remain = n % tasks; - const size_t begin = task < remain ? task * (per_task + 1) : task * per_task + remain; - const size_t count = per_task + (task < remain ? 1 : 0); - - for (size_t i = 0; i < count; ++i) { - const size_t index = begin + i; - const size_t x_idx = index * incx; - const size_t y_idx = index * incy; + const int task = taskId; + const int tasks = taskDim; + const int per_task = n / tasks; + const int remain = n % tasks; + const int begin = task < remain ? task * (per_task + 1) : task * per_task + remain; + const int count = per_task + (task < remain ? 1 : 0); + + for (int i = 0; i < count; ++i) { + const int index = begin + i; + const int x_idx = index * incx; + const int y_idx = index * incy; const Tdata w = x[x_idx]; const Tdata z = y[y_idx]; diff --git a/src/infiniop/ops/rotm/cpu/rotm_cpu.cc b/src/infiniop/ops/rotm/cpu/rotm_cpu.cc index 2f00ba97f..1850a654d 100644 --- a/src/infiniop/ops/rotm/cpu/rotm_cpu.cc +++ b/src/infiniop/ops/rotm/cpu/rotm_cpu.cc @@ -44,11 +44,11 @@ infiniStatus_t calculateRotm( return INFINI_STATUS_SUCCESS; } - const ptrdiff_t size = static_cast(info.n); + const size_t n = info.n; const ptrdiff_t incx = info.incx; const ptrdiff_t incy = info.incy; - const ptrdiff_t kx = incx >= 0 ? 0 : (size - 1) * (-incx); - const ptrdiff_t ky = incy >= 0 ? 0 : (size - 1) * (-incy); + const ptrdiff_t kx = incx >= 0 ? 0 : utils::cast(n - 1) * (-incx); + const ptrdiff_t ky = incy >= 0 ? 0 : utils::cast(n - 1) * (-incy); Tcompute sh11 = zero; Tcompute sh12 = zero; @@ -56,7 +56,7 @@ infiniStatus_t calculateRotm( Tcompute sh22 = zero; if (incx == incy && incx > 0) { - const ptrdiff_t nsteps = size * incx; + const ptrdiff_t nsteps = utils::cast(n) * incx; if (sflag < zero) { sh11 = utils::cast(param[1]); sh12 = utils::cast(param[3]); @@ -96,7 +96,7 @@ infiniStatus_t calculateRotm( sh12 = utils::cast(param[3]); sh21 = utils::cast(param[2]); sh22 = utils::cast(param[4]); - for (ptrdiff_t i = 0; i < size; ++i) { + for (size_t i = 0; i < n; ++i) { const Tcompute w = utils::cast(x[ix]); const Tcompute z = utils::cast(y[iy]); x[ix] = utils::cast(w * sh11 + z * sh12); @@ -107,7 +107,7 @@ infiniStatus_t calculateRotm( } else if (sflag == zero) { sh12 = utils::cast(param[3]); sh21 = utils::cast(param[2]); - for (ptrdiff_t i = 0; i < size; ++i) { + for (size_t i = 0; i < n; ++i) { const Tcompute w = utils::cast(x[ix]); const Tcompute z = utils::cast(y[iy]); x[ix] = utils::cast(w + z * sh12); @@ -118,7 +118,7 @@ infiniStatus_t calculateRotm( } else { sh11 = utils::cast(param[1]); sh22 = utils::cast(param[4]); - for (ptrdiff_t i = 0; i < size; ++i) { + for (size_t i = 0; i < n; ++i) { const Tcompute w = utils::cast(x[ix]); const Tcompute z = utils::cast(y[iy]); x[ix] = utils::cast(w * sh11 + z); diff --git a/src/infiniop/ops/rotm/metax/rotm_metax.cc b/src/infiniop/ops/rotm/metax/rotm_metax.cc index 60d750912..0911623e9 100644 --- a/src/infiniop/ops/rotm/metax/rotm_metax.cc +++ b/src/infiniop/ops/rotm/metax/rotm_metax.cc @@ -44,9 +44,9 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.n; - const ptrdiff_t incx = _info.incx; - const ptrdiff_t incy = _info.incy; + const int n = utils::cast(_info.n); + const int incx = utils::cast(_info.incx); + const int incy = utils::cast(_info.incy); const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( @@ -56,10 +56,10 @@ infiniStatus_t Descriptor::calculate( switch (data_type) { case INFINI_DTYPE_F32: - CHECK_MCBLAS(hcblasSrotm(handle, size, (float *)x, incx, (float *)y, incy, (const float *)param)); + CHECK_MCBLAS(hcblasSrotm(handle, n, (float *)x, incx, (float *)y, incy, (const float *)param)); break; case INFINI_DTYPE_F64: - CHECK_MCBLAS(hcblasDrotm(handle, size, (double *)x, incx, (double *)y, incy, (const double *)param)); + CHECK_MCBLAS(hcblasDrotm(handle, n, (double *)x, incx, (double *)y, incy, (const double *)param)); break; default: return INFINI_STATUS_BAD_TENSOR_DTYPE; diff --git a/src/infiniop/ops/scal/bang/scal_bang.mlu b/src/infiniop/ops/scal/bang/scal_bang.mlu index 31d8a480f..a49fb3196 100644 --- a/src/infiniop/ops/scal/bang/scal_bang.mlu +++ b/src/infiniop/ops/scal/bang/scal_bang.mlu @@ -33,6 +33,9 @@ infiniStatus_t calculateScal( Tdata *x, cnrtQueue_t queue) { + const int n = utils::cast(info.n); + const int incx = utils::cast(info.incx); + cnrtDim3_t k_dim; cnrtFunctionType_t k_type; @@ -41,17 +44,17 @@ infiniStatus_t calculateScal( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.incx == 1) { + if (incx == 1) { scalKernelContiguous<<>>( - info.n, + n, alpha, x); } else { scalKernelStrided<<>>( - info.n, + n, alpha, x, - info.incx); + incx); } cnrtQueueSync(queue); diff --git a/src/infiniop/ops/scal/bang/scal_bang_kernel.mlu b/src/infiniop/ops/scal/bang/scal_bang_kernel.mlu index 1e2f8c9a7..d320d3332 100644 --- a/src/infiniop/ops/scal/bang/scal_bang_kernel.mlu +++ b/src/infiniop/ops/scal/bang/scal_bang_kernel.mlu @@ -5,20 +5,22 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE]; template __mlu_global__ void scalKernelContiguous( - size_t n, + int n, const Tdata *alpha, Tdata *x) { - Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer); + size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer); size_t max_chunk_elements = nram_usable / sizeof(Tdata); - int align_elements = ALIGN_SIZE / sizeof(Tdata); + size_t align_elements = ALIGN_SIZE / sizeof(Tdata); if (align_elements == 0) { align_elements = 1; } - max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements); + + Tdata *nram_x = (Tdata *)nram_aligned; int elements_per_core = n / taskDim; int remain = n % taskDim; @@ -29,20 +31,20 @@ __mlu_global__ void scalKernelContiguous( return; } - int chunks = core_elements / max_chunk_elements; - int chunk_rem = core_elements % max_chunk_elements; + int chunks = core_elements / chunk_size; + int chunk_rem = core_elements % chunk_size; for (int c = 0; c < chunks; c++) { - size_t current_offset = core_offset + c * max_chunk_elements; - __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); + int current_offset = core_offset + c * chunk_size; + __memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); - __bang_mul_scalar(nram_x, nram_x, alpha[0], max_chunk_elements); + __bang_mul_scalar(nram_x, nram_x, alpha[0], chunk_size); - __memcpy(x + current_offset, nram_x, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + __memcpy(x + current_offset, nram_x, chunk_size * sizeof(Tdata), NRAM2GDRAM); } if (chunk_rem > 0) { - size_t current_offset = core_offset + chunks * max_chunk_elements; + int current_offset = core_offset + chunks * chunk_size; int align_rem = ((chunk_rem + align_elements - 1) / align_elements) * align_elements; __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); @@ -55,10 +57,10 @@ __mlu_global__ void scalKernelContiguous( template __mlu_global__ void scalKernelStrided( - size_t n, + int n, const Tdata *alpha, Tdata *x, - size_t incx) { + int incx) { int elements_per_core = n / taskDim; int remain = n % taskDim; @@ -66,7 +68,7 @@ __mlu_global__ void scalKernelStrided( int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; for (int i = start_idx; i < start_idx + actual_tasks; ++i) { - size_t offset = i * incx; + int offset = i * incx; x[offset] *= alpha[0]; } diff --git a/src/infiniop/ops/scal/cpu/scal_cpu.cc b/src/infiniop/ops/scal/cpu/scal_cpu.cc index d54aa8543..3f26bf483 100644 --- a/src/infiniop/ops/scal/cpu/scal_cpu.cc +++ b/src/infiniop/ops/scal/cpu/scal_cpu.cc @@ -31,11 +31,11 @@ infiniStatus_t calculateScal( const Tdata *alpha, Tdata *x) { - const ptrdiff_t size = info.n; + const size_t n = info.n; const ptrdiff_t incx = info.incx; - for (ptrdiff_t i = 0; i < size; ++i) { - const ptrdiff_t idx = i * incx; + for (size_t i = 0; i < n; ++i) { + const ptrdiff_t idx = utils::cast(i) * incx; if constexpr (std::is_same_v || std::is_same_v) { x[idx] = utils::cast(utils::cast(x[idx]) * utils::cast(alpha[0])); diff --git a/src/infiniop/ops/scal/metax/scal_metax.cc b/src/infiniop/ops/scal/metax/scal_metax.cc index a4320cb23..5e67f0a05 100644 --- a/src/infiniop/ops/scal/metax/scal_metax.cc +++ b/src/infiniop/ops/scal/metax/scal_metax.cc @@ -42,8 +42,8 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.n; - const ptrdiff_t incx = _info.incx; + const int n = utils::cast(_info.n); + const int incx = utils::cast(_info.incx); const infiniDtype_t data_type = _info.data_type; hpccDataType alpha_type, x_type; @@ -79,7 +79,7 @@ infiniStatus_t Descriptor::calculate( CHECK_MCBLAS(hcblasScalEx( handle, - size, + n, alpha, alpha_type, x, diff --git a/src/infiniop/ops/swap/bang/swap_bang.mlu b/src/infiniop/ops/swap/bang/swap_bang.mlu index 855997b38..ecd4a418a 100644 --- a/src/infiniop/ops/swap/bang/swap_bang.mlu +++ b/src/infiniop/ops/swap/bang/swap_bang.mlu @@ -33,6 +33,10 @@ infiniStatus_t calculateSwap( Tdata *y, cnrtQueue_t queue) { + const int n = utils::cast(info.n); + const int incx = utils::cast(info.incx); + const int incy = utils::cast(info.incy); + cnrtDim3_t k_dim; cnrtFunctionType_t k_type; @@ -41,18 +45,18 @@ infiniStatus_t calculateSwap( k_dim.z = 1; k_type = cnrtFuncTypeUnion1; - if (info.incx == 1 && info.incy == 1) { + if (incx == 1 && incy == 1) { swapKernelContiguous<<>>( - info.n, + n, x, y); } else { swapKernelStrided<<>>( - info.n, + n, x, - info.incx, + incx, y, - info.incy); + incy); } cnrtQueueSync(queue); diff --git a/src/infiniop/ops/swap/bang/swap_bang_kernel.mlu b/src/infiniop/ops/swap/bang/swap_bang_kernel.mlu index d5f44292d..6a7b86e2c 100644 --- a/src/infiniop/ops/swap/bang/swap_bang_kernel.mlu +++ b/src/infiniop/ops/swap/bang/swap_bang_kernel.mlu @@ -5,23 +5,23 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE]; template __mlu_global__ void swapKernelContiguous( - size_t n, + int n, Tdata *x, Tdata *y) { - Tdata *nram_align = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); + char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_align - nram_buffer); + size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer); size_t max_chunk_elements = nram_usable / (2 * sizeof(Tdata)); - int align_elements = ALIGN_SIZE / sizeof(Tdata); + size_t align_elements = ALIGN_SIZE / sizeof(Tdata); if (align_elements == 0) { align_elements = 1; } - max_chunk_elements = (max_chunk_elements / align_elements) * align_elements; + int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements); - Tdata *nram_x = nram_align; - Tdata *nram_y = nram_align + max_chunk_elements; + Tdata *nram_x = (Tdata *)nram_aligned; + Tdata *nram_y = nram_x + chunk_size; int elements_per_core = n / taskDim; int remain = n % taskDim; @@ -32,19 +32,19 @@ __mlu_global__ void swapKernelContiguous( return; } - int chunks = core_elements / max_chunk_elements; - int chunk_rem = core_elements % max_chunk_elements; + int chunks = core_elements / chunk_size; + int chunk_rem = core_elements % chunk_size; for (int c = 0; c < chunks; c++) { - size_t current_offset = core_offset + c * max_chunk_elements; - __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); - __memcpy(nram_y, y + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM); - __memcpy(x + current_offset, nram_y, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); - __memcpy(y + current_offset, nram_x, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM); + int current_offset = core_offset + c * chunk_size; + __memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); + __memcpy(nram_y, y + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM); + __memcpy(x + current_offset, nram_y, chunk_size * sizeof(Tdata), NRAM2GDRAM); + __memcpy(y + current_offset, nram_x, chunk_size * sizeof(Tdata), NRAM2GDRAM); } if (chunk_rem > 0) { - size_t current_offset = core_offset + chunks * max_chunk_elements; + int current_offset = core_offset + chunks * chunk_size; __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); __memcpy(nram_y, y + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM); __memcpy(x + current_offset, nram_y, chunk_rem * sizeof(Tdata), NRAM2GDRAM); @@ -54,11 +54,11 @@ __mlu_global__ void swapKernelContiguous( template __mlu_global__ void swapKernelStrided( - size_t n, + int n, Tdata *x, - size_t incx, + int incx, Tdata *y, - size_t incy) { + int incy) { int elements_per_core = n / taskDim; int remain = n % taskDim; @@ -66,8 +66,8 @@ __mlu_global__ void swapKernelStrided( int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain; for (int i = start_idx; i < start_idx + actual_tasks; ++i) { - size_t x_idx = i * incx; - size_t y_idx = i * incy; + int x_idx = i * incx; + int y_idx = i * incy; Tdata temp = x[x_idx]; x[x_idx] = y[y_idx]; y[y_idx] = temp; diff --git a/src/infiniop/ops/swap/cpu/swap_cpu.cc b/src/infiniop/ops/swap/cpu/swap_cpu.cc index 3f804d974..1f55bcc98 100644 --- a/src/infiniop/ops/swap/cpu/swap_cpu.cc +++ b/src/infiniop/ops/swap/cpu/swap_cpu.cc @@ -31,14 +31,13 @@ infiniStatus_t calculateSwap( Tdata *x, Tdata *y) { - const ptrdiff_t size = info.n; + const size_t n = info.n; const ptrdiff_t incx = info.incx; const ptrdiff_t incy = info.incy; -#pragma omp parallel for if (size > 1024) - for (ptrdiff_t i = 0; i < size; ++i) { - const ptrdiff_t x_idx = i * incx; - const ptrdiff_t y_idx = i * incy; + for (size_t i = 0; i < n; ++i) { + const ptrdiff_t x_idx = utils::cast(i) * incx; + const ptrdiff_t y_idx = utils::cast(i) * incy; Tdata temp = x[x_idx]; x[x_idx] = y[y_idx]; y[y_idx] = temp; diff --git a/src/infiniop/ops/swap/metax/swap_metax.cc b/src/infiniop/ops/swap/metax/swap_metax.cc index 8c442e533..b84411df6 100644 --- a/src/infiniop/ops/swap/metax/swap_metax.cc +++ b/src/infiniop/ops/swap/metax/swap_metax.cc @@ -42,9 +42,9 @@ infiniStatus_t Descriptor::calculate( (void)workspace; (void)workspace_size; - const size_t size = _info.n; - const ptrdiff_t incx = _info.incx; - const ptrdiff_t incy = _info.incy; + const int n = utils::cast(_info.n); + const int incx = utils::cast(_info.incx); + const int incy = utils::cast(_info.incy); const infiniDtype_t data_type = _info.data_type; CHECK_STATUS(_opaque->internal->useMcblas( @@ -54,10 +54,10 @@ infiniStatus_t Descriptor::calculate( switch (data_type) { case INFINI_DTYPE_F32: - CHECK_MCBLAS(hcblasSswap(handle, size, (float *)x, incx, (float *)y, incy)); + CHECK_MCBLAS(hcblasSswap(handle, n, (float *)x, incx, (float *)y, incy)); break; case INFINI_DTYPE_F64: - CHECK_MCBLAS(hcblasDswap(handle, size, (double *)x, incx, (double *)y, incy)); + CHECK_MCBLAS(hcblasDswap(handle, n, (double *)x, incx, (double *)y, incy)); break; default: return INFINI_STATUS_BAD_TENSOR_DTYPE; From 20e014671455b0c43ae33cf2e298701957890859 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Thu, 7 May 2026 11:09:10 +0000 Subject: [PATCH 19/25] Fix: set InfiniCore device explicitly before running test cases Ensure the test framework explicitly sets the active InfiniCore device to match the selected test device before running test cases. --- test/infinicore/framework/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/infinicore/framework/base.py b/test/infinicore/framework/base.py index 80dcb3eb1..e57e8ec11 100644 --- a/test/infinicore/framework/base.py +++ b/test/infinicore/framework/base.py @@ -80,6 +80,9 @@ def run_tests(self, devices, test_func, test_type="Test"): print(f"Testing {test_type} on {InfiniDeviceNames[device]}") print(f"{'='*60}") + # Keep InfiniCore's runtime aligned with the selected test device. + infinicore.set_device(infinicore.device(torch_device_map[device], 0)) + for test_case in self.test_cases: try: print(f"{test_case}") From 7317e42b237983bd82b12ecb03cb8d58f082b416 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Thu, 7 May 2026 13:26:42 +0000 Subject: [PATCH 20/25] Update `axpy`, `blas_dot`, `nrm2`, `rot` and `scal` to use the graph framework --- include/infinicore/ops/axpy.hpp | 10 +-- include/infinicore/ops/blas_dot.hpp | 12 +-- include/infinicore/ops/nrm2.hpp | 12 +-- include/infinicore/ops/rot.hpp | 10 +-- include/infinicore/ops/scal.hpp | 10 +-- python/infinicore/ops/blas_dot.py | 3 +- python/infinicore/ops/nrm2.py | 3 +- python/infinicore/ops/scal.py | 3 +- src/infinicore/ops/axpy/axpy.cc | 16 ++-- src/infinicore/ops/axpy/axpy_infiniop.cc | 76 +++++++++--------- src/infinicore/ops/blas_dot/blas_dot.cc | 26 +++---- .../ops/blas_dot/blas_dot_infiniop.cc | 76 +++++++++--------- src/infinicore/ops/nrm2/nrm2.cc | 26 +++---- src/infinicore/ops/nrm2/nrm2_infiniop.cc | 74 ++++++++---------- src/infinicore/ops/rot/rot.cc | 16 ++-- src/infinicore/ops/rot/rot_infiniop.cc | 78 +++++++++---------- src/infinicore/ops/scal/scal.cc | 16 ++-- src/infinicore/ops/scal/scal_infiniop.cc | 74 ++++++++---------- src/infinicore/pybind11/ops/axpy.hpp | 2 +- src/infinicore/pybind11/ops/blas_dot.hpp | 2 +- src/infinicore/pybind11/ops/nrm2.hpp | 2 +- src/infinicore/pybind11/ops/rot.hpp | 2 +- src/infinicore/pybind11/ops/scal.hpp | 4 +- 23 files changed, 257 insertions(+), 296 deletions(-) diff --git a/include/infinicore/ops/axpy.hpp b/include/infinicore/ops/axpy.hpp index 3f9da7c4a..280d5ab60 100644 --- a/include/infinicore/ops/axpy.hpp +++ b/include/infinicore/ops/axpy.hpp @@ -1,17 +1,13 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Axpy { -public: - using schema = void (*)(Tensor, Tensor, Tensor); - static void execute(Tensor alpha, Tensor x, Tensor y); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(Axpy, const Tensor &, const Tensor &, Tensor); -void axpy_(Tensor alpha, Tensor x, Tensor y); +void axpy_(const Tensor &alpha, const Tensor &x, Tensor y); } // namespace infinicore::op diff --git a/include/infinicore/ops/blas_dot.hpp b/include/infinicore/ops/blas_dot.hpp index 2eff6f396..157c167f6 100644 --- a/include/infinicore/ops/blas_dot.hpp +++ b/include/infinicore/ops/blas_dot.hpp @@ -1,18 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class BlasDot { -public: - using schema = void (*)(Tensor, Tensor, Tensor); - static void execute(Tensor result, Tensor x, Tensor y); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(BlasDot, const Tensor &, const Tensor &, Tensor); -Tensor blas_dot(Tensor x, Tensor y); -void blas_dot_(Tensor result, Tensor x, Tensor y); +Tensor blas_dot(const Tensor &x, const Tensor &y); +void blas_dot_(const Tensor &x, const Tensor &y, Tensor result); } // namespace infinicore::op diff --git a/include/infinicore/ops/nrm2.hpp b/include/infinicore/ops/nrm2.hpp index ea3c6f8c1..c5552a4c1 100644 --- a/include/infinicore/ops/nrm2.hpp +++ b/include/infinicore/ops/nrm2.hpp @@ -1,18 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Nrm2 { -public: - using schema = void (*)(Tensor, Tensor); - static void execute(Tensor result, Tensor x); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(Nrm2, const Tensor &, Tensor); -Tensor nrm2(Tensor x); -void nrm2_(Tensor result, Tensor x); +Tensor nrm2(const Tensor &x); +void nrm2_(const Tensor &x, Tensor result); } // namespace infinicore::op diff --git a/include/infinicore/ops/rot.hpp b/include/infinicore/ops/rot.hpp index 6229983e4..ff473f33e 100644 --- a/include/infinicore/ops/rot.hpp +++ b/include/infinicore/ops/rot.hpp @@ -1,17 +1,13 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Rot { -public: - using schema = void (*)(Tensor, Tensor, Tensor, Tensor); - static void execute(Tensor x, Tensor y, Tensor c, Tensor s); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(Rot, Tensor, Tensor, const Tensor &, const Tensor &); -void rot_(Tensor x, Tensor y, Tensor c, Tensor s); +void rot_(Tensor x, Tensor y, const Tensor &c, const Tensor &s); } // namespace infinicore::op diff --git a/include/infinicore/ops/scal.hpp b/include/infinicore/ops/scal.hpp index 38045c980..d6cb5ce8c 100644 --- a/include/infinicore/ops/scal.hpp +++ b/include/infinicore/ops/scal.hpp @@ -1,17 +1,13 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Scal { -public: - using schema = void (*)(Tensor, Tensor); - static void execute(Tensor alpha, Tensor x); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(Scal, const Tensor &, Tensor); -void scal_(Tensor x, Tensor alpha); +void scal_(const Tensor &alpha, Tensor x); } // namespace infinicore::op diff --git a/python/infinicore/ops/blas_dot.py b/python/infinicore/ops/blas_dot.py index e96eca838..60f0541b0 100644 --- a/python/infinicore/ops/blas_dot.py +++ b/python/infinicore/ops/blas_dot.py @@ -6,5 +6,6 @@ def blas_dot(x: Tensor, y: Tensor, *, out=None): if out is None: return Tensor(_infinicore.blas_dot(x._underlying, y._underlying)) - _infinicore.blas_dot_(out._underlying, x._underlying, y._underlying) + _infinicore.blas_dot_(x._underlying, y._underlying, out._underlying) + return out diff --git a/python/infinicore/ops/nrm2.py b/python/infinicore/ops/nrm2.py index 34de0e25a..fcc3a0d3a 100644 --- a/python/infinicore/ops/nrm2.py +++ b/python/infinicore/ops/nrm2.py @@ -6,5 +6,6 @@ def nrm2(x: Tensor, *, out=None): if out is None: return Tensor(_infinicore.nrm2(x._underlying)) - _infinicore.nrm2_(out._underlying, x._underlying) + _infinicore.nrm2_(x._underlying, out._underlying) + return out diff --git a/python/infinicore/ops/scal.py b/python/infinicore/ops/scal.py index 86b07eb2b..8302e74a5 100644 --- a/python/infinicore/ops/scal.py +++ b/python/infinicore/ops/scal.py @@ -3,5 +3,6 @@ def scal(x: Tensor, alpha: Tensor): - _infinicore.scal_(x._underlying, alpha._underlying) + _infinicore.scal_(alpha._underlying, x._underlying) + return x diff --git a/src/infinicore/ops/axpy/axpy.cc b/src/infinicore/ops/axpy/axpy.cc index 957b83ef6..d5d3241ec 100644 --- a/src/infinicore/ops/axpy/axpy.cc +++ b/src/infinicore/ops/axpy/axpy.cc @@ -4,18 +4,18 @@ namespace infinicore::op { -common::OpDispatcher &Axpy::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Axpy); -void Axpy::execute(Tensor alpha, Tensor x, Tensor y) { +Axpy::Axpy(const Tensor &alpha, const Tensor &x, Tensor y) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(alpha, x, y); - infinicore::context::setDevice(y->device()); - dispatcher().lookup(y->device().getType())(alpha, x, y); + INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), alpha, x, y); } -void axpy_(Tensor alpha, Tensor x, Tensor y) { +void Axpy::execute(const Tensor &alpha, const Tensor &x, Tensor y) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(Axpy, alpha, x, y); +} + +void axpy_(const Tensor &alpha, const Tensor &x, Tensor y) { Axpy::execute(alpha, x, y); } diff --git a/src/infinicore/ops/axpy/axpy_infiniop.cc b/src/infinicore/ops/axpy/axpy_infiniop.cc index 013b5b2d9..8d54c3823 100644 --- a/src/infinicore/ops/axpy/axpy_infiniop.cc +++ b/src/infinicore/ops/axpy/axpy_infiniop.cc @@ -1,56 +1,52 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/axpy.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::axpy_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopAxpyDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyAxpyDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Axpy, 100); -void calculate(Tensor alpha, Tensor x, Tensor y) { - size_t seed = hash_combine(alpha, x, y); +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, alpha, x, y; +}; - auto device_type = context::getDevice().getType(); - auto device_index = context::getDevice().getIndex(); +void *plan(const Tensor &alpha, const Tensor &x, Tensor y) { + size_t seed = hash_combine(y, alpha, x); - auto &cache = caches.getCache(device_type, device_index); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Axpy, + seed, + alpha->desc(), x->desc(), y->desc()); - auto desc_opt = cache.get(seed); - infiniopAxpyDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, Axpy, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateAxpyDescriptor( - context::getInfiniopHandle(y->device()), &desc, - alpha->desc(), x->desc(), y->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(alpha), + graph::GraphTensor(x), + graph::GraphTensor(y)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetAxpyWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopAxpy( - desc, workspace->data(), workspace_size, - alpha->data(), x->data(), y->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->alpha->data(), + planned->x->data(), + planned->y->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - Axpy::dispatcher().registerDevice({Device::Type::CPU, - Device::Type::CAMBRICON, - Device::Type::METAX}, - &calculate, - false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Axpy, &plan, &run, &cleanup); } // namespace infinicore::op::axpy_impl::infiniop diff --git a/src/infinicore/ops/blas_dot/blas_dot.cc b/src/infinicore/ops/blas_dot/blas_dot.cc index 66465bf27..8dba37acc 100644 --- a/src/infinicore/ops/blas_dot/blas_dot.cc +++ b/src/infinicore/ops/blas_dot/blas_dot.cc @@ -4,25 +4,25 @@ namespace infinicore::op { -common::OpDispatcher &BlasDot::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; - -void BlasDot::execute(Tensor result, Tensor x, Tensor y) { - INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, x, y); - infinicore::context::setDevice(result->device()); - dispatcher().lookup(result->device().getType())(result, x, y); +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(BlasDot); + +BlasDot::BlasDot(const Tensor &x, const Tensor &y, Tensor result) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, y, result); + INFINICORE_GRAPH_OP_DISPATCH(result->device().getType(), x, y, result); +} + +void BlasDot::execute(const Tensor &x, const Tensor &y, Tensor result) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(BlasDot, x, y, result); } -Tensor blas_dot(Tensor x, Tensor y) { +Tensor blas_dot(const Tensor &x, const Tensor &y) { auto result = Tensor::empty({}, x->dtype(), x->device()); - blas_dot_(result, x, y); + blas_dot_(x, y, result); return result; } -void blas_dot_(Tensor result, Tensor x, Tensor y) { - BlasDot::execute(result, x, y); +void blas_dot_(const Tensor &x, const Tensor &y, Tensor result) { + BlasDot::execute(x, y, result); } } // namespace infinicore::op diff --git a/src/infinicore/ops/blas_dot/blas_dot_infiniop.cc b/src/infinicore/ops/blas_dot/blas_dot_infiniop.cc index 51e13a511..78559eea7 100644 --- a/src/infinicore/ops/blas_dot/blas_dot_infiniop.cc +++ b/src/infinicore/ops/blas_dot/blas_dot_infiniop.cc @@ -1,56 +1,52 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/blas_dot.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::blas_dot_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopBlasDotDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyBlasDotDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, BlasDot, 100); -void calculate(Tensor result, Tensor x, Tensor y) { - size_t seed = hash_combine(result, x, y); +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, x, y, result; +}; - auto device_type = context::getDevice().getType(); - auto device_index = context::getDevice().getIndex(); +void *plan(const Tensor &x, const Tensor &y, Tensor result) { + size_t seed = hash_combine(x, y, result); - auto &cache = caches.getCache(device_type, device_index); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, BlasDot, + seed, + x->desc(), y->desc(), result->desc()); - auto desc_opt = cache.get(seed); - infiniopBlasDotDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, BlasDot, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateBlasDotDescriptor( - context::getInfiniopHandle(result->device()), &desc, - x->desc(), y->desc(), result->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(x), + graph::GraphTensor(y), + graph::GraphTensor(result)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetBlasDotWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopBlasDot( - desc, workspace->data(), workspace_size, - x->data(), y->data(), result->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->x->data(), + planned->y->data(), + planned->result->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - BlasDot::dispatcher().registerDevice({Device::Type::CPU, - Device::Type::CAMBRICON, - Device::Type::METAX}, - &calculate, - false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(BlasDot, &plan, &run, &cleanup); } // namespace infinicore::op::blas_dot_impl::infiniop diff --git a/src/infinicore/ops/nrm2/nrm2.cc b/src/infinicore/ops/nrm2/nrm2.cc index 26e9f1d1b..276e1015e 100644 --- a/src/infinicore/ops/nrm2/nrm2.cc +++ b/src/infinicore/ops/nrm2/nrm2.cc @@ -4,25 +4,25 @@ namespace infinicore::op { -common::OpDispatcher &Nrm2::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; - -void Nrm2::execute(Tensor result, Tensor x) { - INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, x); - infinicore::context::setDevice(result->device()); - dispatcher().lookup(result->device().getType())(result, x); +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Nrm2); + +Nrm2::Nrm2(const Tensor &x, Tensor result) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, result); + INFINICORE_GRAPH_OP_DISPATCH(result->device().getType(), x, result); +} + +void Nrm2::execute(const Tensor &x, Tensor result) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(Nrm2, x, result); } -Tensor nrm2(Tensor x) { +Tensor nrm2(const Tensor &x) { auto result = Tensor::empty({}, x->dtype(), x->device()); - nrm2_(result, x); + nrm2_(x, result); return result; } -void nrm2_(Tensor result, Tensor x) { - Nrm2::execute(result, x); +void nrm2_(const Tensor &x, Tensor result) { + Nrm2::execute(x, result); } } // namespace infinicore::op diff --git a/src/infinicore/ops/nrm2/nrm2_infiniop.cc b/src/infinicore/ops/nrm2/nrm2_infiniop.cc index 975237037..3f3ca8c74 100644 --- a/src/infinicore/ops/nrm2/nrm2_infiniop.cc +++ b/src/infinicore/ops/nrm2/nrm2_infiniop.cc @@ -1,56 +1,50 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/nrm2.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::nrm2_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopNrm2Descriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyNrm2Descriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Nrm2, 100); -void calculate(Tensor result, Tensor x) { - size_t seed = hash_combine(result, x); +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, x, result; +}; - auto device_type = context::getDevice().getType(); - auto device_index = context::getDevice().getIndex(); +void *plan(const Tensor &x, Tensor result) { + size_t seed = hash_combine(x, result); - auto &cache = caches.getCache(device_type, device_index); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Nrm2, + seed, + x->desc(), result->desc()); - auto desc_opt = cache.get(seed); - infiniopNrm2Descriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, Nrm2, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateNrm2Descriptor( - context::getInfiniopHandle(result->device()), &desc, - x->desc(), result->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(x), + graph::GraphTensor(result)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetNrm2WorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopNrm2( - desc, workspace->data(), workspace_size, - x->data(), result->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->x->data(), + planned->result->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - Nrm2::dispatcher().registerDevice({Device::Type::CPU, - Device::Type::CAMBRICON, - Device::Type::METAX}, - &calculate, - false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Nrm2, &plan, &run, &cleanup); } // namespace infinicore::op::nrm2_impl::infiniop diff --git a/src/infinicore/ops/rot/rot.cc b/src/infinicore/ops/rot/rot.cc index 70aa8e22c..262cce001 100644 --- a/src/infinicore/ops/rot/rot.cc +++ b/src/infinicore/ops/rot/rot.cc @@ -4,18 +4,18 @@ namespace infinicore::op { -common::OpDispatcher &Rot::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Rot); -void Rot::execute(Tensor x, Tensor y, Tensor c, Tensor s) { +Rot::Rot(Tensor x, Tensor y, const Tensor &c, const Tensor &s) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, y, c, s); - infinicore::context::setDevice(x->device()); - dispatcher().lookup(x->device().getType())(x, y, c, s); + INFINICORE_GRAPH_OP_DISPATCH(x->device().getType(), x, y, c, s); } -void rot_(Tensor x, Tensor y, Tensor c, Tensor s) { +void Rot::execute(Tensor x, Tensor y, const Tensor &c, const Tensor &s) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(Rot, x, y, c, s); +} + +void rot_(Tensor x, Tensor y, const Tensor &c, const Tensor &s) { Rot::execute(x, y, c, s); } diff --git a/src/infinicore/ops/rot/rot_infiniop.cc b/src/infinicore/ops/rot/rot_infiniop.cc index 113bc488f..a114bf110 100644 --- a/src/infinicore/ops/rot/rot_infiniop.cc +++ b/src/infinicore/ops/rot/rot_infiniop.cc @@ -1,56 +1,54 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/rot.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::rot_impl::infiniop { -thread_local common::OpCache caches( - 100, - [](infiniopRotDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyRotDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Rot, 100); -void calculate(Tensor x, Tensor y, Tensor c, Tensor s) { - size_t seed = hash_combine(x, y, c, s); +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, x, y, c, s; +}; - auto device_type = context::getDevice().getType(); - auto device_index = context::getDevice().getIndex(); +void *plan(Tensor x, Tensor y, const Tensor &c, const Tensor &s) { + size_t seed = hash_combine(x, y, c, s); - auto &cache = caches.getCache(device_type, device_index); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Rot, + seed, + x->desc(), y->desc(), c->desc(), s->desc()); - auto desc_opt = cache.get(seed); - infiniopRotDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, Rot, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateRotDescriptor( - context::getInfiniopHandle(x->device()), &desc, - x->desc(), y->desc(), c->desc(), s->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(x), + graph::GraphTensor(y), + graph::GraphTensor(c), + graph::GraphTensor(s)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetRotWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopRot( - desc, workspace->data(), workspace_size, - x->data(), y->data(), c->data(), s->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->x->data(), + planned->y->data(), + planned->c->data(), + planned->s->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - Rot::dispatcher().registerDevice({Device::Type::CPU, - Device::Type::CAMBRICON, - Device::Type::METAX}, - &calculate, - false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Rot, &plan, &run, &cleanup); } // namespace infinicore::op::rot_impl::infiniop diff --git a/src/infinicore/ops/scal/scal.cc b/src/infinicore/ops/scal/scal.cc index c0ef9871a..21258af2a 100644 --- a/src/infinicore/ops/scal/scal.cc +++ b/src/infinicore/ops/scal/scal.cc @@ -4,18 +4,18 @@ namespace infinicore::op { -common::OpDispatcher &Scal::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Scal); -void Scal::execute(Tensor alpha, Tensor x) { +Scal::Scal(const Tensor &alpha, Tensor x) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(alpha, x); - infinicore::context::setDevice(x->device()); - dispatcher().lookup(x->device().getType())(alpha, x); + INFINICORE_GRAPH_OP_DISPATCH(x->device().getType(), alpha, x); } -void scal_(Tensor x, Tensor alpha) { +void Scal::execute(const Tensor &alpha, Tensor x) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(Scal, alpha, x); +} + +void scal_(const Tensor &alpha, Tensor x) { Scal::execute(alpha, x); } diff --git a/src/infinicore/ops/scal/scal_infiniop.cc b/src/infinicore/ops/scal/scal_infiniop.cc index 9e07e1798..25521f4ee 100644 --- a/src/infinicore/ops/scal/scal_infiniop.cc +++ b/src/infinicore/ops/scal/scal_infiniop.cc @@ -1,56 +1,50 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/scal.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::scal_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopScalDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyScalDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Scal, 100); -void calculate(Tensor alpha, Tensor x) { - size_t seed = hash_combine(alpha, x); +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, alpha, x; +}; - auto device_type = context::getDevice().getType(); - auto device_index = context::getDevice().getIndex(); +void *plan(const Tensor &alpha, Tensor x) { + size_t seed = hash_combine(alpha, x); - auto &cache = caches.getCache(device_type, device_index); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Scal, + seed, + alpha->desc(), x->desc()); - auto desc_opt = cache.get(seed); - infiniopScalDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, Scal, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateScalDescriptor( - context::getInfiniopHandle(x->device()), &desc, - alpha->desc(), x->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(alpha), + graph::GraphTensor(x)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetScalWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopScal( - desc, workspace->data(), workspace_size, - alpha->data(), x->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->alpha->data(), + planned->x->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - Scal::dispatcher().registerDevice({Device::Type::CPU, - Device::Type::CAMBRICON, - Device::Type::METAX}, - &calculate, - false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Scal, &plan, &run, &cleanup); } // namespace infinicore::op::scal_impl::infiniop diff --git a/src/infinicore/pybind11/ops/axpy.hpp b/src/infinicore/pybind11/ops/axpy.hpp index f71270a08..fd0aff633 100644 --- a/src/infinicore/pybind11/ops/axpy.hpp +++ b/src/infinicore/pybind11/ops/axpy.hpp @@ -10,7 +10,7 @@ namespace infinicore::ops { inline void bind_axpy(py::module &m) { m.def("axpy_", - py::overload_cast(&op::axpy_), + &op::axpy_, py::arg("alpha"), py::arg("x"), py::arg("y"), diff --git a/src/infinicore/pybind11/ops/blas_dot.hpp b/src/infinicore/pybind11/ops/blas_dot.hpp index 9290db5a7..73b4f0bc9 100644 --- a/src/infinicore/pybind11/ops/blas_dot.hpp +++ b/src/infinicore/pybind11/ops/blas_dot.hpp @@ -17,9 +17,9 @@ inline void bind_blas_dot(py::module &m) { m.def("blas_dot_", &op::blas_dot_, - py::arg("result"), py::arg("x"), py::arg("y"), + py::arg("result"), R"doc(In-place BLAS level-1 dot.)doc"); } diff --git a/src/infinicore/pybind11/ops/nrm2.hpp b/src/infinicore/pybind11/ops/nrm2.hpp index 5431046b7..02b21f53b 100644 --- a/src/infinicore/pybind11/ops/nrm2.hpp +++ b/src/infinicore/pybind11/ops/nrm2.hpp @@ -16,8 +16,8 @@ inline void bind_nrm2(py::module &m) { m.def("nrm2_", &op::nrm2_, - py::arg("result"), py::arg("x"), + py::arg("result"), R"doc(In-place BLAS level-1 nrm2.)doc"); } diff --git a/src/infinicore/pybind11/ops/rot.hpp b/src/infinicore/pybind11/ops/rot.hpp index 9e2de6c47..359cb9745 100644 --- a/src/infinicore/pybind11/ops/rot.hpp +++ b/src/infinicore/pybind11/ops/rot.hpp @@ -10,7 +10,7 @@ namespace infinicore::ops { inline void bind_rot(py::module &m) { m.def("rot_", - py::overload_cast(&op::rot_), + &op::rot_, py::arg("x"), py::arg("y"), py::arg("c"), diff --git a/src/infinicore/pybind11/ops/scal.hpp b/src/infinicore/pybind11/ops/scal.hpp index 9b914d1fe..75a2c5ca8 100644 --- a/src/infinicore/pybind11/ops/scal.hpp +++ b/src/infinicore/pybind11/ops/scal.hpp @@ -10,9 +10,9 @@ namespace infinicore::ops { inline void bind_scal(py::module &m) { m.def("scal_", - py::overload_cast(&op::scal_), - py::arg("x"), + &op::scal_, py::arg("alpha"), + py::arg("x"), R"doc(In-place BLAS level-1 scal, updating x.)doc"); } From 446d0a771806101ee8da27d3e1d8f30d6e1041af Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Fri, 8 May 2026 06:03:40 +0000 Subject: [PATCH 21/25] Update `asum`, `blas_amax`, `blas_amin`, `blas_copy` and `swap` to use the graph framework --- include/infinicore/ops/asum.hpp | 12 +-- include/infinicore/ops/blas_amax.hpp | 12 +-- include/infinicore/ops/blas_amin.hpp | 12 +-- include/infinicore/ops/blas_copy.hpp | 10 +-- include/infinicore/ops/swap.hpp | 8 +- python/infinicore/ops/asum.py | 3 +- python/infinicore/ops/blas_amax.py | 3 +- python/infinicore/ops/blas_amin.py | 3 +- src/infinicore/ops/asum/asum.cc | 26 +++---- src/infinicore/ops/asum/asum_infiniop.cc | 74 +++++++++---------- src/infinicore/ops/blas_amax/blas_amax.cc | 26 +++---- .../ops/blas_amax/blas_amax_infiniop.cc | 74 +++++++++---------- src/infinicore/ops/blas_amin/blas_amin.cc | 26 +++---- .../ops/blas_amin/blas_amin_infiniop.cc | 74 +++++++++---------- src/infinicore/ops/blas_copy/blas_copy.cc | 16 ++-- .../ops/blas_copy/blas_copy_infiniop.cc | 74 +++++++++---------- src/infinicore/ops/swap/swap.cc | 14 ++-- src/infinicore/ops/swap/swap_infiniop.cc | 74 +++++++++---------- src/infinicore/pybind11/ops/asum.hpp | 2 +- src/infinicore/pybind11/ops/blas_amax.hpp | 2 +- src/infinicore/pybind11/ops/blas_amin.hpp | 2 +- 21 files changed, 250 insertions(+), 297 deletions(-) diff --git a/include/infinicore/ops/asum.hpp b/include/infinicore/ops/asum.hpp index df94f2183..6471e6fe4 100644 --- a/include/infinicore/ops/asum.hpp +++ b/include/infinicore/ops/asum.hpp @@ -1,18 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Asum { -public: - using schema = void (*)(Tensor, Tensor); - static void execute(Tensor result, Tensor x); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(Asum, const Tensor &, Tensor); -Tensor asum(Tensor x); -void asum_(Tensor result, Tensor x); +Tensor asum(const Tensor &x); +void asum_(const Tensor &x, Tensor result); } // namespace infinicore::op diff --git a/include/infinicore/ops/blas_amax.hpp b/include/infinicore/ops/blas_amax.hpp index b43dff58e..a6a571f95 100644 --- a/include/infinicore/ops/blas_amax.hpp +++ b/include/infinicore/ops/blas_amax.hpp @@ -1,18 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class BlasAmax { -public: - using schema = void (*)(Tensor, Tensor); - static void execute(Tensor result, Tensor x); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(BlasAmax, const Tensor &, Tensor); -Tensor blas_amax(Tensor x); -void blas_amax_(Tensor result, Tensor x); +Tensor blas_amax(const Tensor &x); +void blas_amax_(const Tensor &x, Tensor result); } // namespace infinicore::op diff --git a/include/infinicore/ops/blas_amin.hpp b/include/infinicore/ops/blas_amin.hpp index d4161ce40..a2ed21c7b 100644 --- a/include/infinicore/ops/blas_amin.hpp +++ b/include/infinicore/ops/blas_amin.hpp @@ -1,18 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class BlasAmin { -public: - using schema = void (*)(Tensor, Tensor); - static void execute(Tensor result, Tensor x); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(BlasAmin, const Tensor &, Tensor); -Tensor blas_amin(Tensor x); -void blas_amin_(Tensor result, Tensor x); +Tensor blas_amin(const Tensor &x); +void blas_amin_(const Tensor &x, Tensor result); } // namespace infinicore::op diff --git a/include/infinicore/ops/blas_copy.hpp b/include/infinicore/ops/blas_copy.hpp index 1ac048a41..dd32646a0 100644 --- a/include/infinicore/ops/blas_copy.hpp +++ b/include/infinicore/ops/blas_copy.hpp @@ -1,17 +1,13 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class BlasCopy { -public: - using schema = void (*)(Tensor, Tensor); - static void execute(Tensor x, Tensor y); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(BlasCopy, const Tensor &, Tensor); -void blas_copy_(Tensor x, Tensor y); +void blas_copy_(const Tensor &x, Tensor y); } // namespace infinicore::op diff --git a/include/infinicore/ops/swap.hpp b/include/infinicore/ops/swap.hpp index 9b2f85dea..aba3ad563 100644 --- a/include/infinicore/ops/swap.hpp +++ b/include/infinicore/ops/swap.hpp @@ -1,16 +1,12 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Swap { -public: - using schema = void (*)(Tensor, Tensor); - static void execute(Tensor x, Tensor y); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(Swap, Tensor, Tensor); void swap_(Tensor x, Tensor y); diff --git a/python/infinicore/ops/asum.py b/python/infinicore/ops/asum.py index 5f8cce9af..589a02129 100644 --- a/python/infinicore/ops/asum.py +++ b/python/infinicore/ops/asum.py @@ -6,5 +6,6 @@ def asum(x: Tensor, *, out=None): if out is None: return Tensor(_infinicore.asum(x._underlying)) - _infinicore.asum_(out._underlying, x._underlying) + _infinicore.asum_(x._underlying, out._underlying) + return out diff --git a/python/infinicore/ops/blas_amax.py b/python/infinicore/ops/blas_amax.py index ed7fbaf54..65279963d 100644 --- a/python/infinicore/ops/blas_amax.py +++ b/python/infinicore/ops/blas_amax.py @@ -6,5 +6,6 @@ def blas_amax(x: Tensor, *, out=None): if out is None: return Tensor(_infinicore.blas_amax(x._underlying)) - _infinicore.blas_amax_(out._underlying, x._underlying) + _infinicore.blas_amax_(x._underlying, out._underlying) + return out diff --git a/python/infinicore/ops/blas_amin.py b/python/infinicore/ops/blas_amin.py index a93ee43a6..472313faa 100644 --- a/python/infinicore/ops/blas_amin.py +++ b/python/infinicore/ops/blas_amin.py @@ -6,5 +6,6 @@ def blas_amin(x: Tensor, *, out=None): if out is None: return Tensor(_infinicore.blas_amin(x._underlying)) - _infinicore.blas_amin_(out._underlying, x._underlying) + _infinicore.blas_amin_(x._underlying, out._underlying) + return out diff --git a/src/infinicore/ops/asum/asum.cc b/src/infinicore/ops/asum/asum.cc index abf59fd90..c757574bb 100644 --- a/src/infinicore/ops/asum/asum.cc +++ b/src/infinicore/ops/asum/asum.cc @@ -4,25 +4,25 @@ namespace infinicore::op { -common::OpDispatcher &Asum::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; - -void Asum::execute(Tensor result, Tensor x) { - INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, x); - infinicore::context::setDevice(result->device()); - dispatcher().lookup(result->device().getType())(result, x); +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Asum); + +Asum::Asum(const Tensor &x, Tensor result) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, result); + INFINICORE_GRAPH_OP_DISPATCH(result->device().getType(), x, result); +} + +void Asum::execute(const Tensor &x, Tensor result) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(Asum, x, result); } -Tensor asum(Tensor x) { +Tensor asum(const Tensor &x) { auto result = Tensor::empty({}, x->dtype(), x->device()); - asum_(result, x); + asum_(x, result); return result; } -void asum_(Tensor result, Tensor x) { - Asum::execute(result, x); +void asum_(const Tensor &x, Tensor result) { + Asum::execute(x, result); } } // namespace infinicore::op diff --git a/src/infinicore/ops/asum/asum_infiniop.cc b/src/infinicore/ops/asum/asum_infiniop.cc index 65eded161..0cfd8b721 100644 --- a/src/infinicore/ops/asum/asum_infiniop.cc +++ b/src/infinicore/ops/asum/asum_infiniop.cc @@ -1,56 +1,50 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/asum.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::asum_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopAsumDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyAsumDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Asum, 100); -void calculate(Tensor result, Tensor x) { - size_t seed = hash_combine(result, x); +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, x, result; +}; - auto device_type = context::getDevice().getType(); - auto device_index = context::getDevice().getIndex(); +void *plan(const Tensor &x, Tensor result) { + size_t seed = hash_combine(x, result); - auto &cache = caches.getCache(device_type, device_index); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Asum, + seed, + x->desc(), result->desc()); - auto desc_opt = cache.get(seed); - infiniopAsumDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, Asum, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateAsumDescriptor( - context::getInfiniopHandle(result->device()), &desc, - x->desc(), result->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(x), + graph::GraphTensor(result)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetAsumWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopAsum( - desc, workspace->data(), workspace_size, - x->data(), result->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->x->data(), + planned->result->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - Asum::dispatcher().registerDevice({Device::Type::CPU, - Device::Type::CAMBRICON, - Device::Type::METAX}, - &calculate, - false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Asum, &plan, &run, &cleanup); } // namespace infinicore::op::asum_impl::infiniop diff --git a/src/infinicore/ops/blas_amax/blas_amax.cc b/src/infinicore/ops/blas_amax/blas_amax.cc index 5ec8825c4..9579589ff 100644 --- a/src/infinicore/ops/blas_amax/blas_amax.cc +++ b/src/infinicore/ops/blas_amax/blas_amax.cc @@ -4,25 +4,25 @@ namespace infinicore::op { -common::OpDispatcher &BlasAmax::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; - -void BlasAmax::execute(Tensor result, Tensor x) { - INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, x); - infinicore::context::setDevice(result->device()); - dispatcher().lookup(result->device().getType())(result, x); +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(BlasAmax); + +BlasAmax::BlasAmax(const Tensor &x, Tensor result) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, result); + INFINICORE_GRAPH_OP_DISPATCH(result->device().getType(), x, result); +} + +void BlasAmax::execute(const Tensor &x, Tensor result) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(BlasAmax, x, result); } -Tensor blas_amax(Tensor x) { +Tensor blas_amax(const Tensor &x) { auto result = Tensor::empty({}, DataType::I32, x->device()); - blas_amax_(result, x); + blas_amax_(x, result); return result; } -void blas_amax_(Tensor result, Tensor x) { - BlasAmax::execute(result, x); +void blas_amax_(const Tensor &x, Tensor result) { + BlasAmax::execute(x, result); } } // namespace infinicore::op diff --git a/src/infinicore/ops/blas_amax/blas_amax_infiniop.cc b/src/infinicore/ops/blas_amax/blas_amax_infiniop.cc index 9b66fd164..780fca744 100644 --- a/src/infinicore/ops/blas_amax/blas_amax_infiniop.cc +++ b/src/infinicore/ops/blas_amax/blas_amax_infiniop.cc @@ -1,56 +1,50 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/blas_amax.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::blas_amax_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopBlasAmaxDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyBlasAmaxDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, BlasAmax, 100); -void calculate(Tensor result, Tensor x) { - size_t seed = hash_combine(result, x); +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, x, result; +}; - auto device_type = context::getDevice().getType(); - auto device_index = context::getDevice().getIndex(); +void *plan(const Tensor &x, Tensor result) { + size_t seed = hash_combine(x, result); - auto &cache = caches.getCache(device_type, device_index); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, BlasAmax, + seed, + x->desc(), result->desc()); - auto desc_opt = cache.get(seed); - infiniopBlasAmaxDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, BlasAmax, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateBlasAmaxDescriptor( - context::getInfiniopHandle(result->device()), &desc, - x->desc(), result->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(x), + graph::GraphTensor(result)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetBlasAmaxWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopBlasAmax( - desc, workspace->data(), workspace_size, - x->data(), result->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->x->data(), + planned->result->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - BlasAmax::dispatcher().registerDevice({Device::Type::CPU, - Device::Type::CAMBRICON, - Device::Type::METAX}, - &calculate, - false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(BlasAmax, &plan, &run, &cleanup); } // namespace infinicore::op::blas_amax_impl::infiniop diff --git a/src/infinicore/ops/blas_amin/blas_amin.cc b/src/infinicore/ops/blas_amin/blas_amin.cc index 058ae1fdc..e18e22739 100644 --- a/src/infinicore/ops/blas_amin/blas_amin.cc +++ b/src/infinicore/ops/blas_amin/blas_amin.cc @@ -4,25 +4,25 @@ namespace infinicore::op { -common::OpDispatcher &BlasAmin::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; - -void BlasAmin::execute(Tensor result, Tensor x) { - INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, x); - infinicore::context::setDevice(result->device()); - dispatcher().lookup(result->device().getType())(result, x); +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(BlasAmin); + +BlasAmin::BlasAmin(const Tensor &x, Tensor result) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, result); + INFINICORE_GRAPH_OP_DISPATCH(result->device().getType(), x, result); +} + +void BlasAmin::execute(const Tensor &x, Tensor result) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(BlasAmin, x, result); } -Tensor blas_amin(Tensor x) { +Tensor blas_amin(const Tensor &x) { auto result = Tensor::empty({}, DataType::I32, x->device()); - blas_amin_(result, x); + blas_amin_(x, result); return result; } -void blas_amin_(Tensor result, Tensor x) { - BlasAmin::execute(result, x); +void blas_amin_(const Tensor &x, Tensor result) { + BlasAmin::execute(x, result); } } // namespace infinicore::op diff --git a/src/infinicore/ops/blas_amin/blas_amin_infiniop.cc b/src/infinicore/ops/blas_amin/blas_amin_infiniop.cc index f93575c6c..00abf77e1 100644 --- a/src/infinicore/ops/blas_amin/blas_amin_infiniop.cc +++ b/src/infinicore/ops/blas_amin/blas_amin_infiniop.cc @@ -1,56 +1,50 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/blas_amin.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::blas_amin_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopBlasAminDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyBlasAminDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, BlasAmin, 100); -void calculate(Tensor result, Tensor x) { - size_t seed = hash_combine(result, x); +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, x, result; +}; - auto device_type = context::getDevice().getType(); - auto device_index = context::getDevice().getIndex(); +void *plan(const Tensor &x, Tensor result) { + size_t seed = hash_combine(x, result); - auto &cache = caches.getCache(device_type, device_index); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, BlasAmin, + seed, + x->desc(), result->desc()); - auto desc_opt = cache.get(seed); - infiniopBlasAminDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, BlasAmin, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateBlasAminDescriptor( - context::getInfiniopHandle(result->device()), &desc, - x->desc(), result->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(x), + graph::GraphTensor(result)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetBlasAminWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopBlasAmin( - desc, workspace->data(), workspace_size, - x->data(), result->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->x->data(), + planned->result->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - BlasAmin::dispatcher().registerDevice({Device::Type::CPU, - Device::Type::CAMBRICON, - Device::Type::METAX}, - &calculate, - false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(BlasAmin, &plan, &run, &cleanup); } // namespace infinicore::op::blas_amin_impl::infiniop diff --git a/src/infinicore/ops/blas_copy/blas_copy.cc b/src/infinicore/ops/blas_copy/blas_copy.cc index 374774f5c..77b2e7a5f 100644 --- a/src/infinicore/ops/blas_copy/blas_copy.cc +++ b/src/infinicore/ops/blas_copy/blas_copy.cc @@ -4,18 +4,18 @@ namespace infinicore::op { -common::OpDispatcher &BlasCopy::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(BlasCopy); -void BlasCopy::execute(Tensor x, Tensor y) { +BlasCopy::BlasCopy(const Tensor &x, Tensor y) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, y); - infinicore::context::setDevice(y->device()); - dispatcher().lookup(y->device().getType())(x, y); + INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), x, y); } -void blas_copy_(Tensor x, Tensor y) { +void BlasCopy::execute(const Tensor &x, Tensor y) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(BlasCopy, x, y); +} + +void blas_copy_(const Tensor &x, Tensor y) { BlasCopy::execute(x, y); } diff --git a/src/infinicore/ops/blas_copy/blas_copy_infiniop.cc b/src/infinicore/ops/blas_copy/blas_copy_infiniop.cc index aa1c8bc52..33a70523b 100644 --- a/src/infinicore/ops/blas_copy/blas_copy_infiniop.cc +++ b/src/infinicore/ops/blas_copy/blas_copy_infiniop.cc @@ -1,56 +1,50 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/blas_copy.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::blas_copy_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopBlasCopyDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyBlasCopyDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, BlasCopy, 100); -void calculate(Tensor x, Tensor y) { - size_t seed = hash_combine(x, y); +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, x, y; +}; - auto device_type = context::getDevice().getType(); - auto device_index = context::getDevice().getIndex(); +void *plan(const Tensor &x, Tensor y) { + size_t seed = hash_combine(x, y); - auto &cache = caches.getCache(device_type, device_index); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, BlasCopy, + seed, + x->desc(), y->desc()); - auto desc_opt = cache.get(seed); - infiniopBlasCopyDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, BlasCopy, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateBlasCopyDescriptor( - context::getInfiniopHandle(y->device()), &desc, - x->desc(), y->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(x), + graph::GraphTensor(y)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetBlasCopyWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopBlasCopy( - desc, workspace->data(), workspace_size, - x->data(), y->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->x->data(), + planned->y->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - BlasCopy::dispatcher().registerDevice({Device::Type::CPU, - Device::Type::CAMBRICON, - Device::Type::METAX}, - &calculate, - false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(BlasCopy, &plan, &run, &cleanup); } // namespace infinicore::op::blas_copy_impl::infiniop diff --git a/src/infinicore/ops/swap/swap.cc b/src/infinicore/ops/swap/swap.cc index 5bc078e5c..6ce4b86b9 100644 --- a/src/infinicore/ops/swap/swap.cc +++ b/src/infinicore/ops/swap/swap.cc @@ -4,15 +4,15 @@ namespace infinicore::op { -common::OpDispatcher &Swap::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Swap); -void Swap::execute(Tensor x, Tensor y) { +Swap::Swap(Tensor x, Tensor y) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, y); - infinicore::context::setDevice(x->device()); - dispatcher().lookup(x->device().getType())(x, y); + INFINICORE_GRAPH_OP_DISPATCH(x->device().getType(), x, y); +} + +void Swap::execute(Tensor x, Tensor y) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(Swap, x, y); } void swap_(Tensor x, Tensor y) { diff --git a/src/infinicore/ops/swap/swap_infiniop.cc b/src/infinicore/ops/swap/swap_infiniop.cc index f344e36a0..32a163f65 100644 --- a/src/infinicore/ops/swap/swap_infiniop.cc +++ b/src/infinicore/ops/swap/swap_infiniop.cc @@ -1,56 +1,50 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/swap.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::swap_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopSwapDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroySwapDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Swap, 100); -void calculate(Tensor x, Tensor y) { - size_t seed = hash_combine(x, y); +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, x, y; +}; - auto device_type = context::getDevice().getType(); - auto device_index = context::getDevice().getIndex(); +void *plan(Tensor x, Tensor y) { + size_t seed = hash_combine(x, y); - auto &cache = caches.getCache(device_type, device_index); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Swap, + seed, + x->desc(), y->desc()); - auto desc_opt = cache.get(seed); - infiniopSwapDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, Swap, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateSwapDescriptor( - context::getInfiniopHandle(x->device()), &desc, - x->desc(), y->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(x), + graph::GraphTensor(y)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetSwapWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopSwap( - desc, workspace->data(), workspace_size, - x->data(), y->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->x->data(), + planned->y->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - Swap::dispatcher().registerDevice({Device::Type::CPU, - Device::Type::CAMBRICON, - Device::Type::METAX}, - &calculate, - false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Swap, &plan, &run, &cleanup); } // namespace infinicore::op::swap_impl::infiniop diff --git a/src/infinicore/pybind11/ops/asum.hpp b/src/infinicore/pybind11/ops/asum.hpp index 40661c562..b094d12f5 100644 --- a/src/infinicore/pybind11/ops/asum.hpp +++ b/src/infinicore/pybind11/ops/asum.hpp @@ -16,8 +16,8 @@ inline void bind_asum(py::module &m) { m.def("asum_", &op::asum_, - py::arg("result"), py::arg("x"), + py::arg("result"), R"doc(In-place BLAS level-1 asum.)doc"); } diff --git a/src/infinicore/pybind11/ops/blas_amax.hpp b/src/infinicore/pybind11/ops/blas_amax.hpp index 5afdee62c..51e8cfe4d 100644 --- a/src/infinicore/pybind11/ops/blas_amax.hpp +++ b/src/infinicore/pybind11/ops/blas_amax.hpp @@ -16,8 +16,8 @@ inline void bind_blas_amax(py::module &m) { m.def("blas_amax_", &op::blas_amax_, - py::arg("result"), py::arg("x"), + py::arg("result"), R"doc(In-place BLAS level-1 amax.)doc"); } diff --git a/src/infinicore/pybind11/ops/blas_amin.hpp b/src/infinicore/pybind11/ops/blas_amin.hpp index 4f9b316de..8961a9363 100644 --- a/src/infinicore/pybind11/ops/blas_amin.hpp +++ b/src/infinicore/pybind11/ops/blas_amin.hpp @@ -16,8 +16,8 @@ inline void bind_blas_amin(py::module &m) { m.def("blas_amin_", &op::blas_amin_, - py::arg("result"), py::arg("x"), + py::arg("result"), R"doc(In-place BLAS level-1 amin.)doc"); } From 0a8fa567a438bb04f9fb61c455deb5f4c13300c7 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Sat, 9 May 2026 03:07:16 +0000 Subject: [PATCH 22/25] Add `InsertNewlineAtEOF: true` to `.clang-format` --- .clang-format | 1 + 1 file changed, 1 insertion(+) diff --git a/.clang-format b/.clang-format index c05c7633b..a802aaf49 100644 --- a/.clang-format +++ b/.clang-format @@ -28,3 +28,4 @@ BraceWrapping: SplitEmptyFunction: true SplitEmptyRecord: true SplitEmptyNamespace: true +InsertNewlineAtEOF: true From d1cee8a6b7677257bf1edd7701ef60fe39439133 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Sat, 9 May 2026 03:13:53 +0000 Subject: [PATCH 23/25] Run `scripts/format.py` to fix the format --- include/infiniop/ops/axpy.h | 2 +- include/infiniop/ops/blas_amax.h | 2 +- include/infiniop/ops/blas_amin.h | 2 +- include/infiniop/ops/rotg.h | 2 +- include/infiniop/ops/rotm.h | 2 +- include/infiniop/ops/rotmg.h | 2 +- src/infiniop/ops/asum/asum.h | 2 +- src/infiniop/ops/asum/bang/asum_bang.h | 2 +- src/infiniop/ops/asum/bang/asum_bang.mlu | 2 +- src/infiniop/ops/asum/cpu/asum_cpu.h | 2 +- src/infiniop/ops/asum/info.h | 2 +- src/infiniop/ops/asum/metax/asum_metax.cc | 2 +- src/infiniop/ops/asum/metax/asum_metax.h | 2 +- src/infiniop/ops/axpy/axpy.h | 2 +- src/infiniop/ops/axpy/bang/axpy_bang.h | 2 +- src/infiniop/ops/axpy/bang/axpy_bang.mlu | 2 +- src/infiniop/ops/axpy/cpu/axpy_cpu.h | 2 +- src/infiniop/ops/axpy/info.h | 2 +- src/infiniop/ops/axpy/metax/axpy_metax.cc | 2 +- src/infiniop/ops/axpy/metax/axpy_metax.h | 2 +- src/infiniop/ops/axpy/operator.cc | 2 +- src/infiniop/ops/blas_amax/bang/blas_amax_bang.h | 2 +- src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu | 2 +- src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu | 2 +- src/infiniop/ops/blas_amax/blas_amax.h | 2 +- src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.h | 2 +- src/infiniop/ops/blas_amax/info.h | 2 +- src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc | 2 +- src/infiniop/ops/blas_amax/metax/blas_amax_metax.h | 2 +- src/infiniop/ops/blas_amax/operator.cc | 2 +- src/infiniop/ops/blas_amin/bang/blas_amin_bang.h | 2 +- src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu | 2 +- src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu | 2 +- src/infiniop/ops/blas_amin/blas_amin.h | 2 +- src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.h | 2 +- src/infiniop/ops/blas_amin/info.h | 2 +- src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc | 2 +- src/infiniop/ops/blas_amin/metax/blas_amin_metax.h | 2 +- src/infiniop/ops/blas_amin/operator.cc | 2 +- src/infiniop/ops/blas_copy/info.h | 2 +- src/infiniop/ops/blas_dot/bang/blas_dot_bang.h | 2 +- src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.h | 2 +- src/infiniop/ops/blas_dot/info.h | 2 +- src/infiniop/ops/blas_dot/metax/blas_dot_metax.h | 2 +- src/infiniop/ops/nrm2/bang/nrm2_bang.h | 2 +- src/infiniop/ops/nrm2/cpu/nrm2_cpu.h | 2 +- src/infiniop/ops/nrm2/info.h | 2 +- src/infiniop/ops/nrm2/metax/nrm2_metax.h | 2 +- src/infiniop/ops/nrm2/operator.cc | 2 +- src/infiniop/ops/rot/bang/rot_bang.h | 2 +- src/infiniop/ops/rot/cpu/rot_cpu.h | 2 +- src/infiniop/ops/rot/info.h | 2 +- src/infiniop/ops/rot/metax/rot_metax.h | 2 +- src/infiniop/ops/rotg/info.h | 2 +- src/infiniop/ops/rotm/bang/rotm_bang.h | 2 +- src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu | 2 +- src/infiniop/ops/rotm/cpu/rotm_cpu.h | 2 +- src/infiniop/ops/rotm/info.h | 2 +- src/infiniop/ops/rotm/metax/rotm_metax.h | 2 +- src/infiniop/ops/rotm/operator.cc | 2 +- src/infiniop/ops/rotmg/bang/rotmg_bang_kernel.mlu | 2 +- src/infiniop/ops/rotmg/info.h | 2 +- src/infiniop/ops/scal/bang/scal_bang.h | 2 +- src/infiniop/ops/scal/cpu/scal_cpu.h | 2 +- src/infiniop/ops/scal/info.h | 2 +- src/infiniop/ops/scal/metax/scal_metax.h | 2 +- src/infiniop/ops/swap/info.h | 2 +- test/infiniop/libinfiniop/op_register.py | 1 + 68 files changed, 68 insertions(+), 67 deletions(-) diff --git a/include/infiniop/ops/axpy.h b/include/infiniop/ops/axpy.h index 1cf459602..ce6b2a23b 100644 --- a/include/infiniop/ops/axpy.h +++ b/include/infiniop/ops/axpy.h @@ -23,4 +23,4 @@ __INFINI_C __export infiniStatus_t infiniopAxpy(infiniopAxpyDescriptor_t desc, __INFINI_C __export infiniStatus_t infiniopDestroyAxpyDescriptor(infiniopAxpyDescriptor_t desc); -#endif // __INFINIOP_AXPY_API_H__ \ No newline at end of file +#endif // __INFINIOP_AXPY_API_H__ diff --git a/include/infiniop/ops/blas_amax.h b/include/infiniop/ops/blas_amax.h index be69ad4df..9981c7fb7 100644 --- a/include/infiniop/ops/blas_amax.h +++ b/include/infiniop/ops/blas_amax.h @@ -21,4 +21,4 @@ __INFINI_C __export infiniStatus_t infiniopBlasAmax(infiniopBlasAmaxDescriptor_t __INFINI_C __export infiniStatus_t infiniopDestroyBlasAmaxDescriptor(infiniopBlasAmaxDescriptor_t desc); -#endif // __INFINIOP_BLAS_AMAX_API_H__ \ No newline at end of file +#endif // __INFINIOP_BLAS_AMAX_API_H__ diff --git a/include/infiniop/ops/blas_amin.h b/include/infiniop/ops/blas_amin.h index f568acd4a..6bc8680ba 100644 --- a/include/infiniop/ops/blas_amin.h +++ b/include/infiniop/ops/blas_amin.h @@ -21,4 +21,4 @@ __INFINI_C __export infiniStatus_t infiniopBlasAmin(infiniopBlasAminDescriptor_t __INFINI_C __export infiniStatus_t infiniopDestroyBlasAminDescriptor(infiniopBlasAminDescriptor_t desc); -#endif // __INFINIOP_BLAS_AMIN_API_H__ \ No newline at end of file +#endif // __INFINIOP_BLAS_AMIN_API_H__ diff --git a/include/infiniop/ops/rotg.h b/include/infiniop/ops/rotg.h index a9121b873..63c2dad46 100644 --- a/include/infiniop/ops/rotg.h +++ b/include/infiniop/ops/rotg.h @@ -25,4 +25,4 @@ __INFINI_C __export infiniStatus_t infiniopRotg(infiniopRotgDescriptor_t desc, __INFINI_C __export infiniStatus_t infiniopDestroyRotgDescriptor(infiniopRotgDescriptor_t desc); -#endif // __INFINIOP_ROTG_API_H__ \ No newline at end of file +#endif // __INFINIOP_ROTG_API_H__ diff --git a/include/infiniop/ops/rotm.h b/include/infiniop/ops/rotm.h index 279d18129..6cc6a636c 100644 --- a/include/infiniop/ops/rotm.h +++ b/include/infiniop/ops/rotm.h @@ -23,4 +23,4 @@ __INFINI_C __export infiniStatus_t infiniopRotm(infiniopRotmDescriptor_t desc, __INFINI_C __export infiniStatus_t infiniopDestroyRotmDescriptor(infiniopRotmDescriptor_t desc); -#endif // __INFINIOP_ROTM_API_H__ \ No newline at end of file +#endif // __INFINIOP_ROTM_API_H__ diff --git a/include/infiniop/ops/rotmg.h b/include/infiniop/ops/rotmg.h index 93f425216..0295339cb 100644 --- a/include/infiniop/ops/rotmg.h +++ b/include/infiniop/ops/rotmg.h @@ -27,4 +27,4 @@ __INFINI_C __export infiniStatus_t infiniopRotmg(infiniopRotmgDescriptor_t desc, __INFINI_C __export infiniStatus_t infiniopDestroyRotmgDescriptor(infiniopRotmgDescriptor_t desc); -#endif // __INFINIOP_ROTMG_API_H__ \ No newline at end of file +#endif // __INFINIOP_ROTMG_API_H__ diff --git a/src/infiniop/ops/asum/asum.h b/src/infiniop/ops/asum/asum.h index 71cc97cad..dd7aef22e 100644 --- a/src/infiniop/ops/asum/asum.h +++ b/src/infiniop/ops/asum/asum.h @@ -44,4 +44,4 @@ }; \ } -#endif // __ASUM_H__ \ No newline at end of file +#endif // __ASUM_H__ diff --git a/src/infiniop/ops/asum/bang/asum_bang.h b/src/infiniop/ops/asum/bang/asum_bang.h index e5d99dac3..bf388c744 100644 --- a/src/infiniop/ops/asum/bang/asum_bang.h +++ b/src/infiniop/ops/asum/bang/asum_bang.h @@ -5,4 +5,4 @@ DESCRIPTOR(bang) -#endif // __ASUM_BANG_H__ \ No newline at end of file +#endif // __ASUM_BANG_H__ diff --git a/src/infiniop/ops/asum/bang/asum_bang.mlu b/src/infiniop/ops/asum/bang/asum_bang.mlu index 9b85ba56a..1c8cba0c3 100644 --- a/src/infiniop/ops/asum/bang/asum_bang.mlu +++ b/src/infiniop/ops/asum/bang/asum_bang.mlu @@ -92,4 +92,4 @@ infiniStatus_t Descriptor::calculate( #undef CALCULATE_ASUM -} // namespace op::asum::bang \ No newline at end of file +} // namespace op::asum::bang diff --git a/src/infiniop/ops/asum/cpu/asum_cpu.h b/src/infiniop/ops/asum/cpu/asum_cpu.h index 6cebbaf34..84b7572d8 100644 --- a/src/infiniop/ops/asum/cpu/asum_cpu.h +++ b/src/infiniop/ops/asum/cpu/asum_cpu.h @@ -5,4 +5,4 @@ DESCRIPTOR(cpu) -#endif // __ASUM_CPU_H__ \ No newline at end of file +#endif // __ASUM_CPU_H__ diff --git a/src/infiniop/ops/asum/info.h b/src/infiniop/ops/asum/info.h index 17841f48f..3efcc0e39 100644 --- a/src/infiniop/ops/asum/info.h +++ b/src/infiniop/ops/asum/info.h @@ -38,4 +38,4 @@ class AsumInfo { } }; -#endif // __ASUM_INFO_H__ \ No newline at end of file +#endif // __ASUM_INFO_H__ diff --git a/src/infiniop/ops/asum/metax/asum_metax.cc b/src/infiniop/ops/asum/metax/asum_metax.cc index 9fb7453bb..4a084fbf9 100644 --- a/src/infiniop/ops/asum/metax/asum_metax.cc +++ b/src/infiniop/ops/asum/metax/asum_metax.cc @@ -68,4 +68,4 @@ infiniStatus_t Descriptor::calculate( return INFINI_STATUS_SUCCESS; } -} // namespace op::asum::metax \ No newline at end of file +} // namespace op::asum::metax diff --git a/src/infiniop/ops/asum/metax/asum_metax.h b/src/infiniop/ops/asum/metax/asum_metax.h index 4b675b583..f5ade8e58 100644 --- a/src/infiniop/ops/asum/metax/asum_metax.h +++ b/src/infiniop/ops/asum/metax/asum_metax.h @@ -5,4 +5,4 @@ DESCRIPTOR(metax) -#endif // __ASUM_METAX_H__ \ No newline at end of file +#endif // __ASUM_METAX_H__ diff --git a/src/infiniop/ops/axpy/axpy.h b/src/infiniop/ops/axpy/axpy.h index 7dedb68b4..617e9d01b 100644 --- a/src/infiniop/ops/axpy/axpy.h +++ b/src/infiniop/ops/axpy/axpy.h @@ -46,4 +46,4 @@ }; \ } -#endif // __AXPY_H__ \ No newline at end of file +#endif // __AXPY_H__ diff --git a/src/infiniop/ops/axpy/bang/axpy_bang.h b/src/infiniop/ops/axpy/bang/axpy_bang.h index a448303bb..dbce7e5ca 100644 --- a/src/infiniop/ops/axpy/bang/axpy_bang.h +++ b/src/infiniop/ops/axpy/bang/axpy_bang.h @@ -5,4 +5,4 @@ DESCRIPTOR(bang) -#endif // __AXPY_BANG_H__ \ No newline at end of file +#endif // __AXPY_BANG_H__ diff --git a/src/infiniop/ops/axpy/bang/axpy_bang.mlu b/src/infiniop/ops/axpy/bang/axpy_bang.mlu index c3d72baa8..a7cd75bf7 100644 --- a/src/infiniop/ops/axpy/bang/axpy_bang.mlu +++ b/src/infiniop/ops/axpy/bang/axpy_bang.mlu @@ -101,4 +101,4 @@ infiniStatus_t Descriptor::calculate( #undef CALCULATE_AXPY -} // namespace op::axpy::bang \ No newline at end of file +} // namespace op::axpy::bang diff --git a/src/infiniop/ops/axpy/cpu/axpy_cpu.h b/src/infiniop/ops/axpy/cpu/axpy_cpu.h index f4bf63602..f25b49ef3 100644 --- a/src/infiniop/ops/axpy/cpu/axpy_cpu.h +++ b/src/infiniop/ops/axpy/cpu/axpy_cpu.h @@ -5,4 +5,4 @@ DESCRIPTOR(cpu) -#endif // __AXPY_CPU_H__ \ No newline at end of file +#endif // __AXPY_CPU_H__ diff --git a/src/infiniop/ops/axpy/info.h b/src/infiniop/ops/axpy/info.h index 7155c846b..9dfaa6c07 100644 --- a/src/infiniop/ops/axpy/info.h +++ b/src/infiniop/ops/axpy/info.h @@ -46,4 +46,4 @@ class AxpyInfo { } }; -#endif // __AXPY_INFO_H__ \ No newline at end of file +#endif // __AXPY_INFO_H__ diff --git a/src/infiniop/ops/axpy/metax/axpy_metax.cc b/src/infiniop/ops/axpy/metax/axpy_metax.cc index 3a36e31f3..55f173cb0 100644 --- a/src/infiniop/ops/axpy/metax/axpy_metax.cc +++ b/src/infiniop/ops/axpy/metax/axpy_metax.cc @@ -99,4 +99,4 @@ infiniStatus_t Descriptor::calculate( return INFINI_STATUS_SUCCESS; } -} // namespace op::axpy::metax \ No newline at end of file +} // namespace op::axpy::metax diff --git a/src/infiniop/ops/axpy/metax/axpy_metax.h b/src/infiniop/ops/axpy/metax/axpy_metax.h index 09144a535..8129ca2e4 100644 --- a/src/infiniop/ops/axpy/metax/axpy_metax.h +++ b/src/infiniop/ops/axpy/metax/axpy_metax.h @@ -5,4 +5,4 @@ DESCRIPTOR(metax) -#endif // __AXPY_METAX_H__ \ No newline at end of file +#endif // __AXPY_METAX_H__ diff --git a/src/infiniop/ops/axpy/operator.cc b/src/infiniop/ops/axpy/operator.cc index 1c45f0d35..cd57223e4 100644 --- a/src/infiniop/ops/axpy/operator.cc +++ b/src/infiniop/ops/axpy/operator.cc @@ -124,4 +124,4 @@ infiniopDestroyAxpyDescriptor(infiniopAxpyDescriptor_t desc) { } #undef DELETE -} \ No newline at end of file +} diff --git a/src/infiniop/ops/blas_amax/bang/blas_amax_bang.h b/src/infiniop/ops/blas_amax/bang/blas_amax_bang.h index 88a0250a7..8bc6ca2c5 100644 --- a/src/infiniop/ops/blas_amax/bang/blas_amax_bang.h +++ b/src/infiniop/ops/blas_amax/bang/blas_amax_bang.h @@ -5,4 +5,4 @@ DESCRIPTOR(bang) -#endif // __BLAS_AMAX_BANG_H__ \ No newline at end of file +#endif // __BLAS_AMAX_BANG_H__ diff --git a/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu b/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu index 62f43aaf9..26cc3c501 100644 --- a/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu +++ b/src/infiniop/ops/blas_amax/bang/blas_amax_bang.mlu @@ -93,4 +93,4 @@ infiniStatus_t Descriptor::calculate( #undef CALCULATE_BLAS_AMAX -} // namespace op::blas_amax::bang \ No newline at end of file +} // namespace op::blas_amax::bang diff --git a/src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu b/src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu index e3736aa08..2939f2dd7 100644 --- a/src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu +++ b/src/infiniop/ops/blas_amax/bang/blas_amax_bang_kernel.mlu @@ -133,4 +133,4 @@ __mlu_global__ void blasAmaxKernelStrided( result[0] = cluster_max_index + 1; // Convert to 1-based index } -} \ No newline at end of file +} diff --git a/src/infiniop/ops/blas_amax/blas_amax.h b/src/infiniop/ops/blas_amax/blas_amax.h index e5b3400c3..627b7b754 100644 --- a/src/infiniop/ops/blas_amax/blas_amax.h +++ b/src/infiniop/ops/blas_amax/blas_amax.h @@ -44,4 +44,4 @@ }; \ } -#endif // __BLAS_AMAX_H__ \ No newline at end of file +#endif // __BLAS_AMAX_H__ diff --git a/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.h b/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.h index 2aa42e756..66197c9c1 100644 --- a/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.h +++ b/src/infiniop/ops/blas_amax/cpu/blas_amax_cpu.h @@ -5,4 +5,4 @@ DESCRIPTOR(cpu) -#endif // __BLAS_AMAX_CPU_H__ \ No newline at end of file +#endif // __BLAS_AMAX_CPU_H__ diff --git a/src/infiniop/ops/blas_amax/info.h b/src/infiniop/ops/blas_amax/info.h index b9b1924c8..b10f84046 100644 --- a/src/infiniop/ops/blas_amax/info.h +++ b/src/infiniop/ops/blas_amax/info.h @@ -38,4 +38,4 @@ class BlasAmaxInfo { } }; -#endif // __BLAS_AMAX_INFO_H__ \ No newline at end of file +#endif // __BLAS_AMAX_INFO_H__ diff --git a/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc b/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc index 975586dc2..05816020d 100644 --- a/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc +++ b/src/infiniop/ops/blas_amax/metax/blas_amax_metax.cc @@ -68,4 +68,4 @@ infiniStatus_t Descriptor::calculate( return INFINI_STATUS_SUCCESS; } -} // namespace op::blas_amax::metax \ No newline at end of file +} // namespace op::blas_amax::metax diff --git a/src/infiniop/ops/blas_amax/metax/blas_amax_metax.h b/src/infiniop/ops/blas_amax/metax/blas_amax_metax.h index e4f9af0c0..19e79851f 100644 --- a/src/infiniop/ops/blas_amax/metax/blas_amax_metax.h +++ b/src/infiniop/ops/blas_amax/metax/blas_amax_metax.h @@ -5,4 +5,4 @@ DESCRIPTOR(metax) -#endif // __BLAS_AMAX_METAX_H__ \ No newline at end of file +#endif // __BLAS_AMAX_METAX_H__ diff --git a/src/infiniop/ops/blas_amax/operator.cc b/src/infiniop/ops/blas_amax/operator.cc index 2e8e37247..c6b48eeb4 100644 --- a/src/infiniop/ops/blas_amax/operator.cc +++ b/src/infiniop/ops/blas_amax/operator.cc @@ -119,4 +119,4 @@ __INFINI_C infiniStatus_t infiniopDestroyBlasAmaxDescriptor(infiniopBlasAmaxDesc } #undef DELETE -} \ No newline at end of file +} diff --git a/src/infiniop/ops/blas_amin/bang/blas_amin_bang.h b/src/infiniop/ops/blas_amin/bang/blas_amin_bang.h index 6cc84d3a8..ba9dbaa21 100644 --- a/src/infiniop/ops/blas_amin/bang/blas_amin_bang.h +++ b/src/infiniop/ops/blas_amin/bang/blas_amin_bang.h @@ -5,4 +5,4 @@ DESCRIPTOR(bang) -#endif // __BLAS_AMIN_BANG_H__ \ No newline at end of file +#endif // __BLAS_AMIN_BANG_H__ diff --git a/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu b/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu index b3020d68a..dfd2a7e64 100644 --- a/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu +++ b/src/infiniop/ops/blas_amin/bang/blas_amin_bang.mlu @@ -93,4 +93,4 @@ infiniStatus_t Descriptor::calculate( #undef CALCULATE_BLAS_AMIN -} // namespace op::blas_amin::bang \ No newline at end of file +} // namespace op::blas_amin::bang diff --git a/src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu b/src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu index 99d2a828b..2c5ff7c48 100644 --- a/src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu +++ b/src/infiniop/ops/blas_amin/bang/blas_amin_bang_kernel.mlu @@ -132,4 +132,4 @@ __mlu_global__ void blasAminKernelStrided( result[0] = min_index + 1; // Convert to 1-based index } -} \ No newline at end of file +} diff --git a/src/infiniop/ops/blas_amin/blas_amin.h b/src/infiniop/ops/blas_amin/blas_amin.h index 8cf7ceca6..5128fcf85 100644 --- a/src/infiniop/ops/blas_amin/blas_amin.h +++ b/src/infiniop/ops/blas_amin/blas_amin.h @@ -44,4 +44,4 @@ }; \ } -#endif // __BLAS_AMIN_H__ \ No newline at end of file +#endif // __BLAS_AMIN_H__ diff --git a/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.h b/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.h index ea4daa397..c5e4936d7 100644 --- a/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.h +++ b/src/infiniop/ops/blas_amin/cpu/blas_amin_cpu.h @@ -5,4 +5,4 @@ DESCRIPTOR(cpu) -#endif // __BLAS_AMIN_CPU_H__ \ No newline at end of file +#endif // __BLAS_AMIN_CPU_H__ diff --git a/src/infiniop/ops/blas_amin/info.h b/src/infiniop/ops/blas_amin/info.h index a9b187f92..0522edecb 100644 --- a/src/infiniop/ops/blas_amin/info.h +++ b/src/infiniop/ops/blas_amin/info.h @@ -38,4 +38,4 @@ class BlasAminInfo { } }; -#endif // __BLAS_AMIN_INFO_H__ \ No newline at end of file +#endif // __BLAS_AMIN_INFO_H__ diff --git a/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc b/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc index d75ff93ca..dadd2706d 100644 --- a/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc +++ b/src/infiniop/ops/blas_amin/metax/blas_amin_metax.cc @@ -68,4 +68,4 @@ infiniStatus_t Descriptor::calculate( return INFINI_STATUS_SUCCESS; } -} // namespace op::blas_amin::metax \ No newline at end of file +} // namespace op::blas_amin::metax diff --git a/src/infiniop/ops/blas_amin/metax/blas_amin_metax.h b/src/infiniop/ops/blas_amin/metax/blas_amin_metax.h index ce4b6b0b7..42a5b5fe9 100644 --- a/src/infiniop/ops/blas_amin/metax/blas_amin_metax.h +++ b/src/infiniop/ops/blas_amin/metax/blas_amin_metax.h @@ -5,4 +5,4 @@ DESCRIPTOR(metax) -#endif // __BLAS_AMIN_METAX_H__ \ No newline at end of file +#endif // __BLAS_AMIN_METAX_H__ diff --git a/src/infiniop/ops/blas_amin/operator.cc b/src/infiniop/ops/blas_amin/operator.cc index f47044263..7f960b773 100644 --- a/src/infiniop/ops/blas_amin/operator.cc +++ b/src/infiniop/ops/blas_amin/operator.cc @@ -119,4 +119,4 @@ __INFINI_C infiniStatus_t infiniopDestroyBlasAminDescriptor(infiniopBlasAminDesc } #undef DELETE -} \ No newline at end of file +} diff --git a/src/infiniop/ops/blas_copy/info.h b/src/infiniop/ops/blas_copy/info.h index ec936b1bd..585138ab1 100644 --- a/src/infiniop/ops/blas_copy/info.h +++ b/src/infiniop/ops/blas_copy/info.h @@ -42,4 +42,4 @@ class BlasCopyInfo { } }; -#endif // __BLAS_COPY_INFO_H__ \ No newline at end of file +#endif // __BLAS_COPY_INFO_H__ diff --git a/src/infiniop/ops/blas_dot/bang/blas_dot_bang.h b/src/infiniop/ops/blas_dot/bang/blas_dot_bang.h index 2020935ad..1c2c18019 100644 --- a/src/infiniop/ops/blas_dot/bang/blas_dot_bang.h +++ b/src/infiniop/ops/blas_dot/bang/blas_dot_bang.h @@ -5,4 +5,4 @@ DESCRIPTOR(bang) -#endif // __BLAS_DOT_BANG_H__ \ No newline at end of file +#endif // __BLAS_DOT_BANG_H__ diff --git a/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.h b/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.h index e589316a9..0f09f8d08 100644 --- a/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.h +++ b/src/infiniop/ops/blas_dot/cpu/blas_dot_cpu.h @@ -5,4 +5,4 @@ DESCRIPTOR(cpu) -#endif // __BLAS_DOT_CPU_H__ \ No newline at end of file +#endif // __BLAS_DOT_CPU_H__ diff --git a/src/infiniop/ops/blas_dot/info.h b/src/infiniop/ops/blas_dot/info.h index b06decd6b..01e145f6e 100644 --- a/src/infiniop/ops/blas_dot/info.h +++ b/src/infiniop/ops/blas_dot/info.h @@ -46,4 +46,4 @@ class BlasDotInfo { } }; -#endif // __BLAS_DOT_INFO_H__ \ No newline at end of file +#endif // __BLAS_DOT_INFO_H__ diff --git a/src/infiniop/ops/blas_dot/metax/blas_dot_metax.h b/src/infiniop/ops/blas_dot/metax/blas_dot_metax.h index 19a2385b0..0c5cefbf8 100644 --- a/src/infiniop/ops/blas_dot/metax/blas_dot_metax.h +++ b/src/infiniop/ops/blas_dot/metax/blas_dot_metax.h @@ -5,4 +5,4 @@ DESCRIPTOR(metax) -#endif // __BLAS_DOT_METAX_H__ \ No newline at end of file +#endif // __BLAS_DOT_METAX_H__ diff --git a/src/infiniop/ops/nrm2/bang/nrm2_bang.h b/src/infiniop/ops/nrm2/bang/nrm2_bang.h index af7961392..1a9a4761f 100644 --- a/src/infiniop/ops/nrm2/bang/nrm2_bang.h +++ b/src/infiniop/ops/nrm2/bang/nrm2_bang.h @@ -5,4 +5,4 @@ DESCRIPTOR(bang) -#endif // __NRM2_BANG_H__ \ No newline at end of file +#endif // __NRM2_BANG_H__ diff --git a/src/infiniop/ops/nrm2/cpu/nrm2_cpu.h b/src/infiniop/ops/nrm2/cpu/nrm2_cpu.h index 320e18f28..cf1e3fdef 100644 --- a/src/infiniop/ops/nrm2/cpu/nrm2_cpu.h +++ b/src/infiniop/ops/nrm2/cpu/nrm2_cpu.h @@ -5,4 +5,4 @@ DESCRIPTOR(cpu) -#endif // __NRM2_CPU_H__ \ No newline at end of file +#endif // __NRM2_CPU_H__ diff --git a/src/infiniop/ops/nrm2/info.h b/src/infiniop/ops/nrm2/info.h index acf11ba25..04a64abee 100644 --- a/src/infiniop/ops/nrm2/info.h +++ b/src/infiniop/ops/nrm2/info.h @@ -38,4 +38,4 @@ class Nrm2Info { } }; -#endif // __NRM2_INFO_H__ \ No newline at end of file +#endif // __NRM2_INFO_H__ diff --git a/src/infiniop/ops/nrm2/metax/nrm2_metax.h b/src/infiniop/ops/nrm2/metax/nrm2_metax.h index b398b5755..5ebf0aaf2 100644 --- a/src/infiniop/ops/nrm2/metax/nrm2_metax.h +++ b/src/infiniop/ops/nrm2/metax/nrm2_metax.h @@ -5,4 +5,4 @@ DESCRIPTOR(metax) -#endif // __NRM2_METAX_H__ \ No newline at end of file +#endif // __NRM2_METAX_H__ diff --git a/src/infiniop/ops/nrm2/operator.cc b/src/infiniop/ops/nrm2/operator.cc index 8f9e00c0b..d1ecfa2bb 100644 --- a/src/infiniop/ops/nrm2/operator.cc +++ b/src/infiniop/ops/nrm2/operator.cc @@ -118,4 +118,4 @@ __INFINI_C infiniStatus_t infiniopDestroyNrm2Descriptor(infiniopNrm2Descriptor_t } #undef DELETE -} \ No newline at end of file +} diff --git a/src/infiniop/ops/rot/bang/rot_bang.h b/src/infiniop/ops/rot/bang/rot_bang.h index 86ce84f58..7ef66bb62 100644 --- a/src/infiniop/ops/rot/bang/rot_bang.h +++ b/src/infiniop/ops/rot/bang/rot_bang.h @@ -5,4 +5,4 @@ DESCRIPTOR(bang) -#endif // __ROT_BANG_H__ \ No newline at end of file +#endif // __ROT_BANG_H__ diff --git a/src/infiniop/ops/rot/cpu/rot_cpu.h b/src/infiniop/ops/rot/cpu/rot_cpu.h index 57d0a054e..2a5bd0ab8 100644 --- a/src/infiniop/ops/rot/cpu/rot_cpu.h +++ b/src/infiniop/ops/rot/cpu/rot_cpu.h @@ -5,4 +5,4 @@ DESCRIPTOR(cpu) -#endif // __ROT_CPU_H__ \ No newline at end of file +#endif // __ROT_CPU_H__ diff --git a/src/infiniop/ops/rot/info.h b/src/infiniop/ops/rot/info.h index 93de6413b..5a1ddf5ed 100644 --- a/src/infiniop/ops/rot/info.h +++ b/src/infiniop/ops/rot/info.h @@ -50,4 +50,4 @@ class RotInfo { } }; -#endif // __ROT_INFO_H__ \ No newline at end of file +#endif // __ROT_INFO_H__ diff --git a/src/infiniop/ops/rot/metax/rot_metax.h b/src/infiniop/ops/rot/metax/rot_metax.h index 2c64a3a0b..b6bdd1553 100644 --- a/src/infiniop/ops/rot/metax/rot_metax.h +++ b/src/infiniop/ops/rot/metax/rot_metax.h @@ -5,4 +5,4 @@ DESCRIPTOR(metax) -#endif // __ROT_METAX_H__ \ No newline at end of file +#endif // __ROT_METAX_H__ diff --git a/src/infiniop/ops/rotg/info.h b/src/infiniop/ops/rotg/info.h index e96bea587..99e486a1b 100644 --- a/src/infiniop/ops/rotg/info.h +++ b/src/infiniop/ops/rotg/info.h @@ -38,4 +38,4 @@ class RotgInfo { } }; -#endif // __ROTG_INFO_H__ \ No newline at end of file +#endif // __ROTG_INFO_H__ diff --git a/src/infiniop/ops/rotm/bang/rotm_bang.h b/src/infiniop/ops/rotm/bang/rotm_bang.h index 4d2473b20..49cbd6789 100644 --- a/src/infiniop/ops/rotm/bang/rotm_bang.h +++ b/src/infiniop/ops/rotm/bang/rotm_bang.h @@ -5,4 +5,4 @@ DESCRIPTOR(bang) -#endif // __ROTM_BANG_H__ \ No newline at end of file +#endif // __ROTM_BANG_H__ diff --git a/src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu b/src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu index d17e66856..e943d1464 100644 --- a/src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu +++ b/src/infiniop/ops/rotm/bang/rotm_bang_kernel.mlu @@ -180,4 +180,4 @@ __mlu_global__ void rotmKernelStrided( y[y_idx] = -w + h22 * z; } } -} \ No newline at end of file +} diff --git a/src/infiniop/ops/rotm/cpu/rotm_cpu.h b/src/infiniop/ops/rotm/cpu/rotm_cpu.h index 740d741ad..972bd4c8a 100644 --- a/src/infiniop/ops/rotm/cpu/rotm_cpu.h +++ b/src/infiniop/ops/rotm/cpu/rotm_cpu.h @@ -5,4 +5,4 @@ DESCRIPTOR(cpu) -#endif // __ROTM_CPU_H__ \ No newline at end of file +#endif // __ROTM_CPU_H__ diff --git a/src/infiniop/ops/rotm/info.h b/src/infiniop/ops/rotm/info.h index 6509e6d75..0cf44af91 100644 --- a/src/infiniop/ops/rotm/info.h +++ b/src/infiniop/ops/rotm/info.h @@ -47,4 +47,4 @@ class RotmInfo { } }; -#endif // __ROTM_INFO_H__ \ No newline at end of file +#endif // __ROTM_INFO_H__ diff --git a/src/infiniop/ops/rotm/metax/rotm_metax.h b/src/infiniop/ops/rotm/metax/rotm_metax.h index 7dfe84cab..07b031336 100644 --- a/src/infiniop/ops/rotm/metax/rotm_metax.h +++ b/src/infiniop/ops/rotm/metax/rotm_metax.h @@ -5,4 +5,4 @@ DESCRIPTOR(metax) -#endif // __ROTM_METAX_H__ \ No newline at end of file +#endif // __ROTM_METAX_H__ diff --git a/src/infiniop/ops/rotm/operator.cc b/src/infiniop/ops/rotm/operator.cc index fb8eda223..fcfe2060c 100644 --- a/src/infiniop/ops/rotm/operator.cc +++ b/src/infiniop/ops/rotm/operator.cc @@ -126,4 +126,4 @@ __INFINI_C infiniStatus_t infiniopDestroyRotmDescriptor(infiniopRotmDescriptor_t } #undef DELETE -} \ No newline at end of file +} diff --git a/src/infiniop/ops/rotmg/bang/rotmg_bang_kernel.mlu b/src/infiniop/ops/rotmg/bang/rotmg_bang_kernel.mlu index f45d131da..a9237e2d9 100644 --- a/src/infiniop/ops/rotmg/bang/rotmg_bang_kernel.mlu +++ b/src/infiniop/ops/rotmg/bang/rotmg_bang_kernel.mlu @@ -153,4 +153,4 @@ __mlu_global__ void rotmgKernel( *d1 = static_cast(d1_val); *d2 = static_cast(d2_val); *x1 = static_cast(x1_val); -} \ No newline at end of file +} diff --git a/src/infiniop/ops/rotmg/info.h b/src/infiniop/ops/rotmg/info.h index cc9f2826b..e7fdb8532 100644 --- a/src/infiniop/ops/rotmg/info.h +++ b/src/infiniop/ops/rotmg/info.h @@ -45,4 +45,4 @@ class RotmgInfo { } }; -#endif // __ROTMG_INFO_H__ \ No newline at end of file +#endif // __ROTMG_INFO_H__ diff --git a/src/infiniop/ops/scal/bang/scal_bang.h b/src/infiniop/ops/scal/bang/scal_bang.h index c61532ce7..726790b8a 100644 --- a/src/infiniop/ops/scal/bang/scal_bang.h +++ b/src/infiniop/ops/scal/bang/scal_bang.h @@ -5,4 +5,4 @@ DESCRIPTOR(bang) -#endif // __SCAL_BANG_H__ \ No newline at end of file +#endif // __SCAL_BANG_H__ diff --git a/src/infiniop/ops/scal/cpu/scal_cpu.h b/src/infiniop/ops/scal/cpu/scal_cpu.h index efbba93ba..1a0f2fe6b 100644 --- a/src/infiniop/ops/scal/cpu/scal_cpu.h +++ b/src/infiniop/ops/scal/cpu/scal_cpu.h @@ -5,4 +5,4 @@ DESCRIPTOR(cpu) -#endif // __SCAL_CPU_H__ \ No newline at end of file +#endif // __SCAL_CPU_H__ diff --git a/src/infiniop/ops/scal/info.h b/src/infiniop/ops/scal/info.h index fd104e5d5..c9c3122b7 100644 --- a/src/infiniop/ops/scal/info.h +++ b/src/infiniop/ops/scal/info.h @@ -37,4 +37,4 @@ class ScalInfo { } }; -#endif // __SCAL_INFO_H__ \ No newline at end of file +#endif // __SCAL_INFO_H__ diff --git a/src/infiniop/ops/scal/metax/scal_metax.h b/src/infiniop/ops/scal/metax/scal_metax.h index 8e63ac399..1e5760ffe 100644 --- a/src/infiniop/ops/scal/metax/scal_metax.h +++ b/src/infiniop/ops/scal/metax/scal_metax.h @@ -5,4 +5,4 @@ DESCRIPTOR(metax) -#endif // __SCAL_METAX_H__ \ No newline at end of file +#endif // __SCAL_METAX_H__ diff --git a/src/infiniop/ops/swap/info.h b/src/infiniop/ops/swap/info.h index bb4d84b31..0dad381f7 100644 --- a/src/infiniop/ops/swap/info.h +++ b/src/infiniop/ops/swap/info.h @@ -41,4 +41,4 @@ class SwapInfo { } }; -#endif // __SWAP_INFO_H__ \ No newline at end of file +#endif // __SWAP_INFO_H__ diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 0eeb99098..e43f3a18d 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -1353,6 +1353,7 @@ def gptq_qyblas_gemm_(lib): infiniopOperatorDescriptor_t, ] + @OpRegister.operator def softplus_(lib): lib.infiniopCreateSoftplusDescriptor.restype = c_int32 From 0ce6e8ea46e3f0a5f1f90738b745302b60650d2e Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Sat, 9 May 2026 07:22:40 +0000 Subject: [PATCH 24/25] Add empty `getMemInfo` and `getDeviceResourceSnapshot` to `infinirt_bang.cc` --- src/infinirt/bang/infinirt_bang.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/infinirt/bang/infinirt_bang.cc b/src/infinirt/bang/infinirt_bang.cc index 5384add19..fe43e15cf 100644 --- a/src/infinirt/bang/infinirt_bang.cc +++ b/src/infinirt/bang/infinirt_bang.cc @@ -172,4 +172,12 @@ infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stre return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } +infiniStatus_t getMemInfo(int device_id, size_t *free_bytes, size_t *total_bytes) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t getDeviceResourceSnapshot(int device_id, infinirtDeviceResourceSnapshot_t *snapshot) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + } // namespace infinirt::bang From bbb364e6cd35e557d5d28044a0dbe4fd4e257b6c Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Sat, 9 May 2026 07:35:31 +0000 Subject: [PATCH 25/25] Revert `.clang-format` --- .clang-format | 1 - 1 file changed, 1 deletion(-) diff --git a/.clang-format b/.clang-format index a802aaf49..c05c7633b 100644 --- a/.clang-format +++ b/.clang-format @@ -28,4 +28,3 @@ BraceWrapping: SplitEmptyFunction: true SplitEmptyRecord: true SplitEmptyNamespace: true -InsertNewlineAtEOF: true