diff --git a/src/base/silu_and_mul.h b/src/base/silu_and_mul.h new file mode 100644 index 000000000..6cede6e44 --- /dev/null +++ b/src/base/silu_and_mul.h @@ -0,0 +1,62 @@ +#ifndef INFINI_OPS_BASE_SILU_AND_MUL_H_ +#define INFINI_OPS_BASE_SILU_AND_MUL_H_ + +#include "operator.h" + +namespace infini::ops { + +// SiLU-gated linear unit: splits `input` along `dim` into two halves and +// computes `silu(first_half) * second_half`. Matches +// `vllm._C.silu_and_mul`; `dim` defaults to `-1` (PyTorch `F.glu` +// convention). +class SiluAndMul : public Operator { + public: + SiluAndMul(const Tensor input, int64_t dim, Tensor out) + : input_shape_{input.shape()}, + input_strides_{input.strides()}, + out_shape_{out.shape()}, + out_strides_{out.strides()}, + input_dtype_{input.dtype()}, + out_dtype_{out.dtype()}, + dim_{dim}, + ndim_{input.ndim()}, + is_input_contiguous_{input.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(input_dtype_ == out_dtype_ && + "`SiluAndMul`: `input` and `out` must have the same dtype"); + } + + SiluAndMul(const Tensor input, Tensor out) : SiluAndMul{input, -1, out} {} + + virtual void operator()(const Tensor input, int64_t dim, + Tensor out) const = 0; + + virtual void operator()(const Tensor input, Tensor out) const { + return operator()(input, -1, out); + } + + protected: + Tensor::Shape input_shape_; + + Tensor::Strides input_strides_; + + Tensor::Shape out_shape_; + + Tensor::Strides out_strides_; + + const DataType input_dtype_; + + const DataType out_dtype_; + + int64_t dim_; + + Tensor::Size ndim_; + + bool is_input_contiguous_; + + bool is_out_contiguous_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/ascend/ops/silu_and_mul/kernel.h b/src/native/ascend/ops/silu_and_mul/kernel.h new file mode 100644 index 000000000..763e5683e --- /dev/null +++ b/src/native/ascend/ops/silu_and_mul/kernel.h @@ -0,0 +1,127 @@ +#ifndef INFINI_OPS_ASCEND_SILU_AND_MUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_SILU_AND_MUL_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_copy.h" +#include "aclnnop/aclnn_swi_glu.h" +#include "base/silu_and_mul.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// Calls `aclnnSwiGlu` directly on the concatenated `x = [gate, up]` tensor. +// +// `aclnnSwiGlu` splits `x` along `dim` into `[first_half, second_half]` and +// computes `second_half * silu(first_half)`, i.e. `up * silu(gate)`. +// +// `aclnnSwiGlu` ignores output strides and writes contiguously. When the +// output is non-contiguous, a contiguous staging buffer is used and the +// result is copied back via `aclnnInplaceCopy`. +template <> +class Operator : public SiluAndMul { + public: + Operator(const Tensor input, int64_t dim, Tensor out) + : SiluAndMul(input, dim, out), input_cache_(input), out_cache_(out) { + needs_copy_ = !is_out_contiguous_; + + if (needs_copy_) { + out_staging_size_ = out.numel() * kDataTypeToSize.at(out.dtype()); + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. Inputs and + // outputs are referenced by the Repeatable executors (`swiglu_exec_`, + // `copy_exec_`); releasing them here prevents `~AclTensorCache()` from + // double-freeing at shutdown. + input_cache_.release(); + out_cache_.release(); + + // The staging cache is held by `swiglu_exec_` / `copy_exec_`; release to + // avoid double-free on destruction. + if (out_staging_cache_) out_staging_cache_->release(); + } + + void operator()(const Tensor input, int64_t dim, Tensor out) const override { + auto t_input = input_cache_.get(const_cast(input.data())); + auto t_out = out_cache_.get(out.data()); + auto stream = static_cast(stream_); + + // Determine effective output target. + aclTensor* t_swiglu_out = t_out; + void* swiglu_out_data = out.data(); + + if (needs_copy_) { + auto& staging = ascend::GetWorkspacePool().Ensure( + stream, out_staging_size_, "staging"); + + if (!out_staging_cache_) { + std::vector out_shape(out_shape_.begin(), out_shape_.end()); + out_staging_cache_.emplace(out_shape, ascend::ToAclDtype(out_dtype_), + staging.buf); + } + + t_swiglu_out = out_staging_cache_->get(staging.buf); + swiglu_out_data = staging.buf; + } + + // Call `aclnnSwiGlu`. + if (!swiglu_exec_) { + aclnnSwiGluGetWorkspaceSize(t_input, dim_, t_swiglu_out, &swiglu_ws_, + &swiglu_exec_); + aclSetAclOpExecutorRepeatable(swiglu_exec_); + } else { + aclSetInputTensorAddr(swiglu_exec_, 0, t_input, + const_cast(input.data())); + aclSetOutputTensorAddr(swiglu_exec_, 0, t_swiglu_out, swiglu_out_data); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, swiglu_ws_); + aclnnSwiGlu(arena.buf, swiglu_ws_, swiglu_exec_, stream); + + // Copy staging buffer back to non-contiguous output if needed. + if (needs_copy_) { + if (!copy_exec_) { + aclnnInplaceCopyGetWorkspaceSize(t_out, t_swiglu_out, ©_ws_, + ©_exec_); + aclSetAclOpExecutorRepeatable(copy_exec_); + } else { + aclSetInputTensorAddr(copy_exec_, 0, t_out, out.data()); + aclSetInputTensorAddr(copy_exec_, 1, t_swiglu_out, swiglu_out_data); + } + + auto& copy_arena = ascend::GetWorkspacePool().Ensure(stream, copy_ws_); + aclnnInplaceCopy(copy_arena.buf, copy_ws_, copy_exec_, stream); + } + } + + private: + mutable ascend::AclTensorCache input_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable std::optional out_staging_cache_; + + bool needs_copy_ = false; + + uint64_t out_staging_size_ = 0; + + mutable aclOpExecutor* swiglu_exec_ = nullptr; + + mutable uint64_t swiglu_ws_ = 0; + + mutable aclOpExecutor* copy_exec_ = nullptr; + + mutable uint64_t copy_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_silu_and_mul.py b/tests/test_silu_and_mul.py new file mode 100644 index 000000000..c1bb62e4b --- /dev/null +++ b/tests/test_silu_and_mul.py @@ -0,0 +1,76 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, rand_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, x_strides, out_strides", + ( + ((13, 8), None, None), + ((16, 11264), None, None), + ((4, 4, 11264), None, None), + ((1, 8), None, None), + ((32, 5632), None, None), + # Non-contiguous `x` (inner stride > inner dim doubled). + ((13, 8), (16, 1), (4, 1)), + # Non-contiguous across all dims (3-D with larger outer stride). + ((4, 4, 16), (128, 16, 1), (64, 8, 1)), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_silu_and_mul( + shape, + x_strides, + out_strides, + implementation_index, + dtype, + device, + rtol, + atol, +): + x = rand_strided(shape, x_strides, dtype=dtype, device=device) + d = shape[-1] // 2 + out_shape = (*shape[:-1], d) + out = empty_strided(out_shape, out_strides, dtype=dtype, device=device) + + return Payload( + lambda *args, **kwargs: _silu_and_mul( + *args, **kwargs, implementation_index=implementation_index + ), + _torch_silu_and_mul, + (x, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _silu_and_mul(x, out, implementation_index=0): + infini.ops.silu_and_mul( + x, + -1, + out, + implementation_index=implementation_index, + stream=get_stream(x.device), + ) + + return out + + +def _torch_silu_and_mul(x, out): + d = x.shape[-1] // 2 + gate = x[..., :d] + up = x[..., d:] + result = up * torch.sigmoid(gate) * gate + + return result.to(out.dtype)