Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions src/native/ascend/ops/rms_norm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#ifndef INFINI_OPS_ASCEND_RMS_NORM_KERNEL_H_
#define INFINI_OPS_ASCEND_RMS_NORM_KERNEL_H_

#include <vector>

#include "acl/acl.h"
#include "aclnn/aclnn_base.h"
#include "aclnn_rms_norm.h"
#include "base/rms_norm.h"
#include "native/ascend/common.h"
#include "native/ascend/workspace_pool_.h"
#include "operator.h"

namespace infini::ops {

template <>
class Operator<RmsNorm, Device::Type::kAscend> : public RmsNorm {
public:
Operator(const Tensor input, const Tensor weight, float eps, Tensor out)
: RmsNorm(input, weight, eps, out),
in_cache_(input),
weight_cache_(weight),
out_cache_(out) {
// `aclnnRmsNorm` writes `rstd` as a required side output. Size is
// computed here; the buffer is obtained from the pool in `operator()`.
rstd_shape_ = {static_cast<int64_t>(batch_size_),
static_cast<int64_t>(nhead_)};
rstd_size_ = batch_size_ * nhead_ * sizeof(float);
}

~Operator() {
if (!ascend::IsAclRuntimeAlive()) return;

// Null cached descriptors — see `AclTensorCache::release()`.
in_cache_.release();
weight_cache_.release();
out_cache_.release();
// `rstd_tensor_` leaks with the executor at shutdown (see `64c367c`).
}

void operator()(const Tensor input, const Tensor weight, float eps,
Tensor out) const override {
auto t_in = in_cache_.get(const_cast<void*>(input.data()));
auto t_weight = weight_cache_.get(const_cast<void*>(weight.data()));
auto t_out = out_cache_.get(out.data());
auto stream = static_cast<aclrtStream>(stream_);

// Obtain shared `rstd` buffer from pool.
auto& rstd_arena =
ascend::GetWorkspacePool().Ensure(stream, rstd_size_, "temp");

// Lazily create the `rstd` tensor descriptor on first call.
if (!rstd_tensor_) {
rstd_tensor_ = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT,
/*strides=*/nullptr, 0, ACL_FORMAT_ND,
rstd_shape_.data(), 2, rstd_arena.buf);
} else {
aclSetRawTensorAddr(rstd_tensor_, rstd_arena.buf);
}

if (!executor_) {
aclnnRmsNormGetWorkspaceSize(t_in, t_weight, eps, t_out, rstd_tensor_,
&ws_size_, &executor_);
aclSetAclOpExecutorRepeatable(executor_);
} else {
aclSetInputTensorAddr(executor_, 0, t_in,
const_cast<void*>(input.data()));
aclSetInputTensorAddr(executor_, 1, t_weight,
const_cast<void*>(weight.data()));
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
aclSetOutputTensorAddr(executor_, 1, rstd_tensor_, rstd_arena.buf);
}

auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
aclnnRmsNorm(arena.buf, ws_size_, executor_, stream);
}

private:
mutable ascend::AclTensorCache in_cache_;

mutable ascend::AclTensorCache weight_cache_;

mutable ascend::AclTensorCache out_cache_;

mutable aclOpExecutor* executor_ = nullptr;

mutable uint64_t ws_size_ = 0;

std::vector<int64_t> rstd_shape_;

uint64_t rstd_size_ = 0;

mutable aclTensor* rstd_tensor_ = nullptr;
};

} // namespace infini::ops

#include "native/ascend/ops/rms_norm/kernel_custom.h"

#endif
155 changes: 155 additions & 0 deletions src/native/ascend/ops/rms_norm/kernel_custom.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#ifndef INFINI_OPS_ASCEND_RMS_NORM_KERNEL_CUSTOM_H_
#define INFINI_OPS_ASCEND_RMS_NORM_KERNEL_CUSTOM_H_

#ifdef INFINI_HAS_CUSTOM_KERNELS

#include <algorithm>
#include <cstdint>
#include <vector>

#include "acl/acl.h"
#include "aclnn/aclnn_base.h"
#include "aclnnop/aclnn_cast.h"
#include "base/rms_norm.h"
#include "native/ascend/common.h"
#include "native/ascend/workspace_pool_.h"
#include "operator.h"

// Forward-declare the `aclrtlaunch_RmsNorm` launch symbol defined by
// the AscendC toolchain from `custom/rms_norm/op_kernel/`.
extern "C" uint32_t aclrtlaunch_RmsNorm(
uint32_t block_dim, void* stream, void* input, void* weight,
int64_t total_rows, int64_t dim_length, int64_t dim_length_align,
int64_t former_num, int64_t former_length, int64_t tail_length, float eps,
int64_t dtype_code, void* out);

