Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/base/scaled_softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef INFINI_OPS_BASE_SCALED_SOFTMAX_H_
#define INFINI_OPS_BASE_SCALED_SOFTMAX_H_

#include <cassert>
#include <cmath>
#include <cstddef>

#include "data_type.h"
#include "operator.h"
#include "tensor.h"

namespace infini::ops {

class ScaledSoftmax : public Operator<ScaledSoftmax> {
public:
ScaledSoftmax(const Tensor input, double scale, Tensor out)
: scale_{scale},
batch_size_{input.size(0)},
vocab_size_{input.size(1)},
dtype_{input.dtype()},
input_strides_{input.strides()},
out_strides_{out.strides()} {
assert(input.ndim() == 2 &&
"`ScaledSoftmax` currently supports 2D `[batch, vocab]` input");
assert(input.shape() == out.shape() &&
"`ScaledSoftmax` requires `input` and `out` to have the same shape");
assert(input.dtype() == out.dtype() &&
"`ScaledSoftmax` requires `input` and `out` to have the same dtype");
assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 ||
dtype_ == DataType::kFloat32 || dtype_ == DataType::kFloat64) &&
"`ScaledSoftmax` requires a floating point dtype");
assert(std::isfinite(scale_) &&
"`ScaledSoftmax` requires a finite `scale`");
}

virtual void operator()(const Tensor input, double scale,
Tensor out) const = 0;

protected:
double scale_{1.0};

Tensor::Size batch_size_{0};

Tensor::Size vocab_size_{0};

DataType dtype_;

Tensor::Strides input_strides_;

Tensor::Strides out_strides_;
};

} // namespace infini::ops

#endif // INFINI_OPS_BASE_SCALED_SOFTMAX_H_
124 changes: 124 additions & 0 deletions src/native/ascend/ops/scaled_softmax/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#ifndef INFINI_OPS_ASCEND_SCALED_SOFTMAX_KERNEL_H_
#define INFINI_OPS_ASCEND_SCALED_SOFTMAX_KERNEL_H_

#include <cmath>

#include "acl/acl.h"
#include "aclnn/aclnn_base.h"
#include "aclnn_mul.h"
#include "aclnn_softmax.h"
#include "base/scaled_softmax.h"
#include "data_type.h"
#include "native/ascend/common.h"
#include "native/ascend/workspace_pool_.h"
#include "operator.h"

namespace infini::ops {

template <>
class Operator<ScaledSoftmax, Device::Type::kAscend> : public ScaledSoftmax {
public:
Operator(const Tensor input, double scale, Tensor out)
: ScaledSoftmax(input, scale, out),
in_cache_(input),
out_cache_(out),
temp_cache_(input),
scale_storage_(static_cast<float>(scale)),
needs_scale_(std::fabs(scale - 1.0) > 1e-6) {
assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 ||
dtype_ == DataType::kFloat32) &&
"`ScaledSoftmax` Ascend path requires float16, bfloat16, or "
"float32 input");
assert(input.IsContiguous() &&
"`ScaledSoftmax` Ascend path requires contiguous input");
assert(out.IsContiguous() &&
"`ScaledSoftmax` Ascend path requires contiguous output");

temp_size_ = input.numel() * kDataTypeToSize.at(dtype_);
scale_scalar_ = aclCreateScalar(&scale_storage_, ACL_FLOAT);
}

~Operator() {
if (!ascend::IsAclRuntimeAlive()) return;

in_cache_.release();
out_cache_.release();
temp_cache_.release();

if (scale_scalar_) aclDestroyScalar(scale_scalar_);
}

