Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions src/base/mha_fwd_kvcache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
#ifndef INFINI_OPS_BASE_MHA_FWD_KVCACHE_H_
#define INFINI_OPS_BASE_MHA_FWD_KVCACHE_H_

#include <cassert>
#include <cmath>
#include <optional>

#include "data_type.h"
#include "operator.h"

namespace infini::ops {

// FlashAttention-compatible forward attention over an existing KV cache.
//
// Layout follows `flash_attn` `fwd_kvcache`:
// `q`: `[batch, seqlen_q, num_heads, head_size]`.
// `kcache` / `vcache`: `[batch_cache, seqlen_k, num_kv_heads, head_size]`
// or paged `[num_blocks, page_block_size, num_kv_heads, head_size]` when
// `block_table` is supplied.
//
// InfiniOps is an in-place operator API, so `out` must be supplied even
// though FlashAttention can allocate it when omitted.
// This signature intentionally keeps `out` near the FlashAttention argument
// position for compatibility with existing callers; this is an exception to
// the normal InfiniOps output-last convention.
class MhaFwdKvcache : public Operator<MhaFwdKvcache> {
public:
MhaFwdKvcache(
const Tensor q, const Tensor kcache, const Tensor vcache,
std::optional<Tensor> k, std::optional<Tensor> v,
std::optional<Tensor> seqlens_k, std::optional<Tensor> rotary_cos,
std::optional<Tensor> rotary_sin, std::optional<Tensor> cache_batch_idx,
std::optional<Tensor> leftpad_k, std::optional<Tensor> block_table,
std::optional<Tensor> alibi_slopes, Tensor out, float softmax_scale,
bool is_causal, int64_t window_size_left, int64_t window_size_right,
float softcap, bool is_rotary_interleaved, int64_t num_splits = 0)
: batch_size_{q.size(0)},
seqlen_q_{q.size(1)},
num_heads_{q.size(2)},
head_size_{q.size(3)},
num_kv_heads_{kcache.size(2)},
page_block_size_{block_table.has_value() ? kcache.size(1) : 0},
softmax_scale_{softmax_scale},
is_causal_{is_causal},
window_size_left_{window_size_left},
window_size_right_{window_size_right},
softcap_{softcap},
is_rotary_interleaved_{is_rotary_interleaved},
num_splits_{num_splits},
q_dtype_{q.dtype()},
has_k_{k.has_value()},
has_v_{v.has_value()},
has_seqlens_k_{seqlens_k.has_value()},
has_rotary_cos_{rotary_cos.has_value()},
has_rotary_sin_{rotary_sin.has_value()},
has_cache_batch_idx_{cache_batch_idx.has_value()},
has_leftpad_k_{leftpad_k.has_value()},
has_block_table_{block_table.has_value()},
has_alibi_slopes_{alibi_slopes.has_value()} {
assert(q.ndim() == 4 &&
"`MhaFwdKvcache` requires `q` to be `[batch, seq, heads, dim]`");
assert(kcache.ndim() == 4 && vcache.ndim() == 4 &&
"`MhaFwdKvcache` requires `kcache` / `vcache` to be 4D.");
assert(kcache.shape() == vcache.shape() &&
"`MhaFwdKvcache` requires `kcache` and `vcache` same shape");
assert(kcache.dtype() == q.dtype() && vcache.dtype() == q.dtype() &&
"`MhaFwdKvcache` requires `q`, `kcache`, and `vcache` same dtype");
assert(out.dtype() == q.dtype() &&
"`MhaFwdKvcache` requires `out` to have same dtype as `q`.");
assert(kcache.size(3) == head_size_ &&
"`MhaFwdKvcache` requires cache head dim to match `q`.");
assert(q.stride(-1) == 1 && kcache.stride(-1) == 1 &&
vcache.stride(-1) == 1 &&
"`MhaFwdKvcache` requires contiguous last dimension");
assert(num_heads_ % num_kv_heads_ == 0 &&
"`MhaFwdKvcache` requires `num_heads` divisible by `num_kv_heads`");
assert(head_size_ <= 256 &&
"`MhaFwdKvcache` supports `head_size` up to 256");
assert(out.shape() == q.shape() &&
"`MhaFwdKvcache` requires `out` to match `q` shape");
assert(out.stride(-1) == 1 &&
"`MhaFwdKvcache` requires `out` contiguous last dimension");
assert(std::isfinite(softmax_scale_) &&
"`MhaFwdKvcache` requires finite `softmax_scale`.");

if (k.has_value() || v.has_value()) {
assert(k.has_value() && v.has_value() &&
"`MhaFwdKvcache` requires `k` and `v` together.");
if (k.has_value() && v.has_value()) {
assert(k->ndim() == 4 && v->ndim() == 4 &&
"`MhaFwdKvcache` requires appended `k` / `v` to be 4D.");
assert(k->shape() == v->shape() &&
"`MhaFwdKvcache` requires appended `k` and `v` same shape.");
assert(k->dtype() == q.dtype() && v->dtype() == q.dtype() &&
"`MhaFwdKvcache` requires appended `k` / `v` same dtype as "
"`q`.");
assert(k->size(0) == batch_size_ && k->size(2) == num_kv_heads_ &&
k->size(3) == head_size_ &&
"`MhaFwdKvcache` requires appended `k` / `v` to match cache "
"batch, heads, and dim.");
assert(k->stride(-1) == 1 && v->stride(-1) == 1 &&
"`MhaFwdKvcache` requires appended `k` / `v` contiguous last "
"dimension.");
}
}

if (seqlens_k.has_value()) {
assert(seqlens_k->ndim() == 1 &&
"`MhaFwdKvcache` requires `seqlens_k` to be 1D.");
assert((seqlens_k->dtype() == DataType::kInt32 ||
seqlens_k->dtype() == DataType::kInt64) &&
"`MhaFwdKvcache` requires `seqlens_k` to be `int32` or `int64`.");
}

if (block_table.has_value()) {
assert(block_table->ndim() == 2 &&
"`MhaFwdKvcache` requires `block_table` to be 2D.");
assert(block_table->dtype() == DataType::kInt32 &&
"`MhaFwdKvcache` requires `block_table` to be `int32`.");
}

if (cache_batch_idx.has_value()) {
assert(cache_batch_idx->ndim() == 1 &&
"`MhaFwdKvcache` requires `cache_batch_idx` to be 1D.");
assert(cache_batch_idx->dtype() == DataType::kInt32 &&
"`MhaFwdKvcache` requires `cache_batch_idx` to be `int32`.");
}

if (leftpad_k.has_value()) {
assert(leftpad_k->ndim() == 1 &&
"`MhaFwdKvcache` requires `leftpad_k` to be 1D.");
assert(leftpad_k->dtype() == DataType::kInt32 &&
"`MhaFwdKvcache` requires `leftpad_k` to be `int32`.");
}

if (rotary_cos.has_value() || rotary_sin.has_value()) {
assert(rotary_cos.has_value() && rotary_sin.has_value() &&
"`MhaFwdKvcache` requires `rotary_cos` and `rotary_sin` "
"together.");
if (rotary_cos.has_value() && rotary_sin.has_value()) {
assert(rotary_cos->shape() == rotary_sin->shape() &&
"`MhaFwdKvcache` requires rotary tensors to have same shape.");
}
}

if (alibi_slopes.has_value()) {
assert((alibi_slopes->ndim() == 1 || alibi_slopes->ndim() == 2) &&
"`MhaFwdKvcache` requires `alibi_slopes` to be 1D or 2D.");
assert(alibi_slopes->dtype() == DataType::kFloat32 &&
"`MhaFwdKvcache` requires `alibi_slopes` to be `float32`.");
}
}

virtual void operator()(
const Tensor q, const Tensor kcache, const Tensor vcache,
std::optional<Tensor> k, std::optional<Tensor> v,
std::optional<Tensor> seqlens_k, std::optional<Tensor> rotary_cos,
std::optional<Tensor> rotary_sin, std::optional<Tensor> cache_batch_idx,
std::optional<Tensor> leftpad_k, std::optional<Tensor> block_table,
std::optional<Tensor> alibi_slopes, Tensor out, float softmax_scale,
bool is_causal, int64_t window_size_left, int64_t window_size_right,
float softcap, bool is_rotary_interleaved,
int64_t num_splits = 0) const = 0;

protected:
Tensor::Size batch_size_{0};

Tensor::Size seqlen_q_{0};

Tensor::Size num_heads_{0};

Tensor::Size head_size_{0};

Tensor::Size num_kv_heads_{0};

Tensor::Size page_block_size_{0};

float softmax_scale_{0.0f};

bool is_causal_{false};

int64_t window_size_left_{-1};

int64_t window_size_right_{-1};

float softcap_{0.0f};

bool is_rotary_interleaved_{false};

int64_t num_splits_{0};

const DataType q_dtype_;

bool has_k_{false};

bool has_v_{false};

bool has_seqlens_k_{false};

bool has_rotary_cos_{false};

bool has_rotary_sin_{false};

bool has_cache_batch_idx_{false};

bool has_leftpad_k_{false};

bool has_block_table_{false};

bool has_alibi_slopes_{false};
};

} // namespace infini::ops