namespace infini::ops {

// Custom AscendC fused `RmsNorm` kernel (implementation index 1).
//
// A single-kernel implementation that computes `RMSNorm` in one launch,
// avoiding the 5-sub-op decomposition of `aclnnRmsNorm` (index 0). Uses
// `Sqrt` + scalar division instead of `Rsqrt` for higher precision (~1e-7
// `fp32` error vs ~0.2% with `Rsqrt`).
//
// Select via `implementation_index=1` in Python:
// `infini.ops.rms_norm(input, weight, eps, out, implementation_index=1,
// stream=s)`.
//
// Requirements:
// - Input last dimension must be 32-byte aligned (divisible by 16 for
// `fp16` or 8 for `fp32`). All standard LLM hidden dimensions satisfy
// this.
// - `weight` must have the same dtype as `input`.
// - The custom kernel binary must be linked (`BUILD_ASCEND_CUSTOM=ON`).
template <>
class Operator<RmsNorm, Device::Type::kAscend, 1> : public RmsNorm {
public:
Operator(const Tensor input, const Tensor weight, float eps, Tensor out)
: RmsNorm(input, weight, eps, out), dtype_{input.dtype()} {
assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 ||
dtype_ == DataType::kFloat32) &&
"`RmsNorm` custom kernel: `input` must be `fp16`, `bf16`, or "
"`fp32`");

// 32-byte alignment on the last dimension — kernel relies on aligned
// `DataCopyPad` loads/stores.
int64_t align_elems = 32 / static_cast<int64_t>(kDataTypeToSize.at(dtype_));
dim_length_align_ =
((static_cast<int64_t>(dim_) + align_elems - 1) / align_elems) *
align_elems;
assert(static_cast<int64_t>(dim_) == dim_length_align_ &&
"`RmsNorm` custom kernel: last dimension must be 32-byte aligned");

total_rows_ =
static_cast<int64_t>(batch_size_) * static_cast<int64_t>(nhead_);

// The custom kernel always reads `weight` as fp32, so fp16 / bf16
// inputs need a cached `aclnnCast` invocation in `operator()` to
// produce an fp32 shadow buffer on every launch.
if (dtype_ != DataType::kFloat32) {
size_t fp32_bytes = static_cast<size_t>(dim_) * sizeof(float);
aclrtMalloc(&weight_fp32_data_, fp32_bytes, ACL_MEM_MALLOC_NORMAL_ONLY);

weight_src_cache_ = ascend::AclTensorCache(
{static_cast<int64_t>(dim_)}, ascend::ToAclDtype(dtype_), nullptr);
weight_dst_cache_ = ascend::AclTensorCache({static_cast<int64_t>(dim_)},
ACL_FLOAT, weight_fp32_data_);
}
}

~Operator() {
if (!ascend::IsAclRuntimeAlive()) return;

// Null cached descriptors — see `AclTensorCache::release()`.
weight_src_cache_.release();
weight_dst_cache_.release();

if (weight_fp32_data_) aclrtFree(weight_fp32_data_);
}

void operator()(const Tensor input, const Tensor weight, float eps,
Tensor out) const override {
auto stream = static_cast<aclrtStream>(stream_);

void* weight_fp32;

if (dtype_ != DataType::kFloat32) {
auto t_src = weight_src_cache_.get(const_cast<void*>(weight.data()));
auto t_dst = weight_dst_cache_.get(weight_fp32_data_);

if (!cast_exec_) {
aclnnCastGetWorkspaceSize(t_src, ACL_FLOAT, t_dst, &cast_ws_,
&cast_exec_);
aclSetAclOpExecutorRepeatable(cast_exec_);
} else {
aclSetInputTensorAddr(cast_exec_, 0, t_src,
const_cast<void*>(weight.data()));
aclSetOutputTensorAddr(cast_exec_, 0, t_dst, weight_fp32_data_);
}

auto& arena = ascend::GetWorkspacePool().Ensure(stream, cast_ws_);
aclnnCast(arena.buf, cast_ws_, cast_exec_, stream);
weight_fp32 = weight_fp32_data_;
} else {
weight_fp32 = const_cast<void*>(weight.data());
}

// Block-level tiling. Ascend 910B has 20–40 AIV cores; over-subscribing
// is safe (runtime multiplexes) but wastes one weight load per block.
static constexpr int64_t kMaxBlockDim = 40;
int64_t used_cores = std::min(total_rows_, kMaxBlockDim);
int64_t former_length = (total_rows_ + used_cores - 1) / used_cores;
int64_t tail_length = former_length - 1;
int64_t former_num = total_rows_ - tail_length * used_cores;
uint32_t block_dim = static_cast<uint32_t>(used_cores);

aclrtlaunch_RmsNorm(block_dim, stream, const_cast<void*>(input.data()),
weight_fp32, total_rows_, static_cast<int64_t>(dim_),
dim_length_align_, former_num, former_length,
tail_length, eps, static_cast<int64_t>(dtype_),
out.data());
}

private:
DataType dtype_;

int64_t dim_length_align_;

int64_t total_rows_;

void* weight_fp32_data_ = nullptr;

mutable ascend::AclTensorCache weight_src_cache_;

mutable ascend::AclTensorCache weight_dst_cache_;

mutable aclOpExecutor* cast_exec_ = nullptr;

mutable uint64_t cast_ws_ = 0;
};

} // namespace infini::ops

#endif // INFINI_HAS_CUSTOM_KERNELS
#endif // INFINI_OPS_ASCEND_RMS_NORM_KERNEL_CUSTOM_H_
Loading