Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/base/silu_and_mul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#ifndef INFINI_OPS_BASE_SILU_AND_MUL_H_
#define INFINI_OPS_BASE_SILU_AND_MUL_H_

#include "operator.h"

namespace infini::ops {

// SiLU-gated linear unit: splits `input` along `dim` into two halves and
// computes `silu(first_half) * second_half`. Matches
// `vllm._C.silu_and_mul`; `dim` defaults to `-1` (PyTorch `F.glu`
// convention).
class SiluAndMul : public Operator<SiluAndMul> {
public:
SiluAndMul(const Tensor input, int64_t dim, Tensor out)
: input_shape_{input.shape()},
input_strides_{input.strides()},
out_shape_{out.shape()},
out_strides_{out.strides()},
input_dtype_{input.dtype()},
out_dtype_{out.dtype()},
dim_{dim},
ndim_{input.ndim()},
is_input_contiguous_{input.IsContiguous()},
is_out_contiguous_{out.IsContiguous()} {
assert(input_dtype_ == out_dtype_ &&
"`SiluAndMul`: `input` and `out` must have the same dtype");
}

SiluAndMul(const Tensor input, Tensor out) : SiluAndMul{input, -1, out} {}

virtual void operator()(const Tensor input, int64_t dim,
Tensor out) const = 0;

virtual void operator()(const Tensor input, Tensor out) const {
return operator()(input, -1, out);
}

protected:
Tensor::Shape input_shape_;

Tensor::Strides input_strides_;

Tensor::Shape out_shape_;

Tensor::Strides out_strides_;

const DataType input_dtype_;

const DataType out_dtype_;

int64_t dim_;

Tensor::Size ndim_;

bool is_input_contiguous_;

bool is_out_contiguous_;
};

} // namespace infini::ops

#endif
127 changes: 127 additions & 0 deletions src/native/ascend/ops/silu_and_mul/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#ifndef INFINI_OPS_ASCEND_SILU_AND_MUL_KERNEL_H_
#define INFINI_OPS_ASCEND_SILU_AND_MUL_KERNEL_H_

#include <vector>

#include "acl/acl.h"
#include "aclnn/aclnn_base.h"
#include "aclnn_copy.h"
#include "aclnnop/aclnn_swi_glu.h"
#include "base/silu_and_mul.h"
#include "native/ascend/common.h"
#include "native/ascend/workspace_pool_.h"
#include "operator.h"