#endif
157 changes: 157 additions & 0 deletions src/native/ascend/ops/fia_common_.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#ifndef INFINI_OPS_ASCEND_FIA_COMMON__H_
#define INFINI_OPS_ASCEND_FIA_COMMON__H_

#include <cassert>
#include <cstdint>
#include <vector>

#include "acl/acl.h"
#include "aclnn/acl_meta.h"
#include "data_type.h"
#include "native/ascend/common.h"
#include "tensor.h"

namespace infini::ops::ascend::fia {

inline void AssertSupportedDtype(const DataType dtype, const char* op_name) {
(void)op_name;
assert((dtype == DataType::kFloat16 || dtype == DataType::kBFloat16) &&
"`FIA`: only `float16` and `bfloat16` are supported");
}

inline std::vector<int64_t> ReadIntTensor(const Tensor& tensor,
aclrtStream stream) {
assert(tensor.ndim() == 1 && "`FIA`: sequence tensor must be 1D");

const auto n = tensor.numel();
std::vector<int64_t> result(n);

if (tensor.dtype() == DataType::kInt32) {
std::vector<int32_t> tmp(n);
const int32_t* src = nullptr;

if (tensor.device().type() == Device::Type::kCpu) {
src = static_cast<const int32_t*>(tensor.data());
} else {
auto ret = aclrtMemcpyAsync(tmp.data(), n * sizeof(int32_t),
tensor.data(), n * sizeof(int32_t),
ACL_MEMCPY_DEVICE_TO_HOST, stream);
assert(ret == ACL_SUCCESS && "`FIA`: D2H copy failed");
ret = aclrtSynchronizeStream(stream);
assert(ret == ACL_SUCCESS && "`FIA`: stream synchronize failed");
src = tmp.data();
}

for (std::size_t i = 0; i < n; ++i) {
result[i] = static_cast<int64_t>(src[i]);
}

return result;
}

if (tensor.dtype() == DataType::kInt64) {
if (tensor.device().type() == Device::Type::kCpu) {
const auto* src = static_cast<const int64_t*>(tensor.data());
result.assign(src, src + n);
} else {
auto ret = aclrtMemcpyAsync(result.data(), n * sizeof(int64_t),
tensor.data(), n * sizeof(int64_t),
ACL_MEMCPY_DEVICE_TO_HOST, stream);
assert(ret == ACL_SUCCESS && "`FIA`: D2H copy failed");
ret = aclrtSynchronizeStream(stream);
assert(ret == ACL_SUCCESS && "`FIA`: stream synchronize failed");
}

return result;
}

assert(false && "`FIA`: sequence tensor must be `int32` or `int64`");

return result;
}

inline aclIntArray* CreateCumSeqLengths(const Tensor& cu_seqlens,
aclrtStream stream) {
auto values = ReadIntTensor(cu_seqlens, stream);
assert(values.size() > 1 && "`FIA`: `cu_seqlens` must contain a batch");

return aclCreateIntArray(values.data() + 1,
static_cast<int64_t>(values.size() - 1));
}

inline aclIntArray* CreateDiffSeqLengths(const Tensor& cu_seqlens,
aclrtStream stream) {
auto values = ReadIntTensor(cu_seqlens, stream);
assert(values.size() > 1 && "`FIA`: `cu_seqlens` must contain a batch");

std::vector<int64_t> lengths(values.size() - 1);
for (std::size_t i = 0; i < lengths.size(); ++i) {
lengths[i] = values[i + 1] - values[i];
}

return aclCreateIntArray(lengths.data(),
static_cast<int64_t>(lengths.size()));
}

inline aclIntArray* CreateSeqLengths(const Tensor& seqlens,
aclrtStream stream) {
auto values = ReadIntTensor(seqlens, stream);

return aclCreateIntArray(values.data(), static_cast<int64_t>(values.size()));
}

inline aclTensor* MakeCausalMask(void** mask_buf) {
constexpr int64_t kMaskDim = 2048;
const int64_t mask_elems = kMaskDim * kMaskDim;
const auto mask_bytes = static_cast<size_t>(mask_elems);

auto ret = aclrtMalloc(mask_buf, mask_bytes, ACL_MEM_MALLOC_NORMAL_ONLY);
assert(ret == ACL_SUCCESS && "`FIA`: causal mask allocation failed");

std::vector<uint8_t> host_mask(mask_elems);
for (int64_t row = 0; row < kMaskDim; ++row) {
for (int64_t col = 0; col < kMaskDim; ++col) {
host_mask[row * kMaskDim + col] = (col > row) ? 1 : 0;
}
}

ret = aclrtMemcpy(*mask_buf, mask_bytes, host_mask.data(), mask_bytes,
ACL_MEMCPY_HOST_TO_DEVICE);
assert(ret == ACL_SUCCESS && "`FIA`: causal mask upload failed");

std::vector<int64_t> shape = {kMaskDim, kMaskDim};
std::vector<int64_t> strides = {kMaskDim, 1};
std::vector<int64_t> storage = {mask_elems};

return aclCreateTensor(shape.data(), static_cast<int64_t>(shape.size()),
ACL_UINT8, strides.data(), 0, ACL_FORMAT_ND,
storage.data(), static_cast<int64_t>(storage.size()),
*mask_buf);
}

inline void ResolveSparseMode(bool is_causal, int64_t window_left,
int64_t window_right, int64_t& sparse_mode,
int64_t& pre_tokens, int64_t& next_tokens) {
sparse_mode = 0;
pre_tokens = 2147483647;
next_tokens = 2147483647;

if (is_causal) {
if (window_left >= 0) {
sparse_mode = 4;
pre_tokens = window_left;
} else {
sparse_mode = 3;
}
next_tokens = 0;

return;
}

if (window_left >= 0) pre_tokens = window_left;
if (window_right >= 0) next_tokens = window_right;
}

} // namespace infini::ops::ascend::fia

#endif
Loading
Loading