Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/infiniop/ops/chunk_gated_delta_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ __INFINI_C __export infiniStatus_t infiniopCreateChunkGatedDeltaRuleDescriptor(
infiniopHandle_t handle,
infiniopChunkGatedDeltaRuleDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc, // padded: [B, T, Hv, Dv]; varlen: [1, total_tokens, Hv, Dv]
infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dk, Dv]; indexed pool: [pool_size, Hv, Dv, Dk]
infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dv, Dk]; indexed pool: [pool_size, Hv, Dv, Dk]
infiniopTensorDescriptor_t final_state_desc, // null when final_state_indices_desc is provided
infiniopTensorDescriptor_t q_desc, // padded: [B, T, Hk, Dk]; varlen: [1, total_tokens, Hk, Dk]
infiniopTensorDescriptor_t k_desc, // same shape as q
Expand Down
2 changes: 1 addition & 1 deletion include/infiniop/ops/recurrent_gated_delta_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ __INFINI_C __export infiniStatus_t infiniopCreateRecurrentGatedDeltaRuleDescript
infiniopHandle_t handle,
infiniopRecurrentGatedDeltaRuleDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc, // [B, T, Hv, Dv], T must be 1; last dim contiguous
infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dk, Dv]; indexed pool: [pool_size, Hv, Dv, Dk]
infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dv, Dk]; indexed pool: [pool_size, Hv, Dv, Dk]
infiniopTensorDescriptor_t final_state_desc, // legacy/indexed out-of-place final state; null when final_state_indices_desc is provided
infiniopTensorDescriptor_t q_desc, // [B, T, Hk, Dk], T must be 1; last dim contiguous
infiniopTensorDescriptor_t k_desc, // [B, T, Hk, Dk], same shape as q; last dim contiguous
Expand Down
2 changes: 1 addition & 1 deletion python/infinicore/nn/functional/chunk_gated_delta_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def chunk_gated_delta_rule(
q, k: ``[B, T, Hk, Dk]``
v, out: ``[B, T, Hv, Dv]``
g, beta: ``[B, T, Hv]``
initial_state: ``[B, Hv, Dk, Dv]``
initial_state: ``[B, Hv, Dv, Dk]``

Continuous-batch mode shapes:
Pass ``cu_seqlens`` with shape ``[B + 1]`` and dtype int32/int64.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,14 @@ static void check_4d_sequence_tensor(const Tensor &x, const char *name) {

static Shape chunk_final_state_shape(const Tensor &q,
const Tensor &v,
const Tensor &initial_state,
std::optional<Tensor> cu_seqlens,
std::optional<Tensor> initial_state_indices) {
std::optional<Tensor> cu_seqlens) {
const auto &q_shape = q->shape();
const auto &v_shape = v->shape();
size_t B = cu_seqlens.has_value() ? cu_seqlens.value()->shape()[0] - 1 : v_shape[0];
size_t Hv = v_shape[2];
size_t Dk = q_shape[3];
size_t Dv = v_shape[3];
if (initial_state_indices.has_value()) {
return {B, Hv, Dv, Dk};
}
return {B, Hv, Dk, Dv};
return {B, Hv, Dv, Dk};
}

Tensor chunk_gated_delta_rule(const Tensor &q,
Expand All @@ -118,7 +113,7 @@ Tensor chunk_gated_delta_rule(const Tensor &q,
auto out = Tensor::empty(v->shape(), v->dtype(), v->device());
std::optional<Tensor> final_state = std::nullopt;
if (!final_state_indices.has_value()) {
final_state = Tensor::empty(chunk_final_state_shape(q, v, initial_state, cu_seqlens, initial_state_indices),
final_state = Tensor::empty(chunk_final_state_shape(q, v, cu_seqlens),
initial_state->dtype(),
initial_state->device());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Tensor recurrent_gated_delta_rule(const Tensor &q,
Tensor k4 = ensure_4d_sequence_tensor(k, "k");
Tensor v4 = ensure_4d_sequence_tensor(v, "v");
auto out = Tensor::empty(recurrent_output_shape(v4), v4->dtype(), v4->device());
Shape final_state_shape = {v4->shape()[0], v4->shape()[2], q4->shape()[3], v4->shape()[3]};
Shape final_state_shape = {v4->shape()[0], v4->shape()[2], v4->shape()[3], q4->shape()[3]};
auto final_state = Tensor::empty(final_state_shape, initial_state->dtype(), initial_state->device());
recurrent_gated_delta_rule_(out,
initial_state,
Expand Down
2 changes: 1 addition & 1 deletion src/infinicore/pybind11/ops/chunk_gated_delta_rule.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ inline void bind_chunk_gated_delta_rule(py::module &m) {

Padded mode:
q/k: [B, T, Hk, Dk], v/out: [B, T, Hv, Dv], g/beta: [B, T, Hv],
initial_state: [B, Hv, Dk, Dv].
initial_state: [B, Hv, Dv, Dk].

Continuous-batch mode:
pass cu_seqlens [B + 1]; q/k: [1, total_tokens, Hk, Dk],
Expand Down
36 changes: 9 additions & 27 deletions src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#define __CHUNK_GATED_DELTA_RULE_KERNEL_CUH__

#include <cstdint>
#include <cuda_bf16.h>
#include <cuda_fp16.h>

__device__ inline int64_t loadOptionalIndex(
const void *indices,
Expand All @@ -29,7 +27,7 @@ __device__ inline float loadAsFloat<half>(const half *ptr, ptrdiff_t offset) {
}

template <>
__device__ inline float loadAsFloat<__nv_bfloat16>(const __nv_bfloat16 *ptr, ptrdiff_t offset) {
__device__ inline float loadAsFloat<cuda_bfloat16>(const cuda_bfloat16 *ptr, ptrdiff_t offset) {
return __bfloat162float(ptr[offset]);
}

Expand Down Expand Up @@ -190,9 +188,7 @@ __device__ void chunkGatedDeltaRuleKernel(
int dk = i / Dv;
int dv = i % Dv;

ptrdiff_t read_idx = indexed_state_pool
? initial_base + static_cast<ptrdiff_t>(dv) * initial_s2 + static_cast<ptrdiff_t>(dk) * initial_s3
: initial_base + static_cast<ptrdiff_t>(dk) * initial_s2 + static_cast<ptrdiff_t>(dv) * initial_s3;
ptrdiff_t read_idx = initial_base + static_cast<ptrdiff_t>(dv) * initial_s2 + static_cast<ptrdiff_t>(dk) * initial_s3;

state_local[i] = static_cast<Tcompute>(
loadAsFloat(initial_state, read_idx));
Expand Down Expand Up @@ -419,15 +415,9 @@ __device__ void chunkGatedDeltaRuleKernel(
int dk = i / Dv;
int dv = i % Dv;

ptrdiff_t write_idx;
if (indexed_state_pool) {
const ptrdiff_t s2 = final_state_indices != nullptr ? initial_s2 : final_s2;
const ptrdiff_t s3 = final_state_indices != nullptr ? initial_s3 : final_s3;

write_idx = final_base + static_cast<ptrdiff_t>(dv) * s2 + static_cast<ptrdiff_t>(dk) * s3;
} else {
write_idx = final_base + static_cast<ptrdiff_t>(dk) * final_s2 + static_cast<ptrdiff_t>(dv) * final_s3;
}
const ptrdiff_t s2 = final_state_indices != nullptr ? initial_s2 : final_s2;
const ptrdiff_t s3 = final_state_indices != nullptr ? initial_s3 : final_s3;
ptrdiff_t write_idx = final_base + static_cast<ptrdiff_t>(dv) * s2 + static_cast<ptrdiff_t>(dk) * s3;

final_state_target[write_idx] = static_cast<Tdata>(state_local[i]);
}
Expand Down Expand Up @@ -563,9 +553,7 @@ __device__ void chunkGatedDeltaRuleRecurrentKernel(
int dk = i / Dv;
int dv = i % Dv;

ptrdiff_t read_idx = indexed_state_pool
? initial_base + static_cast<ptrdiff_t>(dv) * initial_s2 + static_cast<ptrdiff_t>(dk) * initial_s3
: initial_base + static_cast<ptrdiff_t>(dk) * initial_s2 + static_cast<ptrdiff_t>(dv) * initial_s3;
ptrdiff_t read_idx = initial_base + static_cast<ptrdiff_t>(dv) * initial_s2 + static_cast<ptrdiff_t>(dk) * initial_s3;

state_local[i] = static_cast<Tcompute>(
loadAsFloat(initial_state, read_idx));
Expand Down Expand Up @@ -650,15 +638,9 @@ __device__ void chunkGatedDeltaRuleRecurrentKernel(
int dk = i / Dv;
int dv = i % Dv;

ptrdiff_t write_idx;
if (indexed_state_pool) {
const ptrdiff_t s2 = final_state_indices != nullptr ? initial_s2 : final_s2;
const ptrdiff_t s3 = final_state_indices != nullptr ? initial_s3 : final_s3;

write_idx = final_base + static_cast<ptrdiff_t>(dv) * s2 + static_cast<ptrdiff_t>(dk) * s3;
} else {
write_idx = final_base + static_cast<ptrdiff_t>(dk) * final_s2 + static_cast<ptrdiff_t>(dv) * final_s3;
}
const ptrdiff_t s2 = final_state_indices != nullptr ? initial_s2 : final_s2;
const ptrdiff_t s3 = final_state_indices != nullptr ? initial_s3 : final_s3;
ptrdiff_t write_idx = final_base + static_cast<ptrdiff_t>(dv) * s2 + static_cast<ptrdiff_t>(dk) * s3;

final_state_target[write_idx] = static_cast<Tdata>(state_local[i]);
}
Expand Down
4 changes: 2 additions & 2 deletions src/infiniop/ops/chunk_gated_delta_rule/info.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class ChunkGatedDeltaRuleInfo {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else {
if (initial_shape[0] != B || initial_shape[1] != Hv || initial_shape[2] != Dk || initial_shape[3] != Dv) {
if (initial_shape[0] != B || initial_shape[1] != Hv || initial_shape[2] != Dv || initial_shape[3] != Dk) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
Expand All @@ -137,7 +137,7 @@ class ChunkGatedDeltaRuleInfo {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else {
if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dk || final_shape[3] != Dv) {
if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dv || final_shape[3] != Dk) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
Expand Down
Loading
Loading