Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions src/native/ascend/ops/swiglu/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#ifndef INFINI_OPS_ASCEND_SWIGLU_KERNEL_H_
#define INFINI_OPS_ASCEND_SWIGLU_KERNEL_H_

#include "acl/acl.h"
#include "aclnn/aclnn_base.h"
#include "aclnn_mul.h"
#include "aclnn_silu.h"
#include "base/swiglu.h"
#include "data_type.h"
#include "native/ascend/common.h"
#include "native/ascend/workspace_pool_.h"
#include "operator.h"

namespace infini::ops {

// Implements SwiGLU as two ACLNN calls: `aclnnSilu(gate)` into a `temp`
// buffer, then elementwise `aclnnMul(input, temp)` into `out`.
// `aclnnSiluMul` was not used because it fuses silu-and-mul on the same
// tensor (`x * silu(x)`), whereas SwiGLU requires `input * silu(gate)` —
// two distinct inputs.
template <>
class Operator<Swiglu, Device::Type::kAscend, 0> : public Swiglu {
public:
Operator(const Tensor input, const Tensor gate, Tensor out)
: Swiglu(input, gate, out),
in_cache_(input),
gate_cache_(gate),
out_cache_(out) {
temp_size_ = input.numel() * kDataTypeToSize.at(input.dtype());

// Build the `temp` cache from `gate` geometry (contiguous, same
// shape/dtype). No data pointer yet — it is set on the first `get()`
// call.
Tensor temp_t{nullptr, gate.shape(), gate.dtype(), gate.device()};
temp_cache_ = ascend::AclTensorCache(temp_t);
}

~Operator() {
if (!ascend::IsAclRuntimeAlive()) return;

// Null cached descriptors — see `AclTensorCache::release()`.
in_cache_.release();
gate_cache_.release();
out_cache_.release();
temp_cache_.release();
}

void operator()(const Tensor input, const Tensor gate,
Tensor out) const override {
auto t_in = in_cache_.get(const_cast<void*>(input.data()));
auto t_gate = gate_cache_.get(const_cast<void*>(gate.data()));
auto t_out = out_cache_.get(out.data());
auto stream = static_cast<aclrtStream>(stream_);

// Obtain shared `temp` buffer from the pool.
auto& temp = ascend::GetWorkspacePool().Ensure(stream, temp_size_, "temp");
auto t_temp = temp_cache_.get(temp.buf);

// Step 1: `silu(gate) -> temp`.
if (!silu_exec_) {
aclnnSiluGetWorkspaceSize(t_gate, t_temp, &silu_ws_, &silu_exec_);
aclSetAclOpExecutorRepeatable(silu_exec_);
} else {
aclSetInputTensorAddr(silu_exec_, 0, t_gate,
const_cast<void*>(gate.data()));
aclSetOutputTensorAddr(silu_exec_, 0, t_temp, temp.buf);
}
auto& silu_arena = ascend::GetWorkspacePool().Ensure(stream, silu_ws_);
aclnnSilu(silu_arena.buf, silu_ws_, silu_exec_, stream);

// Step 2: `mul(input, temp) -> out`.
if (!mul_exec_) {
aclnnMulGetWorkspaceSize(t_in, t_temp, t_out, &mul_ws_, &mul_exec_);
aclSetAclOpExecutorRepeatable(mul_exec_);
} else {
aclSetInputTensorAddr(mul_exec_, 0, t_in,
const_cast<void*>(input.data()));
aclSetInputTensorAddr(mul_exec_, 1, t_temp, temp.buf);
aclSetOutputTensorAddr(mul_exec_, 0, t_out, out.data());
}
auto& mul_arena = ascend::GetWorkspacePool().Ensure(stream, mul_ws_);
aclnnMul(mul_arena.buf, mul_ws_, mul_exec_, stream);
}

private:
mutable ascend::AclTensorCache in_cache_;

mutable ascend::AclTensorCache gate_cache_;

mutable ascend::AclTensorCache out_cache_;

mutable ascend::AclTensorCache temp_cache_;

uint64_t temp_size_ = 0;

mutable aclOpExecutor* silu_exec_ = nullptr;

mutable uint64_t silu_ws_ = 0;

mutable aclOpExecutor* mul_exec_ = nullptr;

mutable uint64_t mul_ws_ = 0;
};

} // namespace infini::ops

