diff --git a/src/base/scaled_softmax.h b/src/base/scaled_softmax.h new file mode 100644 index 00000000..707bf1ff --- /dev/null +++ b/src/base/scaled_softmax.h @@ -0,0 +1,55 @@ +#ifndef INFINI_OPS_BASE_SCALED_SOFTMAX_H_ +#define INFINI_OPS_BASE_SCALED_SOFTMAX_H_ + +#include +#include +#include + +#include "data_type.h" +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class ScaledSoftmax : public Operator { + public: + ScaledSoftmax(const Tensor input, double scale, Tensor out) + : scale_{scale}, + batch_size_{input.size(0)}, + vocab_size_{input.size(1)}, + dtype_{input.dtype()}, + input_strides_{input.strides()}, + out_strides_{out.strides()} { + assert(input.ndim() == 2 && + "`ScaledSoftmax` currently supports 2D `[batch, vocab]` input"); + assert(input.shape() == out.shape() && + "`ScaledSoftmax` requires `input` and `out` to have the same shape"); + assert(input.dtype() == out.dtype() && + "`ScaledSoftmax` requires `input` and `out` to have the same dtype"); + assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 || + dtype_ == DataType::kFloat32 || dtype_ == DataType::kFloat64) && + "`ScaledSoftmax` requires a floating point dtype"); + assert(std::isfinite(scale_) && + "`ScaledSoftmax` requires a finite `scale`"); + } + + virtual void operator()(const Tensor input, double scale, + Tensor out) const = 0; + + protected: + double scale_{1.0}; + + Tensor::Size batch_size_{0}; + + Tensor::Size vocab_size_{0}; + + DataType dtype_; + + Tensor::Strides input_strides_; + + Tensor::Strides out_strides_; +}; + +} // namespace infini::ops + +#endif // INFINI_OPS_BASE_SCALED_SOFTMAX_H_ diff --git a/src/native/ascend/ops/scaled_softmax/kernel.h b/src/native/ascend/ops/scaled_softmax/kernel.h new file mode 100644 index 00000000..c6c0df6f --- /dev/null +++ b/src/native/ascend/ops/scaled_softmax/kernel.h @@ -0,0 +1,124 @@ +#ifndef INFINI_OPS_ASCEND_SCALED_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_ASCEND_SCALED_SOFTMAX_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_mul.h" +#include "aclnn_softmax.h" +#include "base/scaled_softmax.h" +#include "data_type.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public ScaledSoftmax { + public: + Operator(const Tensor input, double scale, Tensor out) + : ScaledSoftmax(input, scale, out), + in_cache_(input), + out_cache_(out), + temp_cache_(input), + scale_storage_(static_cast(scale)), + needs_scale_(std::fabs(scale - 1.0) > 1e-6) { + assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 || + dtype_ == DataType::kFloat32) && + "`ScaledSoftmax` Ascend path requires float16, bfloat16, or " + "float32 input"); + assert(input.IsContiguous() && + "`ScaledSoftmax` Ascend path requires contiguous input"); + assert(out.IsContiguous() && + "`ScaledSoftmax` Ascend path requires contiguous output"); + + temp_size_ = input.numel() * kDataTypeToSize.at(dtype_); + scale_scalar_ = aclCreateScalar(&scale_storage_, ACL_FLOAT); + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + in_cache_.release(); + out_cache_.release(); + temp_cache_.release(); + + if (scale_scalar_) aclDestroyScalar(scale_scalar_); + } + + void operator()(const Tensor input, double scale, Tensor out) const override { + assert(scale == scale_ && + "`ScaledSoftmax` scale changed after descriptor creation"); + + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_out = out_cache_.get(out.data()); + aclTensor* t_softmax_in = t_in; + void* softmax_in_data = const_cast(input.data()); + + if (needs_scale_) { + auto& temp = + ascend::GetWorkspacePool().Ensure(stream, temp_size_, "temp"); + auto t_temp = temp_cache_.get(temp.buf); + + if (!muls_exec_) { + aclnnMulsGetWorkspaceSize(t_in, scale_scalar_, t_temp, &muls_ws_, + &muls_exec_); + aclSetAclOpExecutorRepeatable(muls_exec_); + } else { + aclSetInputTensorAddr(muls_exec_, 0, t_in, + const_cast(input.data())); + aclSetOutputTensorAddr(muls_exec_, 0, t_temp, temp.buf); + } + + auto& muls_arena = ascend::GetWorkspacePool().Ensure(stream, muls_ws_); + aclnnMuls(muls_arena.buf, muls_ws_, muls_exec_, stream); + + t_softmax_in = t_temp; + softmax_in_data = temp.buf; + } + + if (!softmax_exec_) { + constexpr int64_t kLastDim = -1; + aclnnSoftmaxGetWorkspaceSize(t_softmax_in, kLastDim, t_out, &softmax_ws_, + &softmax_exec_); + aclSetAclOpExecutorRepeatable(softmax_exec_); + } else { + aclSetInputTensorAddr(softmax_exec_, 0, t_softmax_in, softmax_in_data); + aclSetOutputTensorAddr(softmax_exec_, 0, t_out, out.data()); + } + + auto& softmax_arena = + ascend::GetWorkspacePool().Ensure(stream, softmax_ws_); + aclnnSoftmax(softmax_arena.buf, softmax_ws_, softmax_exec_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache temp_cache_; + + float scale_storage_{1.0f}; + + aclScalar* scale_scalar_ = nullptr; + + bool needs_scale_{false}; + + uint64_t temp_size_{0}; + + mutable aclOpExecutor* muls_exec_ = nullptr; + + mutable uint64_t muls_ws_ = 0; + + mutable aclOpExecutor* softmax_exec_ = nullptr; + + mutable uint64_t softmax_ws_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_OPS_ASCEND_SCALED_SOFTMAX_KERNEL_H_ diff --git a/tests/test_scaled_softmax.py b/tests/test_scaled_softmax.py new file mode 100644 index 00000000..ec860b34 --- /dev/null +++ b/tests/test_scaled_softmax.py @@ -0,0 +1,66 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape", + ( + (1, 7), + (3, 11), + (16, 512), + ), +) +@pytest.mark.parametrize("scale", (1.0, 0.5, 1.7)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-5, 1e-5), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) +def test_scaled_softmax( + shape, + scale, + dtype, + device, + implementation_index, + rtol, + atol, +): + input_tensor = randn_strided(shape, None, dtype=dtype, device=device) + out = empty_strided(shape, None, dtype=dtype, device=device) + + return Payload( + _scaled_softmax, + _torch_scaled_softmax, + (input_tensor, out), + {"scale": scale, "implementation_index": implementation_index}, + rtol=rtol, + atol=atol, + ) + + +def _scaled_softmax(input_tensor, out, *, scale, implementation_index): + infini.ops.scaled_softmax( + input_tensor, + scale, + out, + stream=get_stream(input_tensor.device), + implementation_index=implementation_index, + ) + + return out + + +def _torch_scaled_softmax(input_tensor, out, *, scale, implementation_index): + del implementation_index + + result = torch.nn.functional.softmax(input_tensor.to(torch.float32) * scale, dim=-1) + out.copy_(result.to(input_tensor.dtype)) + + return out