namespace infini::ops {

// Calls `aclnnSwiGlu` directly on the concatenated `x = [gate, up]` tensor.
//
// `aclnnSwiGlu` splits `x` along `dim` into `[first_half, second_half]` and
// computes `second_half * silu(first_half)`, i.e. `up * silu(gate)`.
//
// `aclnnSwiGlu` ignores output strides and writes contiguously. When the
// output is non-contiguous, a contiguous staging buffer is used and the
// result is copied back via `aclnnInplaceCopy`.
template <>
class Operator<SiluAndMul, Device::Type::kAscend, 0> : public SiluAndMul {
public:
Operator(const Tensor input, int64_t dim, Tensor out)
: SiluAndMul(input, dim, out), input_cache_(input), out_cache_(out) {
needs_copy_ = !is_out_contiguous_;

if (needs_copy_) {
out_staging_size_ = out.numel() * kDataTypeToSize.at(out.dtype());
}
}

~Operator() {
if (!ascend::IsAclRuntimeAlive()) return;

// Null cached descriptors — see `AclTensorCache::release()`. Inputs and
// outputs are referenced by the Repeatable executors (`swiglu_exec_`,
// `copy_exec_`); releasing them here prevents `~AclTensorCache()` from
// double-freeing at shutdown.
input_cache_.release();
out_cache_.release();

// The staging cache is held by `swiglu_exec_` / `copy_exec_`; release to
// avoid double-free on destruction.
if (out_staging_cache_) out_staging_cache_->release();
}

void operator()(const Tensor input, int64_t dim, Tensor out) const override {
auto t_input = input_cache_.get(const_cast<void*>(input.data()));
auto t_out = out_cache_.get(out.data());
auto stream = static_cast<aclrtStream>(stream_);

// Determine effective output target.
aclTensor* t_swiglu_out = t_out;
void* swiglu_out_data = out.data();

if (needs_copy_) {
auto& staging = ascend::GetWorkspacePool().Ensure(
stream, out_staging_size_, "staging");

if (!out_staging_cache_) {
std::vector<int64_t> out_shape(out_shape_.begin(), out_shape_.end());
out_staging_cache_.emplace(out_shape, ascend::ToAclDtype(out_dtype_),
staging.buf);
}

t_swiglu_out = out_staging_cache_->get(staging.buf);
swiglu_out_data = staging.buf;
}

// Call `aclnnSwiGlu`.
if (!swiglu_exec_) {
aclnnSwiGluGetWorkspaceSize(t_input, dim_, t_swiglu_out, &swiglu_ws_,
&swiglu_exec_);
aclSetAclOpExecutorRepeatable(swiglu_exec_);
} else {
aclSetInputTensorAddr(swiglu_exec_, 0, t_input,
const_cast<void*>(input.data()));
aclSetOutputTensorAddr(swiglu_exec_, 0, t_swiglu_out, swiglu_out_data);
}

auto& arena = ascend::GetWorkspacePool().Ensure(stream, swiglu_ws_);
aclnnSwiGlu(arena.buf, swiglu_ws_, swiglu_exec_, stream);

// Copy staging buffer back to non-contiguous output if needed.
if (needs_copy_) {
if (!copy_exec_) {
aclnnInplaceCopyGetWorkspaceSize(t_out, t_swiglu_out, &copy_ws_,
&copy_exec_);
aclSetAclOpExecutorRepeatable(copy_exec_);
} else {
aclSetInputTensorAddr(copy_exec_, 0, t_out, out.data());
aclSetInputTensorAddr(copy_exec_, 1, t_swiglu_out, swiglu_out_data);
}

auto& copy_arena = ascend::GetWorkspacePool().Ensure(stream, copy_ws_);
aclnnInplaceCopy(copy_arena.buf, copy_ws_, copy_exec_, stream);
}
}

private:
mutable ascend::AclTensorCache input_cache_;

mutable ascend::AclTensorCache out_cache_;

mutable std::optional<ascend::AclTensorCache> out_staging_cache_;

bool needs_copy_ = false;

uint64_t out_staging_size_ = 0;

mutable aclOpExecutor* swiglu_exec_ = nullptr;

mutable uint64_t swiglu_ws_ = 0;

mutable aclOpExecutor* copy_exec_ = nullptr;

mutable uint64_t copy_ws_ = 0;
};

} // namespace infini::ops

#endif
76 changes: 76 additions & 0 deletions tests/test_silu_and_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import infini.ops
import pytest
import torch

from tests.utils import Payload, empty_strided, get_stream, rand_strided


@pytest.mark.auto_act_and_assert
@pytest.mark.parametrize(
"shape, x_strides, out_strides",
(
((13, 8), None, None),
((16, 11264), None, None),
((4, 4, 11264), None, None),
((1, 8), None, None),
((32, 5632), None, None),
# Non-contiguous `x` (inner stride > inner dim doubled).
((13, 8), (16, 1), (4, 1)),
# Non-contiguous across all dims (3-D with larger outer stride).
((4, 4, 16), (128, 16, 1), (64, 8, 1)),
),
)
@pytest.mark.parametrize(
("dtype", "rtol", "atol"),
(
(torch.float32, 1e-7, 1e-7),
(torch.float16, 1e-3, 1e-3),
(torch.bfloat16, 1e-2, 5e-3),
),
)
def test_silu_and_mul(
shape,
x_strides,
out_strides,
implementation_index,
dtype,
device,
rtol,
atol,
):
x = rand_strided(shape, x_strides, dtype=dtype, device=device)
d = shape[-1] // 2
out_shape = (*shape[:-1], d)
out = empty_strided(out_shape, out_strides, dtype=dtype, device=device)

return Payload(
lambda *args, **kwargs: _silu_and_mul(
*args, **kwargs, implementation_index=implementation_index
),
_torch_silu_and_mul,
(x, out),
{},
rtol=rtol,
atol=atol,
)


def _silu_and_mul(x, out, implementation_index=0):
infini.ops.silu_and_mul(
x,
-1,
out,
implementation_index=implementation_index,
stream=get_stream(x.device),
)

return out


def _torch_silu_and_mul(x, out):
d = x.shape[-1] // 2
gate = x[..., :d]
up = x[..., d:]
result = up * torch.sigmoid(gate) * gate

return result.to(out.dtype)
Loading