diff --git a/src/native/ascend/ops/swiglu/kernel.h b/src/native/ascend/ops/swiglu/kernel.h new file mode 100644 index 000000000..795616e33 --- /dev/null +++ b/src/native/ascend/ops/swiglu/kernel.h @@ -0,0 +1,109 @@ +#ifndef INFINI_OPS_ASCEND_SWIGLU_KERNEL_H_ +#define INFINI_OPS_ASCEND_SWIGLU_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_mul.h" +#include "aclnn_silu.h" +#include "base/swiglu.h" +#include "data_type.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// Implements SwiGLU as two ACLNN calls: `aclnnSilu(gate)` into a `temp` +// buffer, then elementwise `aclnnMul(input, temp)` into `out`. +// `aclnnSiluMul` was not used because it fuses silu-and-mul on the same +// tensor (`x * silu(x)`), whereas SwiGLU requires `input * silu(gate)` — +// two distinct inputs. +template <> +class Operator : public Swiglu { + public: + Operator(const Tensor input, const Tensor gate, Tensor out) + : Swiglu(input, gate, out), + in_cache_(input), + gate_cache_(gate), + out_cache_(out) { + temp_size_ = input.numel() * kDataTypeToSize.at(input.dtype()); + + // Build the `temp` cache from `gate` geometry (contiguous, same + // shape/dtype). No data pointer yet — it is set on the first `get()` + // call. + Tensor temp_t{nullptr, gate.shape(), gate.dtype(), gate.device()}; + temp_cache_ = ascend::AclTensorCache(temp_t); + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + in_cache_.release(); + gate_cache_.release(); + out_cache_.release(); + temp_cache_.release(); + } + + void operator()(const Tensor input, const Tensor gate, + Tensor out) const override { + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_gate = gate_cache_.get(const_cast(gate.data())); + auto t_out = out_cache_.get(out.data()); + auto stream = static_cast(stream_); + + // Obtain shared `temp` buffer from the pool. + auto& temp = ascend::GetWorkspacePool().Ensure(stream, temp_size_, "temp"); + auto t_temp = temp_cache_.get(temp.buf); + + // Step 1: `silu(gate) -> temp`. + if (!silu_exec_) { + aclnnSiluGetWorkspaceSize(t_gate, t_temp, &silu_ws_, &silu_exec_); + aclSetAclOpExecutorRepeatable(silu_exec_); + } else { + aclSetInputTensorAddr(silu_exec_, 0, t_gate, + const_cast(gate.data())); + aclSetOutputTensorAddr(silu_exec_, 0, t_temp, temp.buf); + } + auto& silu_arena = ascend::GetWorkspacePool().Ensure(stream, silu_ws_); + aclnnSilu(silu_arena.buf, silu_ws_, silu_exec_, stream); + + // Step 2: `mul(input, temp) -> out`. + if (!mul_exec_) { + aclnnMulGetWorkspaceSize(t_in, t_temp, t_out, &mul_ws_, &mul_exec_); + aclSetAclOpExecutorRepeatable(mul_exec_); + } else { + aclSetInputTensorAddr(mul_exec_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(mul_exec_, 1, t_temp, temp.buf); + aclSetOutputTensorAddr(mul_exec_, 0, t_out, out.data()); + } + auto& mul_arena = ascend::GetWorkspacePool().Ensure(stream, mul_ws_); + aclnnMul(mul_arena.buf, mul_ws_, mul_exec_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache gate_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache temp_cache_; + + uint64_t temp_size_ = 0; + + mutable aclOpExecutor* silu_exec_ = nullptr; + + mutable uint64_t silu_ws_ = 0; + + mutable aclOpExecutor* mul_exec_ = nullptr; + + mutable uint64_t mul_ws_ = 0; +}; + +} // namespace infini::ops + +#include "native/ascend/ops/swiglu/kernel_fused.h" + +#endif diff --git a/src/native/ascend/ops/swiglu/kernel_fused.h b/src/native/ascend/ops/swiglu/kernel_fused.h new file mode 100644 index 000000000..6d8af71e9 --- /dev/null +++ b/src/native/ascend/ops/swiglu/kernel_fused.h @@ -0,0 +1,202 @@ +#ifndef INFINI_OPS_ASCEND_SWIGLU_KERNEL_FUSED_H_ +#define INFINI_OPS_ASCEND_SWIGLU_KERNEL_FUSED_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_copy.h" +#include "aclnnop/aclnn_cat.h" +#include "aclnnop/aclnn_swi_glu.h" +#include "base/swiglu.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// Fused implementation via `aclnnSwiGlu` (implementation index 1). +// +// Concatenates `[gate, input]` into a `temp` buffer via `aclnnCat`, then +// calls `aclnnSwiGlu` which computes `second_half * silu(first_half)` in a +// single fused kernel, i.e. `input * silu(gate)`. +// +// This trades an extra `aclnnCat` launch for a single fused SwiGLU kernel +// instead of separate `aclnnSilu` + `aclnnMul`. The net benefit is one +// fewer intermediate buffer materialised on-device (the `silu` temp is +// eliminated). +// +// `aclnnSwiGlu` requires a contiguous output tensor. When the caller's +// output is non-contiguous, a contiguous staging buffer is used and the +// result is copied back via `aclnnInplaceCopy`. +// +// Select via `implementation_index=1` in Python: +// `infini.ops.swiglu(..., implementation_index=1, stream=s)`. +template <> +class Operator : public Swiglu { + public: + Operator(const Tensor input, const Tensor gate, Tensor out) + : Swiglu(input, gate, out), + gate_cache_(gate), + in_cache_(input), + out_cache_(out) { + // Compute the concatenated shape: same as input but with last dim doubled. + cat_shape_.assign(input.shape().begin(), input.shape().end()); + cat_shape_.back() *= 2; + + uint64_t cat_elems = 1; + + for (auto d : cat_shape_) { + cat_elems *= static_cast(d); + } + + cat_size_ = cat_elems * kDataTypeToSize.at(input.dtype()); + + // `aclnnSwiGlu` ignores output strides and writes contiguously. + // When the output is non-contiguous we need a contiguous staging buffer. + needs_copy_ = !is_out_contiguous_; + + if (needs_copy_) { + out_staging_size_ = output_size_ * kDataTypeToSize.at(out.dtype()); + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. The inputs + // and outputs are referenced by the Repeatable executors (`cat_exec_`, + // `swiglu_exec_`, `copy_exec_`) via `cat_tensor_list_`; releasing them + // here prevents `~AclTensorCache()` from double-freeing at shutdown. + gate_cache_.release(); + in_cache_.release(); + out_cache_.release(); + + // Optional caches are held by `swiglu_exec_` / `copy_exec_`; release to + // avoid double-free on destruction. + if (cat_out_cache_) cat_out_cache_->release(); + if (out_staging_cache_) out_staging_cache_->release(); + + // `cat_tensor_list_` leaks with `cat_exec_` at shutdown (see `64c367c`). + } + + void operator()(const Tensor input, const Tensor gate, + Tensor out) const override { + auto t_gate = gate_cache_.get(const_cast(gate.data())); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_out = out_cache_.get(out.data()); + auto stream = static_cast(stream_); + + // Obtain shared `temp` buffer for the concatenated tensor. + auto& cat_arena = + ascend::GetWorkspacePool().Ensure(stream, cat_size_, "temp"); + + // Lazily build the `aclnnCat` output tensor cache on first call. + if (!cat_out_cache_) { + cat_out_cache_.emplace(cat_shape_, ascend::ToAclDtype(input_type_), + cat_arena.buf); + } + + auto t_cat = cat_out_cache_->get(cat_arena.buf); + + // Step 1: `aclnnCat([gate, input], dim=-1) -> cat_buf`. + if (!cat_exec_) { + aclTensor* tensors[2] = {t_gate, t_in}; + cat_tensor_list_ = + aclCreateTensorList(const_cast(tensors), 2); + aclnnCatGetWorkspaceSize(cat_tensor_list_, + static_cast(ndim_ - 1), t_cat, &cat_ws_, + &cat_exec_); + aclSetAclOpExecutorRepeatable(cat_exec_); + } else { + // The tensor list references the same `aclTensor*` objects whose data + // pointers were already updated by `get()` above. + aclSetOutputTensorAddr(cat_exec_, 0, t_cat, cat_arena.buf); + } + + auto& cat_ws_arena = ascend::GetWorkspacePool().Ensure(stream, cat_ws_); + aclnnCat(cat_ws_arena.buf, cat_ws_, cat_exec_, stream); + + // Step 2: `aclnnSwiGlu(cat_buf, dim=-1) -> out` (or staging buffer). + aclTensor* t_swiglu_out = t_out; + void* swiglu_out_data = out.data(); + + if (needs_copy_) { + auto& staging = ascend::GetWorkspacePool().Ensure( + stream, out_staging_size_, "staging"); + + if (!out_staging_cache_) { + std::vector out_shape(out_shape_.begin(), out_shape_.end()); + out_staging_cache_.emplace(out_shape, ascend::ToAclDtype(out_type_), + staging.buf); + } + + t_swiglu_out = out_staging_cache_->get(staging.buf); + swiglu_out_data = staging.buf; + } + + if (!swiglu_exec_) { + aclnnSwiGluGetWorkspaceSize(t_cat, static_cast(ndim_ - 1), + t_swiglu_out, &swiglu_ws_, &swiglu_exec_); + aclSetAclOpExecutorRepeatable(swiglu_exec_); + } else { + aclSetInputTensorAddr(swiglu_exec_, 0, t_cat, cat_arena.buf); + aclSetOutputTensorAddr(swiglu_exec_, 0, t_swiglu_out, swiglu_out_data); + } + + auto& swiglu_arena = ascend::GetWorkspacePool().Ensure(stream, swiglu_ws_); + aclnnSwiGlu(swiglu_arena.buf, swiglu_ws_, swiglu_exec_, stream); + + // Step 3 (non-contiguous output only): copy staging -> `out`. + if (needs_copy_) { + if (!copy_exec_) { + aclnnInplaceCopyGetWorkspaceSize(t_out, t_swiglu_out, ©_ws_, + ©_exec_); + aclSetAclOpExecutorRepeatable(copy_exec_); + } else { + aclSetInputTensorAddr(copy_exec_, 0, t_out, out.data()); + aclSetInputTensorAddr(copy_exec_, 1, t_swiglu_out, swiglu_out_data); + } + + auto& copy_arena = ascend::GetWorkspacePool().Ensure(stream, copy_ws_); + aclnnInplaceCopy(copy_arena.buf, copy_ws_, copy_exec_, stream); + } + } + + private: + mutable ascend::AclTensorCache gate_cache_; + + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable std::optional cat_out_cache_; + + mutable std::optional out_staging_cache_; + + std::vector cat_shape_; + + uint64_t cat_size_ = 0; + + bool needs_copy_ = false; + + uint64_t out_staging_size_ = 0; + + mutable aclTensorList* cat_tensor_list_ = nullptr; + + mutable aclOpExecutor* cat_exec_ = nullptr; + + mutable uint64_t cat_ws_ = 0; + + mutable aclOpExecutor* swiglu_exec_ = nullptr; + + mutable uint64_t swiglu_ws_ = 0; + + mutable aclOpExecutor* copy_exec_ = nullptr; + + mutable uint64_t copy_ws_ = 0; +}; + +} // namespace infini::ops + +#endif