Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 199 additions & 0 deletions src/base/mha_varlen_fwd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
#ifndef INFINI_OPS_BASE_MHA_VARLEN_FWD_H_
#define INFINI_OPS_BASE_MHA_VARLEN_FWD_H_

#include <cassert>
#include <cmath>
#include <optional>

#include "data_type.h"
#include "operator.h"

namespace infini::ops {

// FlashAttention-compatible variable-length forward attention.
//
// Layout follows `flash_attn` `varlen_fwd`:
// `q`: `[total_q, num_heads, head_size]`.
// `k` / `v`: `[total_k, num_kv_heads, head_size]`.
// Paged `k` / `v`: `[num_blocks, block_size, num_kv_heads, head_size]`
// when `block_table` is supplied.
// `cu_seqlens_q` / `cu_seqlens_k`: `[batch_size + 1]`.
//
// InfiniOps is an in-place operator API, so `out` must be supplied even
// though FlashAttention can allocate it when omitted.
// This signature intentionally keeps `out` near the FlashAttention argument
// position for compatibility with existing callers; this is an exception to
// the normal InfiniOps output-last convention.
class MhaVarlenFwd : public Operator<MhaVarlenFwd> {
public:
MhaVarlenFwd(const Tensor q, const Tensor k, const Tensor v, Tensor out,
const Tensor cu_seqlens_q, const Tensor cu_seqlens_k,
std::optional<Tensor> seqused_k, std::optional<Tensor> leftpad_k,
std::optional<Tensor> block_table,
std::optional<Tensor> alibi_slopes, int64_t max_seqlen_q,
int64_t max_seqlen_k, float p_dropout, float softmax_scale,
bool zero_tensors, bool is_causal, int64_t window_size_left,
int64_t window_size_right, float softcap, bool return_softmax,
std::optional<int64_t> generator = std::nullopt,
int64_t num_splits = 0)
: batch_size_{cu_seqlens_q.numel() - 1},
total_q_{q.size(0)},
num_heads_{q.size(1)},
head_size_{q.size(2)},
num_kv_heads_{block_table.has_value() ? k.size(2) : k.size(1)},
max_seqlen_q_{max_seqlen_q},
max_seqlen_k_{max_seqlen_k},
p_dropout_{p_dropout},
softmax_scale_{softmax_scale},
zero_tensors_{zero_tensors},
is_causal_{is_causal},
window_size_left_{window_size_left},
window_size_right_{window_size_right},
softcap_{softcap},
return_softmax_{return_softmax},
num_splits_{num_splits},
q_dtype_{q.dtype()},
has_seqused_k_{seqused_k.has_value()},
has_leftpad_k_{leftpad_k.has_value()},
has_block_table_{block_table.has_value()},
has_alibi_slopes_{alibi_slopes.has_value()},
has_generator_{generator.has_value()} {
assert(q.ndim() == 3 &&
"`MhaVarlenFwd` requires `q` to be `[total_q, heads, dim]`");
if (has_block_table_) {
assert(k.ndim() == 4 && v.ndim() == 4 &&
"`MhaVarlenFwd` with `block_table` requires paged `k` / `v` to "
"be `[num_blocks, block_size, heads, dim]`.");
assert(block_table->ndim() == 2 &&
"`MhaVarlenFwd` requires `block_table` to be 2D.");
assert(block_table->dtype() == DataType::kInt32 &&
"`MhaVarlenFwd` requires `block_table` to be `int32`.");
assert(k.shape() == v.shape() &&
"`MhaVarlenFwd` requires paged `k` and `v` to have same shape.");
assert(k.size(3) == head_size_ &&
"`MhaVarlenFwd` requires paged `k` / `v` last dim to match `q`.");
} else {
assert(k.ndim() == 3 && v.ndim() == 3 &&
"`MhaVarlenFwd` requires `k` / `v` to be "
"`[total_k, heads, dim]`.");
assert(k.shape() == v.shape() &&
"`MhaVarlenFwd` requires `k` and `v` to have same shape.");
assert(k.size(2) == head_size_ &&
"`MhaVarlenFwd` requires `k` / `v` last dim to match `q`.");
}
assert(k.dtype() == q.dtype() && v.dtype() == q.dtype() &&
"`MhaVarlenFwd` requires `q`, `k`, and `v` to have same dtype");
assert(out.dtype() == q.dtype() &&
"`MhaVarlenFwd` requires `out` to have same dtype as `q`.");
assert(k.stride(-1) == 1 && v.stride(-1) == 1 && q.stride(-1) == 1 &&
"`MhaVarlenFwd` requires contiguous last dimension");
assert(num_heads_ % num_kv_heads_ == 0 &&
"`MhaVarlenFwd` requires `num_heads` divisible by `num_kv_heads`");
assert(head_size_ <= 256 &&
"`MhaVarlenFwd` supports `head_size` up to 256");
assert(head_size_ % 8 == 0 &&
"`MhaVarlenFwd` requires `head_size` to be a multiple of 8 in "
"InfiniOps v1.");
assert(cu_seqlens_q.ndim() == 1 && cu_seqlens_k.ndim() == 1 &&
"`MhaVarlenFwd` requires 1D `cu_seqlens_q` / `cu_seqlens_k`");
assert(cu_seqlens_q.dtype() == DataType::kInt32 &&
cu_seqlens_k.dtype() == DataType::kInt32 &&
"`MhaVarlenFwd` requires `cu_seqlens_q` and `cu_seqlens_k` to be "
"`int32`.");
assert(cu_seqlens_q.numel() > 1 &&
"`MhaVarlenFwd` requires non-empty `cu_seqlens`.");
assert(cu_seqlens_q.numel() == cu_seqlens_k.numel() &&
"`MhaVarlenFwd` requires matching `cu_seqlens` lengths");
assert(out.shape() == q.shape() &&
"`MhaVarlenFwd` requires `out` to match `q` shape");
assert(out.stride(-1) == 1 &&
"`MhaVarlenFwd` requires `out` contiguous last dimension");
assert(p_dropout >= 0.0f && p_dropout <= 1.0f &&
"`MhaVarlenFwd` requires `p_dropout` in `[0, 1]`.");
assert(std::isfinite(softmax_scale_) &&
"`MhaVarlenFwd` requires finite `softmax_scale`.");
assert(max_seqlen_q_ >= 0 && max_seqlen_k_ >= 0 &&
"`MhaVarlenFwd` requires non-negative max sequence lengths.");

if (seqused_k.has_value()) {
assert(seqused_k->ndim() == 1 &&
"`MhaVarlenFwd` requires `seqused_k` to be 1D.");
assert(seqused_k->dtype() == DataType::kInt32 &&
"`MhaVarlenFwd` requires `seqused_k` to be `int32`.");
}

if (leftpad_k.has_value()) {
assert(leftpad_k->ndim() == 1 &&
"`MhaVarlenFwd` requires `leftpad_k` to be 1D.");
assert(leftpad_k->dtype() == DataType::kInt32 &&
"`MhaVarlenFwd` requires `leftpad_k` to be `int32`.");
}

if (alibi_slopes.has_value()) {
assert((alibi_slopes->ndim() == 1 || alibi_slopes->ndim() == 2) &&
"`MhaVarlenFwd` requires `alibi_slopes` to be 1D or 2D.");
assert(alibi_slopes->dtype() == DataType::kFloat32 &&
"`MhaVarlenFwd` requires `alibi_slopes` to be `float32`.");
}
}

virtual void operator()(
const Tensor q, const Tensor k, const Tensor v, Tensor out,
const Tensor cu_seqlens_q, const Tensor cu_seqlens_k,
std::optional<Tensor> seqused_k, std::optional<Tensor> leftpad_k,
std::optional<Tensor> block_table, std::optional<Tensor> alibi_slopes,
int64_t max_seqlen_q, int64_t max_seqlen_k, float p_dropout,
float softmax_scale, bool zero_tensors, bool is_causal,
int64_t window_size_left, int64_t window_size_right, float softcap,
bool return_softmax, std::optional<int64_t> generator = std::nullopt,
int64_t num_splits = 0) const = 0;

protected:
Tensor::Size batch_size_{0};

Tensor::Size total_q_{0};

Tensor::Size num_heads_{0};

Tensor::Size head_size_{0};

Tensor::Size num_kv_heads_{0};

int64_t max_seqlen_q_{0};

int64_t max_seqlen_k_{0};

float p_dropout_{0.0f};

float softmax_scale_{0.0f};

bool zero_tensors_{false};

bool is_causal_{false};

int64_t window_size_left_{-1};

int64_t window_size_right_{-1};

float softcap_{0.0f};

bool return_softmax_{false};

int64_t num_splits_{0};

const DataType q_dtype_;

bool has_seqused_k_{false};

bool has_leftpad_k_{false};

bool has_block_table_{false};

bool has_alibi_slopes_{false};

bool has_generator_{false};
};

} // namespace infini::ops