void operator()(const Tensor input, double scale, Tensor out) const override {
assert(scale == scale_ &&
"`ScaledSoftmax` scale changed after descriptor creation");

auto stream = static_cast<aclrtStream>(stream_);
auto t_in = in_cache_.get(const_cast<void*>(input.data()));
auto t_out = out_cache_.get(out.data());
aclTensor* t_softmax_in = t_in;
void* softmax_in_data = const_cast<void*>(input.data());

if (needs_scale_) {
auto& temp =
ascend::GetWorkspacePool().Ensure(stream, temp_size_, "temp");
auto t_temp = temp_cache_.get(temp.buf);

if (!muls_exec_) {
aclnnMulsGetWorkspaceSize(t_in, scale_scalar_, t_temp, &muls_ws_,
&muls_exec_);
aclSetAclOpExecutorRepeatable(muls_exec_);
} else {
aclSetInputTensorAddr(muls_exec_, 0, t_in,
const_cast<void*>(input.data()));
aclSetOutputTensorAddr(muls_exec_, 0, t_temp, temp.buf);
}

auto& muls_arena = ascend::GetWorkspacePool().Ensure(stream, muls_ws_);
aclnnMuls(muls_arena.buf, muls_ws_, muls_exec_, stream);

t_softmax_in = t_temp;
softmax_in_data = temp.buf;
}

if (!softmax_exec_) {
constexpr int64_t kLastDim = -1;
aclnnSoftmaxGetWorkspaceSize(t_softmax_in, kLastDim, t_out, &softmax_ws_,
&softmax_exec_);
aclSetAclOpExecutorRepeatable(softmax_exec_);
} else {
aclSetInputTensorAddr(softmax_exec_, 0, t_softmax_in, softmax_in_data);
aclSetOutputTensorAddr(softmax_exec_, 0, t_out, out.data());
}

auto& softmax_arena =
ascend::GetWorkspacePool().Ensure(stream, softmax_ws_);
aclnnSoftmax(softmax_arena.buf, softmax_ws_, softmax_exec_, stream);
}

private:
mutable ascend::AclTensorCache in_cache_;

mutable ascend::AclTensorCache out_cache_;

mutable ascend::AclTensorCache temp_cache_;

float scale_storage_{1.0f};

aclScalar* scale_scalar_ = nullptr;

bool needs_scale_{false};

uint64_t temp_size_{0};

mutable aclOpExecutor* muls_exec_ = nullptr;

mutable uint64_t muls_ws_ = 0;

mutable aclOpExecutor* softmax_exec_ = nullptr;

mutable uint64_t softmax_ws_ = 0;
};

} // namespace infini::ops

#endif // INFINI_OPS_ASCEND_SCALED_SOFTMAX_KERNEL_H_
66 changes: 66 additions & 0 deletions tests/test_scaled_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import infini.ops
import pytest
import torch

from tests.utils import Payload, empty_strided, get_stream, randn_strided


@pytest.mark.auto_act_and_assert
@pytest.mark.parametrize(
"shape",
(
(1, 7),
(3, 11),
(16, 512),
),
)
@pytest.mark.parametrize("scale", (1.0, 0.5, 1.7))
@pytest.mark.parametrize(
("dtype", "rtol", "atol"),
(
(torch.float32, 1e-5, 1e-5),
(torch.float16, 1e-2, 1e-2),
(torch.bfloat16, 1e-2, 1e-2),
),
)
def test_scaled_softmax(
shape,
scale,
dtype,
device,
implementation_index,
rtol,
atol,
):
input_tensor = randn_strided(shape, None, dtype=dtype, device=device)
out = empty_strided(shape, None, dtype=dtype, device=device)

return Payload(
_scaled_softmax,
_torch_scaled_softmax,
(input_tensor, out),
{"scale": scale, "implementation_index": implementation_index},
rtol=rtol,
atol=atol,
)


def _scaled_softmax(input_tensor, out, *, scale, implementation_index):
infini.ops.scaled_softmax(
input_tensor,
scale,
out,
stream=get_stream(input_tensor.device),
implementation_index=implementation_index,
)

return out


def _torch_scaled_softmax(input_tensor, out, *, scale, implementation_index):
del implementation_index

result = torch.nn.functional.softmax(input_tensor.to(torch.float32) * scale, dim=-1)
out.copy_(result.to(input_tensor.dtype))

return out
Loading