diff --git a/include/infinicore/ops/cosh.hpp b/include/infinicore/ops/cosh.hpp new file mode 100644 index 000000000..78dd2ccb7 --- /dev/null +++ b/include/infinicore/ops/cosh.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Cosh { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor y, Tensor x); + static common::OpDispatcher &dispatcher(); +}; + +Tensor cosh(Tensor x); +void cosh_(Tensor y, Tensor x); +} // namespace infinicore::op \ No newline at end of file diff --git a/include/infiniop.h b/include/infiniop.h index 56b29f9f5..6c4444c78 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -19,6 +19,7 @@ #include "infiniop/ops/cdist.h" #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" +#include "infiniop/ops/cosh.h" #include "infiniop/ops/cross_entropy.h" #include "infiniop/ops/dequant/per_tensor_dequant_int8.h" #include "infiniop/ops/dequantize_awq.h" diff --git a/include/infiniop/ops/cosh.h b/include/infiniop/ops/cosh.h new file mode 100644 index 000000000..3a52d8bda --- /dev/null +++ b/include/infiniop/ops/cosh.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_COSH_API_H__ +#define __INFINIOP_COSH_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopCoshDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateCoshDescriptor(infiniopHandle_t handle, + infiniopCoshDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +__INFINI_C __export infiniStatus_t infiniopGetCoshWorkspaceSize(infiniopCoshDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopCosh(infiniopCoshDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyCoshDescriptor(infiniopCoshDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 730df7741..54e559d29 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -64,6 +64,7 @@ ) from infinicore.ops.cat import cat from infinicore.ops.cdist import cdist +from infinicore.ops.cosh import cosh from infinicore.ops.cross_entropy import cross_entropy from infinicore.ops.equal import equal from infinicore.ops.fmin import fmin @@ -84,6 +85,7 @@ from infinicore.ops.paged_caching import paged_caching from infinicore.ops.rearrange import rearrange from infinicore.ops.reciprocal import reciprocal +from infinicore.ops.round import round from infinicore.ops.squeeze import squeeze from infinicore.ops.sum import sum from infinicore.ops.take import take @@ -154,6 +156,7 @@ "binary_cross_entropy_with_logits", "cdist", "reciprocal", + "round", "add", "addr", "add_rms_norm", @@ -176,6 +179,7 @@ "squeeze", "unsqueeze", "rearrange", + "cosh", "cross_entropy", "tan", "empty", diff --git a/python/infinicore/ops/cosh.py b/python/infinicore/ops/cosh.py new file mode 100644 index 000000000..f7ed1344c --- /dev/null +++ b/python/infinicore/ops/cosh.py @@ -0,0 +1,15 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def cosh(input, *, out=None): + if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None: + return infinicore.ntops.torch.cosh(input) + + if out is None: + return Tensor(_infinicore.cosh(input._underlying)) + + _infinicore.cosh_(out._underlying, input._underlying) + + return out diff --git a/python/infinicore/ops/round.py b/python/infinicore/ops/round.py new file mode 100644 index 000000000..4ba0ce9b2 --- /dev/null +++ b/python/infinicore/ops/round.py @@ -0,0 +1,16 @@ +import infinicore +from infinicore.tensor import Tensor + + +def round(input: Tensor, decimals=0, *, out=None) -> Tensor: + r"""Round elements to the nearest integer, with banker's rounding.""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None: + return infinicore.ntops.torch.round(input, decimals=decimals) + + if out is None: + return infinicore.ntops.torch.round(input, decimals=decimals) + + result = infinicore.ntops.torch.round(input, decimals=decimals) + out.copy_(result) + return out diff --git a/src/infinicore/ops/cosh/cosh.cc b/src/infinicore/ops/cosh/cosh.cc new file mode 100644 index 000000000..a485a1a73 --- /dev/null +++ b/src/infinicore/ops/cosh/cosh.cc @@ -0,0 +1,27 @@ +#include "infinicore/ops/cosh.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Cosh::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Cosh::execute(Tensor y, Tensor x) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x); + infinicore::context::setDevice(y->device()); + dispatcher().lookup(y->device().getType())(y, x); +} + +Tensor cosh(Tensor x) { + auto y = Tensor::empty(x->shape(), x->dtype(), x->device()); + cosh_(y, x); + return y; +} + +void cosh_(Tensor y, Tensor x) { + Cosh::execute(y, x); +} + +} // namespace infinicore::op \ No newline at end of file diff --git a/src/infinicore/ops/cosh/cosh_infiniop.cc b/src/infinicore/ops/cosh/cosh_infiniop.cc new file mode 100644 index 000000000..b8d4097f3 --- /dev/null +++ b/src/infinicore/ops/cosh/cosh_infiniop.cc @@ -0,0 +1,52 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/cosh.hpp" +#include + +namespace infinicore::op::cosh_impl::infiniop { + +thread_local common::OpCache caches( + 100, + [](infiniopCoshDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyCoshDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor y, Tensor x) { + size_t seed = hash_combine(y, x); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopCoshDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateCoshDescriptor( + context::getInfiniopHandle(y->device()), &desc, + y->desc(), x->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetCoshWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopCosh( + desc, workspace->data(), workspace_size, + y->data(), x->data(), context::getStream())); +} + +static bool registered = []() { + Cosh::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::cosh_impl::infiniop \ No newline at end of file diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index a8072aea4..67117a036 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -21,6 +21,7 @@ #include "ops/cat.hpp" #include "ops/causal_softmax.hpp" #include "ops/cdist.hpp" +#include "ops/cosh.hpp" #include "ops/cross_entropy.hpp" #include "ops/embedding.hpp" #include "ops/equal.hpp" @@ -78,6 +79,7 @@ inline void bind(py::module &m) { bind_baddbmm(m); bind_bilinear(m); bind_causal_softmax(m); + bind_cosh(m); bind_flash_attention(m); bind_kv_caching(m); bind_fmod(m); diff --git a/src/infinicore/pybind11/ops/cosh.hpp b/src/infinicore/pybind11/ops/cosh.hpp new file mode 100644 index 000000000..fab93fa2a --- /dev/null +++ b/src/infinicore/pybind11/ops/cosh.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "infinicore/ops/cosh.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_cosh(py::module &m) { + m.def("cosh", + &op::cosh, + py::arg("x"), + R"doc(Element-wise hyperbolic cosine function.)doc"); + + m.def("cosh_", + &op::cosh_, + py::arg("y"), + py::arg("x"), + R"doc(In-place element-wise hyperbolic cosine function.)doc"); +} + +} // namespace infinicore::ops \ No newline at end of file diff --git a/src/infiniop/ops/cosh/cpu/cosh_cpu.cc b/src/infiniop/ops/cosh/cpu/cosh_cpu.cc new file mode 100644 index 000000000..c4de3d38d --- /dev/null +++ b/src/infiniop/ops/cosh/cpu/cosh_cpu.cc @@ -0,0 +1,49 @@ +#include "cosh_cpu.h" + +namespace op::cosh::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + CHECK_SAME_SHAPE(y_shape, x_shape); + + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate(_info, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::cosh::cpu \ No newline at end of file diff --git a/src/infiniop/ops/cosh/cpu/cosh_cpu.h b/src/infiniop/ops/cosh/cpu/cosh_cpu.h new file mode 100644 index 000000000..6bf895eeb --- /dev/null +++ b/src/infiniop/ops/cosh/cpu/cosh_cpu.h @@ -0,0 +1,22 @@ +#ifndef __COSH_CPU_H__ +#define __COSH_CPU_H__ + +#include + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(cosh, cpu) + +namespace op::cosh::cpu { +typedef struct CoshOp { +public: + static constexpr size_t num_inputs = 1; + + template + T operator()(const T &x) const { + return std::cosh(x); + } +} CoshOp; +} // namespace op::cosh::cpu + +#endif // __COSH_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/cosh/cuda/kernel.cuh b/src/infiniop/ops/cosh/cuda/kernel.cuh new file mode 100644 index 000000000..b441c35c5 --- /dev/null +++ b/src/infiniop/ops/cosh/cuda/kernel.cuh @@ -0,0 +1,27 @@ +#ifndef __COSH_CUDA_KERNEL_H__ +#define __COSH_CUDA_KERNEL_H__ + +namespace op::cosh::cuda { + +typedef struct CoshOp { +public: + static constexpr size_t num_inputs = 1; + template + __device__ __forceinline__ T operator()(const T &x) const { + if constexpr (std::is_same_v) { + float x_f = __half2float(x); + return __float2half(coshf(x_f)); + } else if constexpr (std::is_same_v) { + float x_f = __bfloat162float(x); + return __float2bfloat16(coshf(x_f)); + } else if constexpr (std::is_same_v) { + return coshf(x); + } else { + return ::cosh(x); + } + } +} CoshOp; + +} // namespace op::cosh::cuda + +#endif // __COSH_CUDA_KERNEL_H__ \ No newline at end of file diff --git a/src/infiniop/ops/cosh/metax/cosh.maca b/src/infiniop/ops/cosh/metax/cosh.maca new file mode 100644 index 000000000..53755f75c --- /dev/null +++ b/src/infiniop/ops/cosh/metax/cosh.maca @@ -0,0 +1,56 @@ +#include "../../../elementwise/metax/elementwise_metax.h" +#include "cosh_metax.h" + +#include "../cuda/kernel.cuh" + +namespace op::cosh::metax { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + CHECK_SAME_SHAPE(y_shape, x_shape); + + CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::CoshOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::CoshOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::CoshOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::CoshOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::cosh::metax \ No newline at end of file diff --git a/src/infiniop/ops/cosh/metax/cosh_metax.h b/src/infiniop/ops/cosh/metax/cosh_metax.h new file mode 100644 index 000000000..73bce4f9a --- /dev/null +++ b/src/infiniop/ops/cosh/metax/cosh_metax.h @@ -0,0 +1,8 @@ +#ifndef __COSH_METAX_API_H__ +#define __COSH_METAX_API_H__ + +#include "../../../elementwise/metax/elementwise_metax_api.h" + +ELEMENTWISE_DESCRIPTOR(cosh, metax) + +#endif // __COSH_METAX_API_H__ \ No newline at end of file diff --git a/src/infiniop/ops/cosh/moore/cosh_moore.h b/src/infiniop/ops/cosh/moore/cosh_moore.h new file mode 100644 index 000000000..361bd7195 --- /dev/null +++ b/src/infiniop/ops/cosh/moore/cosh_moore.h @@ -0,0 +1,8 @@ +#ifndef __COSH_MOORE_API_H__ +#define __COSH_MOORE_API_H__ + +#include "../../../elementwise/moore/elementwise_moore_api.h" + +ELEMENTWISE_DESCRIPTOR(cosh, moore) + +#endif // __COSH_MOORE_API_H__ \ No newline at end of file diff --git a/src/infiniop/ops/cosh/moore/cosh_moore.mu b/src/infiniop/ops/cosh/moore/cosh_moore.mu new file mode 100644 index 000000000..69735cf9d --- /dev/null +++ b/src/infiniop/ops/cosh/moore/cosh_moore.mu @@ -0,0 +1,57 @@ +#include "cosh_moore.h" + +#include "../../../elementwise/moore/elementwise_moore.h" + +#include "../cuda/kernel.cuh" + +namespace op::cosh::moore { +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + CHECK_SAME_SHAPE(y_shape, x_shape); + + CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::CoshOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::CoshOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::CoshOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::CoshOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::cosh::moore \ No newline at end of file diff --git a/src/infiniop/ops/cosh/nvidia/cosh_nvidia.cu b/src/infiniop/ops/cosh/nvidia/cosh_nvidia.cu new file mode 100644 index 000000000..1ec259509 --- /dev/null +++ b/src/infiniop/ops/cosh/nvidia/cosh_nvidia.cu @@ -0,0 +1,55 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" + +#include "../cuda/kernel.cuh" +#include "cosh_nvidia.cuh" + +namespace op::cosh::nvidia { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + CHECK_SAME_SHAPE(y_shape, x_shape); + + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::CoshOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::CoshOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::CoshOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::CoshOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::cosh::nvidia \ No newline at end of file diff --git a/src/infiniop/ops/cosh/nvidia/cosh_nvidia.cuh b/src/infiniop/ops/cosh/nvidia/cosh_nvidia.cuh new file mode 100644 index 000000000..6a032b0bb --- /dev/null +++ b/src/infiniop/ops/cosh/nvidia/cosh_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __COSH_NVIDIA_API_H__ +#define __COSH_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(cosh, nvidia) + +#endif // __COSH_NVIDIA_API_H__ diff --git a/src/infiniop/ops/cosh/operator.cc b/src/infiniop/ops/cosh/operator.cc new file mode 100644 index 000000000..aaaf60d01 --- /dev/null +++ b/src/infiniop/ops/cosh/operator.cc @@ -0,0 +1,139 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/cosh.h" + +#ifdef ENABLE_CPU_API +#include "cpu/cosh_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#include "nvidia/cosh_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/cosh_metax.h" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/cosh_moore.h" +#endif + +__INFINI_C infiniStatus_t infiniopCreateCoshDescriptor( + infiniopHandle_t handle, + infiniopCoshDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::cosh::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + {x_desc}) + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetCoshWorkspaceSize(infiniopCoshDescriptor_t desc, size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__INFINI_C infiniStatus_t infiniopCosh(infiniopCoshDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, y, {x}, stream); + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyCoshDescriptor(infiniopCoshDescriptor_t desc) { +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DESTROY(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + DESTROY(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DESTROY +} \ No newline at end of file diff --git a/test/infinicore/ops/cosh.py b/test/infinicore/ops/cosh.py index 16d3c3d7b..9d6d6cf9d 100644 --- a/test/infinicore/ops/cosh.py +++ b/test/infinicore/ops/cosh.py @@ -1,18 +1,19 @@ -import sys import os +import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) -import infinicore import torch from framework import ( BaseOperatorTest, + GenericTestRunner, TensorSpec, TestCase, - GenericTestRunner, is_broadcast, ) +import infinicore + # ======================================================================= # Test cases format: (shape, input_strides_or_None) # ======================================================================= @@ -97,9 +98,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.cosh(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.cosh(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore cosh implementation""" + return infinicore.cosh(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/round.py b/test/infinicore/ops/round.py index 7cd48f660..0365fb5c5 100644 --- a/test/infinicore/ops/round.py +++ b/test/infinicore/ops/round.py @@ -93,9 +93,8 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.round(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.round(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + return infinicore.round(*args, **kwargs) def main():