#endif
157 changes: 157 additions & 0 deletions src/native/ascend/ops/fia_common_.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#ifndef INFINI_OPS_ASCEND_FIA_COMMON__H_
#define INFINI_OPS_ASCEND_FIA_COMMON__H_

#include <cassert>
#include <cstdint>
#include <vector>

#include "acl/acl.h"
#include "aclnn/acl_meta.h"
#include "data_type.h"
#include "native/ascend/common.h"
#include "tensor.h"

namespace infini::ops::ascend::fia {

inline void AssertSupportedDtype(const DataType dtype, const char* op_name) {
(void)op_name;
assert((dtype == DataType::kFloat16 || dtype == DataType::kBFloat16) &&
"`FIA`: only `float16` and `bfloat16` are supported");
}

inline std::vector<int64_t> ReadIntTensor(const Tensor& tensor,
aclrtStream stream) {
assert(tensor.ndim() == 1 && "`FIA`: sequence tensor must be 1D");

const auto n = tensor.numel();
std::vector<int64_t> result(n);

if (tensor.dtype() == DataType::kInt32) {
std::vector<int32_t> tmp(n);
const int32_t* src = nullptr;

if (tensor.device().type() == Device::Type::kCpu) {
src = static_cast<const int32_t*>(tensor.data());
} else {
auto ret = aclrtMemcpyAsync(tmp.data(), n * sizeof(int32_t),
tensor.data(), n * sizeof(int32_t),
ACL_MEMCPY_DEVICE_TO_HOST, stream);
assert(ret == ACL_SUCCESS && "`FIA`: D2H copy failed");
ret = aclrtSynchronizeStream(stream);
assert(ret == ACL_SUCCESS && "`FIA`: stream synchronize failed");
src = tmp.data();
}

for (std::size_t i = 0; i < n; ++i) {
result[i] = static_cast<int64_t>(src[i]);
}

return result;
}

if (tensor.dtype() == DataType::kInt64) {
if (tensor.device().type() == Device::Type::kCpu) {
const auto* src = static_cast<const int64_t*>(tensor.data());
result.assign(src, src + n);
} else {
auto ret = aclrtMemcpyAsync(result.data(), n * sizeof(int64_t),
tensor.data(), n * sizeof(int64_t),
ACL_MEMCPY_DEVICE_TO_HOST, stream);
assert(ret == ACL_SUCCESS && "`FIA`: D2H copy failed");
ret = aclrtSynchronizeStream(stream);
assert(ret == ACL_SUCCESS && "`FIA`: stream synchronize failed");
}

return result;
}

assert(false && "`FIA`: sequence tensor must be `int32` or `int64`");

return result;
}

inline aclIntArray* CreateCumSeqLengths(const Tensor& cu_seqlens,
aclrtStream stream) {
auto values = ReadIntTensor(cu_seqlens, stream);
assert(values.size() > 1 && "`FIA`: `cu_seqlens` must contain a batch");

return aclCreateIntArray(values.data() + 1,
static_cast<int64_t>(values.size() - 1));
}

inline aclIntArray* CreateDiffSeqLengths(const Tensor& cu_seqlens,
aclrtStream stream) {
auto values = ReadIntTensor(cu_seqlens, stream);
assert(values.size() > 1 && "`FIA`: `cu_seqlens` must contain a batch");

std::vector<int64_t> lengths(values.size() - 1);
for (std::size_t i = 0; i < lengths.size(); ++i) {
lengths[i] = values[i + 1] - values[i];
}

return aclCreateIntArray(lengths.data(),
static_cast<int64_t>(lengths.size()));
}

inline aclIntArray* CreateSeqLengths(const Tensor& seqlens,
aclrtStream stream) {
auto values = ReadIntTensor(seqlens, stream);

return aclCreateIntArray(values.data(), static_cast<int64_t>(values.size()));
}

inline aclTensor* MakeCausalMask(void** mask_buf) {
constexpr int64_t kMaskDim = 2048;
const int64_t mask_elems = kMaskDim * kMaskDim;
const auto mask_bytes = static_cast<size_t>(mask_elems);

auto ret = aclrtMalloc(mask_buf, mask_bytes, ACL_MEM_MALLOC_NORMAL_ONLY);
assert(ret == ACL_SUCCESS && "`FIA`: causal mask allocation failed");

std::vector<uint8_t> host_mask(mask_elems);
for (int64_t row = 0; row < kMaskDim; ++row) {
for (int64_t col = 0; col < kMaskDim; ++col) {
host_mask[row * kMaskDim + col] = (col > row) ? 1 : 0;
}
}

ret = aclrtMemcpy(*mask_buf, mask_bytes, host_mask.data(), mask_bytes,
ACL_MEMCPY_HOST_TO_DEVICE);
assert(ret == ACL_SUCCESS && "`FIA`: causal mask upload failed");

std::vector<int64_t> shape = {kMaskDim, kMaskDim};
std::vector<int64_t> strides = {kMaskDim, 1};
std::vector<int64_t> storage = {mask_elems};

return aclCreateTensor(shape.data(), static_cast<int64_t>(shape.size()),
ACL_UINT8, strides.data(), 0, ACL_FORMAT_ND,
storage.data(), static_cast<int64_t>(storage.size()),
*mask_buf);
}

inline void ResolveSparseMode(bool is_causal, int64_t window_left,
int64_t window_right, int64_t& sparse_mode,
int64_t& pre_tokens, int64_t& next_tokens) {
sparse_mode = 0;
pre_tokens = 2147483647;
next_tokens = 2147483647;

if (is_causal) {
if (window_left >= 0) {
sparse_mode = 4;
pre_tokens = window_left;
} else {
sparse_mode = 3;
}
next_tokens = 0;

return;
}

if (window_left >= 0) pre_tokens = window_left;
if (window_right >= 0) next_tokens = window_right;
}

} // namespace infini::ops::ascend::fia

#endif
Loading
Loading