#include "native/ascend/ops/swiglu/kernel_fused.h"

#endif
202 changes: 202 additions & 0 deletions src/native/ascend/ops/swiglu/kernel_fused.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
#ifndef INFINI_OPS_ASCEND_SWIGLU_KERNEL_FUSED_H_
#define INFINI_OPS_ASCEND_SWIGLU_KERNEL_FUSED_H_

#include <vector>

#include "acl/acl.h"
#include "aclnn/aclnn_base.h"
#include "aclnn_copy.h"
#include "aclnnop/aclnn_cat.h"
#include "aclnnop/aclnn_swi_glu.h"
#include "base/swiglu.h"
#include "native/ascend/common.h"
#include "native/ascend/workspace_pool_.h"
#include "operator.h"

namespace infini::ops {

// Fused implementation via `aclnnSwiGlu` (implementation index 1).
//
// Concatenates `[gate, input]` into a `temp` buffer via `aclnnCat`, then
// calls `aclnnSwiGlu` which computes `second_half * silu(first_half)` in a
// single fused kernel, i.e. `input * silu(gate)`.
//
// This trades an extra `aclnnCat` launch for a single fused SwiGLU kernel
// instead of separate `aclnnSilu` + `aclnnMul`. The net benefit is one
// fewer intermediate buffer materialised on-device (the `silu` temp is
// eliminated).
//
// `aclnnSwiGlu` requires a contiguous output tensor. When the caller's
// output is non-contiguous, a contiguous staging buffer is used and the
// result is copied back via `aclnnInplaceCopy`.
//
// Select via `implementation_index=1` in Python:
// `infini.ops.swiglu(..., implementation_index=1, stream=s)`.
template <>
class Operator<Swiglu, Device::Type::kAscend, 1> : public Swiglu {
public:
Operator(const Tensor input, const Tensor gate, Tensor out)
: Swiglu(input, gate, out),
gate_cache_(gate),
in_cache_(input),
out_cache_(out) {
// Compute the concatenated shape: same as input but with last dim doubled.
cat_shape_.assign(input.shape().begin(), input.shape().end());
cat_shape_.back() *= 2;

uint64_t cat_elems = 1;

for (auto d : cat_shape_) {
cat_elems *= static_cast<uint64_t>(d);
}

cat_size_ = cat_elems * kDataTypeToSize.at(input.dtype());

// `aclnnSwiGlu` ignores output strides and writes contiguously.
// When the output is non-contiguous we need a contiguous staging buffer.
needs_copy_ = !is_out_contiguous_;

if (needs_copy_) {
out_staging_size_ = output_size_ * kDataTypeToSize.at(out.dtype());
}
}

~Operator() {
if (!ascend::IsAclRuntimeAlive()) return;

// Null cached descriptors — see `AclTensorCache::release()`. The inputs
// and outputs are referenced by the Repeatable executors (`cat_exec_`,
// `swiglu_exec_`, `copy_exec_`) via `cat_tensor_list_`; releasing them
// here prevents `~AclTensorCache()` from double-freeing at shutdown.
gate_cache_.release();
in_cache_.release();
out_cache_.release();

// Optional caches are held by `swiglu_exec_` / `copy_exec_`; release to
// avoid double-free on destruction.
if (cat_out_cache_) cat_out_cache_->release();
if (out_staging_cache_) out_staging_cache_->release();

// `cat_tensor_list_` leaks with `cat_exec_` at shutdown (see `64c367c`).
}

void operator()(const Tensor input, const Tensor gate,
Tensor out) const override {
auto t_gate = gate_cache_.get(const_cast<void*>(gate.data()));
auto t_in = in_cache_.get(const_cast<void*>(input.data()));
auto t_out = out_cache_.get(out.data());
auto stream = static_cast<aclrtStream>(stream_);

// Obtain shared `temp` buffer for the concatenated tensor.
auto& cat_arena =
ascend::GetWorkspacePool().Ensure(stream, cat_size_, "temp");

// Lazily build the `aclnnCat` output tensor cache on first call.
if (!cat_out_cache_) {
cat_out_cache_.emplace(cat_shape_, ascend::ToAclDtype(input_type_),
cat_arena.buf);
}

auto t_cat = cat_out_cache_->get(cat_arena.buf);

// Step 1: `aclnnCat([gate, input], dim=-1) -> cat_buf`.
if (!cat_exec_) {
aclTensor* tensors[2] = {t_gate, t_in};
cat_tensor_list_ =
aclCreateTensorList(const_cast<const aclTensor**>(tensors), 2);
aclnnCatGetWorkspaceSize(cat_tensor_list_,
static_cast<int64_t>(ndim_ - 1), t_cat, &cat_ws_,
&cat_exec_);
aclSetAclOpExecutorRepeatable(cat_exec_);
} else {
// The tensor list references the same `aclTensor*` objects whose data
// pointers were already updated by `get()` above.
aclSetOutputTensorAddr(cat_exec_, 0, t_cat, cat_arena.buf);
}

auto& cat_ws_arena = ascend::GetWorkspacePool().Ensure(stream, cat_ws_);
aclnnCat(cat_ws_arena.buf, cat_ws_, cat_exec_, stream);

// Step 2: `aclnnSwiGlu(cat_buf, dim=-1) -> out` (or staging buffer).
aclTensor* t_swiglu_out = t_out;
void* swiglu_out_data = out.data();

if (needs_copy_) {
auto& staging = ascend::GetWorkspacePool().Ensure(
stream, out_staging_size_, "staging");

if (!out_staging_cache_) {
std::vector<int64_t> out_shape(out_shape_.begin(), out_shape_.end());
out_staging_cache_.emplace(out_shape, ascend::ToAclDtype(out_type_),
staging.buf);
}

t_swiglu_out = out_staging_cache_->get(staging.buf);
swiglu_out_data = staging.buf;
}

if (!swiglu_exec_) {
aclnnSwiGluGetWorkspaceSize(t_cat, static_cast<int64_t>(ndim_ - 1),
t_swiglu_out, &swiglu_ws_, &swiglu_exec_);
aclSetAclOpExecutorRepeatable(swiglu_exec_);
} else {
aclSetInputTensorAddr(swiglu_exec_, 0, t_cat, cat_arena.buf);
aclSetOutputTensorAddr(swiglu_exec_, 0, t_swiglu_out, swiglu_out_data);
}

auto& swiglu_arena = ascend::GetWorkspacePool().Ensure(stream, swiglu_ws_);
aclnnSwiGlu(swiglu_arena.buf, swiglu_ws_, swiglu_exec_, stream);

// Step 3 (non-contiguous output only): copy staging -> `out`.
if (needs_copy_) {
if (!copy_exec_) {
aclnnInplaceCopyGetWorkspaceSize(t_out, t_swiglu_out, &copy_ws_,
&copy_exec_);
aclSetAclOpExecutorRepeatable(copy_exec_);
} else {
aclSetInputTensorAddr(copy_exec_, 0, t_out, out.data());
aclSetInputTensorAddr(copy_exec_, 1, t_swiglu_out, swiglu_out_data);
}

auto& copy_arena = ascend::GetWorkspacePool().Ensure(stream, copy_ws_);
aclnnInplaceCopy(copy_arena.buf, copy_ws_, copy_exec_, stream);
}
}

private:
mutable ascend::AclTensorCache gate_cache_;

mutable ascend::AclTensorCache in_cache_;

mutable ascend::AclTensorCache out_cache_;

mutable std::optional<ascend::AclTensorCache> cat_out_cache_;

mutable std::optional<ascend::AclTensorCache> out_staging_cache_;

std::vector<int64_t> cat_shape_;

uint64_t cat_size_ = 0;

bool needs_copy_ = false;

uint64_t out_staging_size_ = 0;

mutable aclTensorList* cat_tensor_list_ = nullptr;

mutable aclOpExecutor* cat_exec_ = nullptr;

mutable uint64_t cat_ws_ = 0;

mutable aclOpExecutor* swiglu_exec_ = nullptr;

mutable uint64_t swiglu_ws_ = 0;

mutable aclOpExecutor* copy_exec_ = nullptr;

mutable uint64_t copy_ws_ = 0;
};

} // namespace infini::ops

#endif
Loading