From 94320aceea2623ef5396ed85bc8340be519cfe73 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 30 Jun 2026 15:14:05 +0800 Subject: [PATCH] feat(ascend): add `rms_norm` operator --- src/native/ascend/ops/rms_norm/kernel.h | 100 +++++++++++ .../ascend/ops/rms_norm/kernel_custom.h | 155 ++++++++++++++++++ 2 files changed, 255 insertions(+) create mode 100644 src/native/ascend/ops/rms_norm/kernel.h create mode 100644 src/native/ascend/ops/rms_norm/kernel_custom.h diff --git a/src/native/ascend/ops/rms_norm/kernel.h b/src/native/ascend/ops/rms_norm/kernel.h new file mode 100644 index 000000000..143c99dfb --- /dev/null +++ b/src/native/ascend/ops/rms_norm/kernel.h @@ -0,0 +1,100 @@ +#ifndef INFINI_OPS_ASCEND_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_ASCEND_RMS_NORM_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_rms_norm.h" +#include "base/rms_norm.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public RmsNorm { + public: + Operator(const Tensor input, const Tensor weight, float eps, Tensor out) + : RmsNorm(input, weight, eps, out), + in_cache_(input), + weight_cache_(weight), + out_cache_(out) { + // `aclnnRmsNorm` writes `rstd` as a required side output. Size is + // computed here; the buffer is obtained from the pool in `operator()`. + rstd_shape_ = {static_cast(batch_size_), + static_cast(nhead_)}; + rstd_size_ = batch_size_ * nhead_ * sizeof(float); + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + in_cache_.release(); + weight_cache_.release(); + out_cache_.release(); + // `rstd_tensor_` leaks with the executor at shutdown (see `64c367c`). + } + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_weight = weight_cache_.get(const_cast(weight.data())); + auto t_out = out_cache_.get(out.data()); + auto stream = static_cast(stream_); + + // Obtain shared `rstd` buffer from pool. + auto& rstd_arena = + ascend::GetWorkspacePool().Ensure(stream, rstd_size_, "temp"); + + // Lazily create the `rstd` tensor descriptor on first call. + if (!rstd_tensor_) { + rstd_tensor_ = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, + rstd_shape_.data(), 2, rstd_arena.buf); + } else { + aclSetRawTensorAddr(rstd_tensor_, rstd_arena.buf); + } + + if (!executor_) { + aclnnRmsNormGetWorkspaceSize(t_in, t_weight, eps, t_out, rstd_tensor_, + &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_weight, + const_cast(weight.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + aclSetOutputTensorAddr(executor_, 1, rstd_tensor_, rstd_arena.buf); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnRmsNorm(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache weight_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; + + std::vector rstd_shape_; + + uint64_t rstd_size_ = 0; + + mutable aclTensor* rstd_tensor_ = nullptr; +}; + +} // namespace infini::ops + +#include "native/ascend/ops/rms_norm/kernel_custom.h" + +#endif diff --git a/src/native/ascend/ops/rms_norm/kernel_custom.h b/src/native/ascend/ops/rms_norm/kernel_custom.h new file mode 100644 index 000000000..11d56732b --- /dev/null +++ b/src/native/ascend/ops/rms_norm/kernel_custom.h @@ -0,0 +1,155 @@ +#ifndef INFINI_OPS_ASCEND_RMS_NORM_KERNEL_CUSTOM_H_ +#define INFINI_OPS_ASCEND_RMS_NORM_KERNEL_CUSTOM_H_ + +#ifdef INFINI_HAS_CUSTOM_KERNELS + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cast.h" +#include "base/rms_norm.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +// Forward-declare the `aclrtlaunch_RmsNorm` launch symbol defined by +// the AscendC toolchain from `custom/rms_norm/op_kernel/`. +extern "C" uint32_t aclrtlaunch_RmsNorm( + uint32_t block_dim, void* stream, void* input, void* weight, + int64_t total_rows, int64_t dim_length, int64_t dim_length_align, + int64_t former_num, int64_t former_length, int64_t tail_length, float eps, + int64_t dtype_code, void* out); + +namespace infini::ops { + +// Custom AscendC fused `RmsNorm` kernel (implementation index 1). +// +// A single-kernel implementation that computes `RMSNorm` in one launch, +// avoiding the 5-sub-op decomposition of `aclnnRmsNorm` (index 0). Uses +// `Sqrt` + scalar division instead of `Rsqrt` for higher precision (~1e-7 +// `fp32` error vs ~0.2% with `Rsqrt`). +// +// Select via `implementation_index=1` in Python: +// `infini.ops.rms_norm(input, weight, eps, out, implementation_index=1, +// stream=s)`. +// +// Requirements: +// - Input last dimension must be 32-byte aligned (divisible by 16 for +// `fp16` or 8 for `fp32`). All standard LLM hidden dimensions satisfy +// this. +// - `weight` must have the same dtype as `input`. +// - The custom kernel binary must be linked (`BUILD_ASCEND_CUSTOM=ON`). +template <> +class Operator : public RmsNorm { + public: + Operator(const Tensor input, const Tensor weight, float eps, Tensor out) + : RmsNorm(input, weight, eps, out), dtype_{input.dtype()} { + assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 || + dtype_ == DataType::kFloat32) && + "`RmsNorm` custom kernel: `input` must be `fp16`, `bf16`, or " + "`fp32`"); + + // 32-byte alignment on the last dimension — kernel relies on aligned + // `DataCopyPad` loads/stores. + int64_t align_elems = 32 / static_cast(kDataTypeToSize.at(dtype_)); + dim_length_align_ = + ((static_cast(dim_) + align_elems - 1) / align_elems) * + align_elems; + assert(static_cast(dim_) == dim_length_align_ && + "`RmsNorm` custom kernel: last dimension must be 32-byte aligned"); + + total_rows_ = + static_cast(batch_size_) * static_cast(nhead_); + + // The custom kernel always reads `weight` as fp32, so fp16 / bf16 + // inputs need a cached `aclnnCast` invocation in `operator()` to + // produce an fp32 shadow buffer on every launch. + if (dtype_ != DataType::kFloat32) { + size_t fp32_bytes = static_cast(dim_) * sizeof(float); + aclrtMalloc(&weight_fp32_data_, fp32_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + weight_src_cache_ = ascend::AclTensorCache( + {static_cast(dim_)}, ascend::ToAclDtype(dtype_), nullptr); + weight_dst_cache_ = ascend::AclTensorCache({static_cast(dim_)}, + ACL_FLOAT, weight_fp32_data_); + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + weight_src_cache_.release(); + weight_dst_cache_.release(); + + if (weight_fp32_data_) aclrtFree(weight_fp32_data_); + } + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + auto stream = static_cast(stream_); + + void* weight_fp32; + + if (dtype_ != DataType::kFloat32) { + auto t_src = weight_src_cache_.get(const_cast(weight.data())); + auto t_dst = weight_dst_cache_.get(weight_fp32_data_); + + if (!cast_exec_) { + aclnnCastGetWorkspaceSize(t_src, ACL_FLOAT, t_dst, &cast_ws_, + &cast_exec_); + aclSetAclOpExecutorRepeatable(cast_exec_); + } else { + aclSetInputTensorAddr(cast_exec_, 0, t_src, + const_cast(weight.data())); + aclSetOutputTensorAddr(cast_exec_, 0, t_dst, weight_fp32_data_); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, cast_ws_); + aclnnCast(arena.buf, cast_ws_, cast_exec_, stream); + weight_fp32 = weight_fp32_data_; + } else { + weight_fp32 = const_cast(weight.data()); + } + + // Block-level tiling. Ascend 910B has 20–40 AIV cores; over-subscribing + // is safe (runtime multiplexes) but wastes one weight load per block. + static constexpr int64_t kMaxBlockDim = 40; + int64_t used_cores = std::min(total_rows_, kMaxBlockDim); + int64_t former_length = (total_rows_ + used_cores - 1) / used_cores; + int64_t tail_length = former_length - 1; + int64_t former_num = total_rows_ - tail_length * used_cores; + uint32_t block_dim = static_cast(used_cores); + + aclrtlaunch_RmsNorm(block_dim, stream, const_cast(input.data()), + weight_fp32, total_rows_, static_cast(dim_), + dim_length_align_, former_num, former_length, + tail_length, eps, static_cast(dtype_), + out.data()); + } + + private: + DataType dtype_; + + int64_t dim_length_align_; + + int64_t total_rows_; + + void* weight_fp32_data_ = nullptr; + + mutable ascend::AclTensorCache weight_src_cache_; + + mutable ascend::AclTensorCache weight_dst_cache_; + + mutable aclOpExecutor* cast_exec_ = nullptr; + + mutable uint64_t cast_ws_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_HAS_CUSTOM_KERNELS +#endif // INFINI_OPS_ASCEND_RMS_NORM_KERNEL_CUSTOM_H_