Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/infiniop/ops/sigmoid/metax/sigmoid_metax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __SIGMOID_METAX_API_H__
#define __SIGMOID_METAX_API_H__

#include "../../../elementwise/metax/elementwise_metax_api.h"

ELEMENTWISE_DESCRIPTOR(sigmoid, metax)

#endif // __SIGMOID_METAX_API_H__
60 changes: 60 additions & 0 deletions src/infiniop/ops/sigmoid/metax/sigmoid_metax.maca
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "sigmoid_metax.h"

#include "../../../elementwise/metax/elementwise_metax.h"

#include "../cuda/kernel.cuh"

namespace op::sigmoid::metax {

Descriptor::~Descriptor() = default;

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {

auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto dtype = out_desc->dtype();

const auto &input_desc = input_desc_vec.at(0);
const auto &output_shape = out_desc->shape();
const auto &input_shape = input_desc->shape();

CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);

CHECK_SAME_SHAPE(output_shape, input_shape);

// create METAX elementwise descriptor
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {

if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}

switch (_dtype) {
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::SigmoidOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F16:
return _device_info->calculate<256, cuda::SigmoidOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, cuda::SigmoidOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, cuda::SigmoidOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

return INFINI_STATUS_SUCCESS;
}
} // namespace op::sigmoid::metax
15 changes: 15 additions & 0 deletions src/infiniop/ops/sigmoid/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
#include "nvidia/sigmoid_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/sigmoid_metax.h"
#endif

__INFINI_C infiniStatus_t infiniopCreateSigmoidDescriptor(
infiniopHandle_t handle,
Expand Down Expand Up @@ -43,6 +46,9 @@ __INFINI_C infiniStatus_t infiniopCreateSigmoidDescriptor(
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -76,6 +82,9 @@ __INFINI_C infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescrip
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -118,6 +127,9 @@ __INFINI_C infiniStatus_t infiniopSigmoid(
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -154,6 +166,9 @@ infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc) {
#ifdef ENABLE_HYGON_API
DELETE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down