diff --git a/src/native/ascend/ops/causal_softmax/kernel.h b/src/native/ascend/ops/causal_softmax/kernel.h new file mode 100644 index 000000000..a2c704bbc --- /dev/null +++ b/src/native/ascend/ops/causal_softmax/kernel.h @@ -0,0 +1,173 @@ +#ifndef INFINI_OPS_ASCEND_CAUSAL_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAUSAL_SOFTMAX_KERNEL_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_copy.h" +#include "aclnn_masked_fill_scalar.h" +#include "aclnn_softmax.h" +#include "base/causal_softmax.h" +#include "data_type.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// CANN 8.5 has no single API covering causal-mask-then-softmax: the nearest +// candidates (`aclnnSoftmaxV2`, `aclnnScaledSoftmaxGrad`) do not accept a +// boolean mask argument, and `aclnnScaledMaskedSoftmax` requires a +// pre-scaled attention-score tensor produced inside flash-attention, not a +// standalone softmax input. Decomposing into three ACLNN calls is therefore +// unavoidable until a `aclnnCausalSoftmax` ships: +// 1. `aclnnInplaceCopy(temp, input)` — stride-aware copy to a contiguous +// `temp` buffer. +// 2. `aclnnInplaceMaskedFillScalar(temp, mask, -inf)` — apply the +// upper-triangle mask. +// 3. `aclnnSoftmax(temp, dim=-1, out)` — softmax over the last dimension. +// +// The boolean causal mask is pre-computed and uploaded to device once in the +// constructor. Its shape `(seq_len, total_seq_len)` broadcasts over the +// batch dimension. +template <> +class Operator : public CausalSoftmax { + public: + Operator(const Tensor input, Tensor out) + : CausalSoftmax(input, out), in_cache_(input), out_cache_(out) { + // Compute `temp` buffer size — allocated lazily from the pool in + // `operator()`. + size_t n_elems = input.numel(); + size_t elem_bytes = kDataTypeToSize.at(dtype_); + temp_size_ = n_elems * elem_bytes; + + // Build a contiguous `Tensor` descriptor — data pointer set on first use. + Tensor temp_t{nullptr, input.shape(), input.dtype(), input.device()}; + temp_cache_ = ascend::AclTensorCache(temp_t); + + // Causal mask: `mask[i][j] = 1` when position `j` must be masked for + // query `i`. Shape `(seq_len, total_seq_len)` broadcasts over the batch + // dimension. + size_t mask_elems = seq_len_ * total_seq_len_; + std::vector mask_host(mask_elems, 0); + + for (size_t i = 0; i < seq_len_; ++i) { + auto vis_end = static_cast(total_seq_len_ - seq_len_ + i); + + for (auto j = vis_end + 1; j < static_cast(total_seq_len_); + ++j) { + mask_host[i * total_seq_len_ + j] = 1; + } + } + + aclrtMalloc(&mask_buf_, mask_elems, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpy(mask_buf_, mask_elems, mask_host.data(), mask_elems, + ACL_MEMCPY_HOST_TO_DEVICE); + + std::vector mshape = {static_cast(seq_len_), + static_cast(total_seq_len_)}; + std::vector mstrides = {static_cast(total_seq_len_), 1}; + mask_tensor_ = aclCreateTensor(mshape.data(), mshape.size(), ACL_BOOL, + mstrides.data(), 0, ACL_FORMAT_ND, + mshape.data(), mshape.size(), mask_buf_); + + // Scalar `-inf` for the masked-fill step. `aclCreateScalar` stores the + // pointer rather than copying, so `neg_inf_storage_` must stay alive + // with the object. + neg_inf_ = aclCreateScalar(&neg_inf_storage_, ACL_FLOAT); + // Workspaces are allocated lazily on the first `operator()` call. + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + in_cache_.release(); + out_cache_.release(); + temp_cache_.release(); + + // `mask_tensor_` leaks with `fill_exec_` at shutdown (see `64c367c`). + if (mask_buf_) aclrtFree(mask_buf_); + if (neg_inf_) aclDestroyScalar(neg_inf_); + } + + void operator()(const Tensor input, Tensor out) const override { + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_out = out_cache_.get(out.data()); + auto stream = static_cast(stream_); + + // Obtain shared `temp` buffer from the pool. + auto& temp = ascend::GetWorkspacePool().Ensure(stream, temp_size_, "temp"); + auto t_temp = temp_cache_.get(temp.buf); + + // Step 1: copy `input` (possibly non-contiguous) into a contiguous `temp`. + if (!copy_exec_) { + aclnnInplaceCopyGetWorkspaceSize(t_temp, t_in, ©_ws_, ©_exec_); + aclSetAclOpExecutorRepeatable(copy_exec_); + } else { + aclSetInputTensorAddr(copy_exec_, 0, t_temp, temp.buf); + aclSetInputTensorAddr(copy_exec_, 1, t_in, + const_cast(input.data())); + } + auto& copy_arena = ascend::GetWorkspacePool().Ensure(stream, copy_ws_); + aclnnInplaceCopy(copy_arena.buf, copy_ws_, copy_exec_, stream); + + // Step 2: mask upper-triangle positions with `-inf` in-place. + // `mask_tensor_` and `neg_inf_` have stable addresses — first-call only. + if (!fill_exec_) { + aclnnInplaceMaskedFillScalarGetWorkspaceSize( + t_temp, mask_tensor_, neg_inf_, &fill_ws_, &fill_exec_); + aclSetAclOpExecutorRepeatable(fill_exec_); + } + auto& fill_arena = ascend::GetWorkspacePool().Ensure(stream, fill_ws_); + aclnnInplaceMaskedFillScalar(fill_arena.buf, fill_ws_, fill_exec_, stream); + + // Step 3: softmax over the last dimension -> `out`. + if (!softmax_exec_) { + constexpr int64_t kLastDim = -1; + aclnnSoftmaxGetWorkspaceSize(t_temp, kLastDim, t_out, &softmax_ws_, + &softmax_exec_); + aclSetAclOpExecutorRepeatable(softmax_exec_); + } else { + aclSetOutputTensorAddr(softmax_exec_, 0, t_out, out.data()); + } + auto& softmax_arena = + ascend::GetWorkspacePool().Ensure(stream, softmax_ws_); + aclnnSoftmax(softmax_arena.buf, softmax_ws_, softmax_exec_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache temp_cache_; + + float neg_inf_storage_ = -std::numeric_limits::infinity(); + + uint64_t temp_size_ = 0; + + void* mask_buf_ = nullptr; + + aclTensor* mask_tensor_ = nullptr; + + aclScalar* neg_inf_ = nullptr; + + mutable aclOpExecutor* copy_exec_ = nullptr; + + mutable uint64_t copy_ws_ = 0; + + mutable aclOpExecutor* fill_exec_ = nullptr; + + mutable uint64_t fill_ws_ = 0; + + mutable aclOpExecutor* softmax_exec_ = nullptr; + + mutable uint64_t softmax_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_causal_softmax.py b/tests/test_causal_softmax.py index ef48601f7..79e6be8e0 100644 --- a/tests/test_causal_softmax.py +++ b/tests/test_causal_softmax.py @@ -46,10 +46,9 @@ def _causal_softmax(input, out): def _torch_causal_softmax(input, out): - input_cpu = input.detach().cpu().to(torch.float32) - mask = torch.tril(torch.ones_like(input_cpu), diagonal=-1).flip(dims=[-2, -1]) - masked = torch.where(mask == 1, -torch.inf, input_cpu) + mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1]) + masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32)) result = torch.nn.functional.softmax(masked, dim=-1) - out.copy_(result.to(device=out.device, dtype=out.dtype)) + out.copy_(result) return out