diff --git a/include/infiniop/ops/chunk_gated_delta_rule.h b/include/infiniop/ops/chunk_gated_delta_rule.h index a9a9c74aa..6caca6217 100644 --- a/include/infiniop/ops/chunk_gated_delta_rule.h +++ b/include/infiniop/ops/chunk_gated_delta_rule.h @@ -9,7 +9,7 @@ __INFINI_C __export infiniStatus_t infiniopCreateChunkGatedDeltaRuleDescriptor( infiniopHandle_t handle, infiniopChunkGatedDeltaRuleDescriptor_t *desc_ptr, infiniopTensorDescriptor_t out_desc, // padded: [B, T, Hv, Dv]; varlen: [1, total_tokens, Hv, Dv] - infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dk, Dv]; indexed pool: [pool_size, Hv, Dv, Dk] + infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dv, Dk]; indexed pool: [pool_size, Hv, Dv, Dk] infiniopTensorDescriptor_t final_state_desc, // null when final_state_indices_desc is provided infiniopTensorDescriptor_t q_desc, // padded: [B, T, Hk, Dk]; varlen: [1, total_tokens, Hk, Dk] infiniopTensorDescriptor_t k_desc, // same shape as q diff --git a/include/infiniop/ops/recurrent_gated_delta_rule.h b/include/infiniop/ops/recurrent_gated_delta_rule.h index 2865cc1f5..6a1b29c65 100644 --- a/include/infiniop/ops/recurrent_gated_delta_rule.h +++ b/include/infiniop/ops/recurrent_gated_delta_rule.h @@ -9,7 +9,7 @@ __INFINI_C __export infiniStatus_t infiniopCreateRecurrentGatedDeltaRuleDescript infiniopHandle_t handle, infiniopRecurrentGatedDeltaRuleDescriptor_t *desc_ptr, infiniopTensorDescriptor_t out_desc, // [B, T, Hv, Dv], T must be 1; last dim contiguous - infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dk, Dv]; indexed pool: [pool_size, Hv, Dv, Dk] + infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dv, Dk]; indexed pool: [pool_size, Hv, Dv, Dk] infiniopTensorDescriptor_t final_state_desc, // legacy/indexed out-of-place final state; null when final_state_indices_desc is provided infiniopTensorDescriptor_t q_desc, // [B, T, Hk, Dk], T must be 1; last dim contiguous infiniopTensorDescriptor_t k_desc, // [B, T, Hk, Dk], same shape as q; last dim contiguous diff --git a/python/infinicore/nn/functional/chunk_gated_delta_rule.py b/python/infinicore/nn/functional/chunk_gated_delta_rule.py index dbcf15e75..082743faa 100644 --- a/python/infinicore/nn/functional/chunk_gated_delta_rule.py +++ b/python/infinicore/nn/functional/chunk_gated_delta_rule.py @@ -22,7 +22,7 @@ def chunk_gated_delta_rule( q, k: ``[B, T, Hk, Dk]`` v, out: ``[B, T, Hv, Dv]`` g, beta: ``[B, T, Hv]`` - initial_state: ``[B, Hv, Dk, Dv]`` + initial_state: ``[B, Hv, Dv, Dk]`` Continuous-batch mode shapes: Pass ``cu_seqlens`` with shape ``[B + 1]`` and dtype int32/int64. diff --git a/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.cc b/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.cc index 2fccef444..82e3d5b1d 100644 --- a/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.cc +++ b/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.cc @@ -86,19 +86,14 @@ static void check_4d_sequence_tensor(const Tensor &x, const char *name) { static Shape chunk_final_state_shape(const Tensor &q, const Tensor &v, - const Tensor &initial_state, - std::optional cu_seqlens, - std::optional initial_state_indices) { + std::optional cu_seqlens) { const auto &q_shape = q->shape(); const auto &v_shape = v->shape(); size_t B = cu_seqlens.has_value() ? cu_seqlens.value()->shape()[0] - 1 : v_shape[0]; size_t Hv = v_shape[2]; size_t Dk = q_shape[3]; size_t Dv = v_shape[3]; - if (initial_state_indices.has_value()) { - return {B, Hv, Dv, Dk}; - } - return {B, Hv, Dk, Dv}; + return {B, Hv, Dv, Dk}; } Tensor chunk_gated_delta_rule(const Tensor &q, @@ -118,7 +113,7 @@ Tensor chunk_gated_delta_rule(const Tensor &q, auto out = Tensor::empty(v->shape(), v->dtype(), v->device()); std::optional final_state = std::nullopt; if (!final_state_indices.has_value()) { - final_state = Tensor::empty(chunk_final_state_shape(q, v, initial_state, cu_seqlens, initial_state_indices), + final_state = Tensor::empty(chunk_final_state_shape(q, v, cu_seqlens), initial_state->dtype(), initial_state->device()); } diff --git a/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.cc b/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.cc index 4fd7bb69a..06edd9955 100644 --- a/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.cc +++ b/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.cc @@ -93,7 +93,7 @@ Tensor recurrent_gated_delta_rule(const Tensor &q, Tensor k4 = ensure_4d_sequence_tensor(k, "k"); Tensor v4 = ensure_4d_sequence_tensor(v, "v"); auto out = Tensor::empty(recurrent_output_shape(v4), v4->dtype(), v4->device()); - Shape final_state_shape = {v4->shape()[0], v4->shape()[2], q4->shape()[3], v4->shape()[3]}; + Shape final_state_shape = {v4->shape()[0], v4->shape()[2], v4->shape()[3], q4->shape()[3]}; auto final_state = Tensor::empty(final_state_shape, initial_state->dtype(), initial_state->device()); recurrent_gated_delta_rule_(out, initial_state, diff --git a/src/infinicore/pybind11/ops/chunk_gated_delta_rule.hpp b/src/infinicore/pybind11/ops/chunk_gated_delta_rule.hpp index 6d8bf5cf2..c39174f86 100644 --- a/src/infinicore/pybind11/ops/chunk_gated_delta_rule.hpp +++ b/src/infinicore/pybind11/ops/chunk_gated_delta_rule.hpp @@ -27,7 +27,7 @@ inline void bind_chunk_gated_delta_rule(py::module &m) { Padded mode: q/k: [B, T, Hk, Dk], v/out: [B, T, Hv, Dv], g/beta: [B, T, Hv], - initial_state: [B, Hv, Dk, Dv]. + initial_state: [B, Hv, Dv, Dk]. Continuous-batch mode: pass cu_seqlens [B + 1]; q/k: [1, total_tokens, Hk, Dk], diff --git a/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh b/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh index cc167b6c7..d9bac786d 100644 --- a/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh +++ b/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh @@ -2,8 +2,6 @@ #define __CHUNK_GATED_DELTA_RULE_KERNEL_CUH__ #include -#include -#include __device__ inline int64_t loadOptionalIndex( const void *indices, @@ -29,7 +27,7 @@ __device__ inline float loadAsFloat(const half *ptr, ptrdiff_t offset) { } template <> -__device__ inline float loadAsFloat<__nv_bfloat16>(const __nv_bfloat16 *ptr, ptrdiff_t offset) { +__device__ inline float loadAsFloat(const cuda_bfloat16 *ptr, ptrdiff_t offset) { return __bfloat162float(ptr[offset]); } @@ -190,9 +188,7 @@ __device__ void chunkGatedDeltaRuleKernel( int dk = i / Dv; int dv = i % Dv; - ptrdiff_t read_idx = indexed_state_pool - ? initial_base + static_cast(dv) * initial_s2 + static_cast(dk) * initial_s3 - : initial_base + static_cast(dk) * initial_s2 + static_cast(dv) * initial_s3; + ptrdiff_t read_idx = initial_base + static_cast(dv) * initial_s2 + static_cast(dk) * initial_s3; state_local[i] = static_cast( loadAsFloat(initial_state, read_idx)); @@ -419,15 +415,9 @@ __device__ void chunkGatedDeltaRuleKernel( int dk = i / Dv; int dv = i % Dv; - ptrdiff_t write_idx; - if (indexed_state_pool) { - const ptrdiff_t s2 = final_state_indices != nullptr ? initial_s2 : final_s2; - const ptrdiff_t s3 = final_state_indices != nullptr ? initial_s3 : final_s3; - - write_idx = final_base + static_cast(dv) * s2 + static_cast(dk) * s3; - } else { - write_idx = final_base + static_cast(dk) * final_s2 + static_cast(dv) * final_s3; - } + const ptrdiff_t s2 = final_state_indices != nullptr ? initial_s2 : final_s2; + const ptrdiff_t s3 = final_state_indices != nullptr ? initial_s3 : final_s3; + ptrdiff_t write_idx = final_base + static_cast(dv) * s2 + static_cast(dk) * s3; final_state_target[write_idx] = static_cast(state_local[i]); } @@ -563,9 +553,7 @@ __device__ void chunkGatedDeltaRuleRecurrentKernel( int dk = i / Dv; int dv = i % Dv; - ptrdiff_t read_idx = indexed_state_pool - ? initial_base + static_cast(dv) * initial_s2 + static_cast(dk) * initial_s3 - : initial_base + static_cast(dk) * initial_s2 + static_cast(dv) * initial_s3; + ptrdiff_t read_idx = initial_base + static_cast(dv) * initial_s2 + static_cast(dk) * initial_s3; state_local[i] = static_cast( loadAsFloat(initial_state, read_idx)); @@ -650,15 +638,9 @@ __device__ void chunkGatedDeltaRuleRecurrentKernel( int dk = i / Dv; int dv = i % Dv; - ptrdiff_t write_idx; - if (indexed_state_pool) { - const ptrdiff_t s2 = final_state_indices != nullptr ? initial_s2 : final_s2; - const ptrdiff_t s3 = final_state_indices != nullptr ? initial_s3 : final_s3; - - write_idx = final_base + static_cast(dv) * s2 + static_cast(dk) * s3; - } else { - write_idx = final_base + static_cast(dk) * final_s2 + static_cast(dv) * final_s3; - } + const ptrdiff_t s2 = final_state_indices != nullptr ? initial_s2 : final_s2; + const ptrdiff_t s3 = final_state_indices != nullptr ? initial_s3 : final_s3; + ptrdiff_t write_idx = final_base + static_cast(dv) * s2 + static_cast(dk) * s3; final_state_target[write_idx] = static_cast(state_local[i]); } diff --git a/src/infiniop/ops/chunk_gated_delta_rule/info.h b/src/infiniop/ops/chunk_gated_delta_rule/info.h index 4dcb25319..101da34ae 100644 --- a/src/infiniop/ops/chunk_gated_delta_rule/info.h +++ b/src/infiniop/ops/chunk_gated_delta_rule/info.h @@ -125,7 +125,7 @@ class ChunkGatedDeltaRuleInfo { return INFINI_STATUS_BAD_TENSOR_SHAPE; } } else { - if (initial_shape[0] != B || initial_shape[1] != Hv || initial_shape[2] != Dk || initial_shape[3] != Dv) { + if (initial_shape[0] != B || initial_shape[1] != Hv || initial_shape[2] != Dv || initial_shape[3] != Dk) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } } @@ -137,7 +137,7 @@ class ChunkGatedDeltaRuleInfo { return INFINI_STATUS_BAD_TENSOR_SHAPE; } } else { - if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dk || final_shape[3] != Dv) { + if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dv || final_shape[3] != Dk) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } } diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh b/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh index db9161626..211e508df 100644 --- a/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh +++ b/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh @@ -1,12 +1,8 @@ -// kernel.cuh (in op/recurrent_gated_delta_rule/cuda/) - #ifndef __RECURRENT_GATED_DELTA_RULE_KERNEL_CUH__ #define __RECURRENT_GATED_DELTA_RULE_KERNEL_CUH__ #include #include -#include -#include __device__ inline int64_t loadStateIndex( const void *indices, @@ -32,12 +28,19 @@ __device__ inline float loadAsFloat(const half *ptr, ptrdiff_t offset) { } template <> -__device__ inline float loadAsFloat<__nv_bfloat16>(const __nv_bfloat16 *ptr, ptrdiff_t offset) { +__device__ inline float loadAsFloat(const cuda_bfloat16 *ptr, ptrdiff_t offset) { return __bfloat162float(ptr[offset]); } -template -__device__ void recurrentGatedDeltaRuleKernel( +__device__ inline float warpReduceSum(float value) { +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + value += __shfl_down_sync(0xffffffff, value, offset); + } + return __shfl_sync(0xffffffff, value, 0); +} +template +__device__ void recurrentGatedDeltaRuleIndexedPoolWarpKernel( Tdata *out, Tdata *initial_state, Tdata *final_state, @@ -51,7 +54,6 @@ __device__ void recurrentGatedDeltaRuleKernel( bool initial_state_indices_i64, bool final_state_indices_i64, bool use_qk_l2norm, - bool indexed_state_pool, size_t Hk, size_t value_heads_per_key_head, ptrdiff_t out_s0, @@ -80,10 +82,15 @@ __device__ void recurrentGatedDeltaRuleKernel( ptrdiff_t beta_s0, ptrdiff_t beta_s1, ptrdiff_t beta_s2) { + constexpr int WARP_SIZE = 32; + constexpr int NUM_THREADS = WARPS_PER_BLOCK * WARP_SIZE; + const int batch_idx = blockIdx.x; const int value_head_idx = blockIdx.y; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane_idx = threadIdx.x & (WARP_SIZE - 1); + const int value_dim_idx = blockIdx.z * WARPS_PER_BLOCK + warp_idx; const int key_head_idx = value_head_idx / static_cast(value_heads_per_key_head); - const int thread_idx = threadIdx.x; if (key_head_idx >= static_cast(Hk)) { return; @@ -92,50 +99,14 @@ __device__ void recurrentGatedDeltaRuleKernel( constexpr int seq_idx = 0; const ptrdiff_t q_base = static_cast(batch_idx) * q_s0 + seq_idx * q_s1 + static_cast(key_head_idx) * q_s2; const ptrdiff_t k_base = static_cast(batch_idx) * k_s0 + seq_idx * k_s1 + static_cast(key_head_idx) * k_s2; - const ptrdiff_t v_base = static_cast(batch_idx) * v_s0 + seq_idx * v_s1 + static_cast(value_head_idx) * v_s2; - const ptrdiff_t out_base = static_cast(batch_idx) * out_s0 + seq_idx * out_s1 + static_cast(value_head_idx) * out_s2; - const ptrdiff_t gate_offset = static_cast(batch_idx) * g_s0 + seq_idx * g_s1 + static_cast(value_head_idx) * g_s2; - const ptrdiff_t beta_offset = static_cast(batch_idx) * beta_s0 + seq_idx * beta_s1 + static_cast(value_head_idx) * beta_s2; - - int64_t read_slot = static_cast(batch_idx); - int64_t write_slot = static_cast(batch_idx); - if (indexed_state_pool) { - read_slot = loadStateIndex(initial_state_indices, initial_state_indices_i64, batch_idx, batch_idx); - write_slot = final_state_indices == nullptr - ? static_cast(batch_idx) - : loadStateIndex(final_state_indices, final_state_indices_i64, batch_idx, batch_idx); - if (read_slot < 0 || write_slot < 0) { - for (int dv_idx = thread_idx; dv_idx < Dv; dv_idx += NUM_THREADS) { - out[out_base + dv_idx] = static_cast(0.0f); - } - return; - } - } - - const ptrdiff_t initial_base = indexed_state_pool - ? static_cast(read_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1 - : static_cast(batch_idx) * initial_s0 + static_cast(value_head_idx) * initial_s1; - ptrdiff_t final_base = 0; - Tdata *final_state_target = nullptr; - if (indexed_state_pool && final_state_indices != nullptr) { - final_state_target = initial_state; - final_base = static_cast(write_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1; - } else if (indexed_state_pool) { - final_state_target = final_state; - final_base = static_cast(batch_idx) * final_s0 + static_cast(value_head_idx) * final_s1; - } else { - final_state_target = final_state; - final_base = static_cast(batch_idx) * final_s0 + static_cast(value_head_idx) * final_s1; - } extern __shared__ char shared_mem_char[]; Tcompute *shared_mem = reinterpret_cast(shared_mem_char); - Tcompute *q_local = shared_mem; Tcompute *k_local = q_local + Dk; Tcompute *norm_val = k_local + Dk; - for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + for (int i = threadIdx.x; i < static_cast(Dk); i += NUM_THREADS) { q_local[i] = static_cast(loadAsFloat(q, q_base + i)); k_local[i] = static_cast(loadAsFloat(k, k_base + i)); } @@ -143,12 +114,12 @@ __device__ void recurrentGatedDeltaRuleKernel( if (use_qk_l2norm) { __syncthreads(); Tcompute sum_sq = 0.0f; - for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + for (int i = threadIdx.x; i < static_cast(Dk); i += NUM_THREADS) { sum_sq += q_local[i] * q_local[i]; } - norm_val[thread_idx] = sum_sq; + norm_val[threadIdx.x] = sum_sq; __syncthreads(); - if (thread_idx == 0) { + if (threadIdx.x == 0) { Tcompute total_sum_sq = 0.0f; for (int i = 0; i < NUM_THREADS; ++i) { total_sum_sq += norm_val[i]; @@ -156,19 +127,19 @@ __device__ void recurrentGatedDeltaRuleKernel( norm_val[0] = rsqrtf(total_sum_sq + 1e-6f); } __syncthreads(); - Tcompute r_norm_q = norm_val[0]; + const Tcompute r_norm_q = norm_val[0]; - for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + for (int i = threadIdx.x; i < static_cast(Dk); i += NUM_THREADS) { q_local[i] *= r_norm_q; } sum_sq = 0.0f; - for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + for (int i = threadIdx.x; i < static_cast(Dk); i += NUM_THREADS) { sum_sq += k_local[i] * k_local[i]; } - norm_val[thread_idx] = sum_sq; + norm_val[threadIdx.x] = sum_sq; __syncthreads(); - if (thread_idx == 0) { + if (threadIdx.x == 0) { Tcompute total_sum_sq = 0.0f; for (int i = 0; i < NUM_THREADS; ++i) { total_sum_sq += norm_val[i]; @@ -176,51 +147,84 @@ __device__ void recurrentGatedDeltaRuleKernel( norm_val[0] = rsqrtf(total_sum_sq + 1e-6f); } __syncthreads(); - Tcompute r_norm_k = norm_val[0]; + const Tcompute r_norm_k = norm_val[0]; - for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + for (int i = threadIdx.x; i < static_cast(Dk); i += NUM_THREADS) { k_local[i] *= r_norm_k; } - __syncthreads(); } - Tcompute g_t = expf(static_cast(loadAsFloat(g, gate_offset))); - Tcompute beta_t = static_cast(loadAsFloat(beta, beta_offset)); - Tcompute scale = rsqrtf(static_cast(Dk)); - - for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + const Tcompute scale = rsqrtf(static_cast(Dk)); + for (int i = threadIdx.x; i < static_cast(Dk); i += NUM_THREADS) { q_local[i] *= scale; } __syncthreads(); - for (int dv_idx = thread_idx; dv_idx < Dv; dv_idx += NUM_THREADS) { - Tcompute kv_mem = 0.0f; - for (int dk_idx = 0; dk_idx < Dk; ++dk_idx) { - ptrdiff_t state_idx = indexed_state_pool - ? initial_base + static_cast(dv_idx) * initial_s2 + static_cast(dk_idx) * initial_s3 - : initial_base + static_cast(dk_idx) * initial_s2 + static_cast(dv_idx) * initial_s3; - Tcompute h_prev = static_cast(loadAsFloat(initial_state, state_idx)); - kv_mem += (h_prev * g_t) * k_local[dk_idx]; - } + if (value_dim_idx >= static_cast(Dv)) { + return; + } + + int64_t read_slot = loadStateIndex(initial_state_indices, initial_state_indices_i64, batch_idx, batch_idx); + int64_t write_slot = final_state_indices == nullptr + ? static_cast(batch_idx) + : loadStateIndex(final_state_indices, final_state_indices_i64, batch_idx, batch_idx); - Tcompute v_t = static_cast(loadAsFloat(v, v_base + dv_idx)); - Tcompute delta = (v_t - kv_mem) * beta_t; - Tcompute out_val = 0.0f; - - for (int dk_idx = 0; dk_idx < Dk; ++dk_idx) { - ptrdiff_t read_state_idx = indexed_state_pool - ? initial_base + static_cast(dv_idx) * initial_s2 + static_cast(dk_idx) * initial_s3 - : initial_base + static_cast(dk_idx) * initial_s2 + static_cast(dv_idx) * initial_s3; - ptrdiff_t write_state_idx = indexed_state_pool - ? final_base + static_cast(dv_idx) * (final_state_indices != nullptr ? initial_s2 : final_s2) + static_cast(dk_idx) * (final_state_indices != nullptr ? initial_s3 : final_s3) - : final_base + static_cast(dk_idx) * final_s2 + static_cast(dv_idx) * final_s3; - Tcompute h_prev = static_cast(loadAsFloat(initial_state, read_state_idx)); - Tcompute h_final = (h_prev * g_t) + (k_local[dk_idx] * delta); - out_val += h_final * q_local[dk_idx]; - final_state_target[write_state_idx] = static_cast(h_final); + const ptrdiff_t out_base = static_cast(batch_idx) * out_s0 + seq_idx * out_s1 + static_cast(value_head_idx) * out_s2; + if (read_slot < 0 || write_slot < 0) { + if (lane_idx == 0) { + out[out_base + value_dim_idx] = static_cast(0.0f); } - out[out_base + dv_idx] = static_cast(out_val); + return; } -} + const ptrdiff_t v_base = static_cast(batch_idx) * v_s0 + seq_idx * v_s1 + static_cast(value_head_idx) * v_s2; + const ptrdiff_t gate_offset = static_cast(batch_idx) * g_s0 + seq_idx * g_s1 + static_cast(value_head_idx) * g_s2; + const ptrdiff_t beta_offset = static_cast(batch_idx) * beta_s0 + seq_idx * beta_s1 + static_cast(value_head_idx) * beta_s2; + + const ptrdiff_t initial_base = static_cast(read_slot) * initial_s0 + + static_cast(value_head_idx) * initial_s1 + + static_cast(value_dim_idx) * initial_s2; + + Tdata *final_state_target = final_state_indices == nullptr ? final_state : initial_state; + const ptrdiff_t final_base = final_state_indices == nullptr + ? static_cast(batch_idx) * final_s0 + + static_cast(value_head_idx) * final_s1 + + static_cast(value_dim_idx) * final_s2 + : static_cast(write_slot) * initial_s0 + + static_cast(value_head_idx) * initial_s1 + + static_cast(value_dim_idx) * initial_s2; + const ptrdiff_t final_k_stride = final_state_indices == nullptr ? final_s3 : initial_s3; + + const Tcompute g_t = expf(static_cast(loadAsFloat(g, gate_offset))); + const Tcompute beta_t = static_cast(loadAsFloat(beta, beta_offset)); + + Tcompute kv_mem = 0.0f; + Tcompute hq_mem = 0.0f; + Tcompute kq_mem = 0.0f; + for (int dk_idx = lane_idx; dk_idx < static_cast(Dk); dk_idx += WARP_SIZE) { + const Tcompute h_prev = static_cast(loadAsFloat(initial_state, initial_base + static_cast(dk_idx) * initial_s3)); + const Tcompute k_t = k_local[dk_idx]; + const Tcompute q_t = q_local[dk_idx]; + kv_mem += (h_prev * g_t) * k_t; + hq_mem += h_prev * q_t; + kq_mem += k_t * q_t; + } + kv_mem = warpReduceSum(kv_mem); + hq_mem = warpReduceSum(hq_mem); + kq_mem = warpReduceSum(kq_mem); + + const Tcompute v_t = static_cast(loadAsFloat(v, v_base + value_dim_idx)); + const Tcompute delta = (v_t - kv_mem) * beta_t; + + if (lane_idx == 0) { + const Tcompute out_val = g_t * hq_mem + delta * kq_mem; + out[out_base + value_dim_idx] = static_cast(out_val); + } + + for (int dk_idx = lane_idx; dk_idx < static_cast(Dk); dk_idx += WARP_SIZE) { + const Tcompute h_prev = static_cast(loadAsFloat(initial_state, initial_base + static_cast(dk_idx) * initial_s3)); + const Tcompute h_final = (h_prev * g_t) + (k_local[dk_idx] * delta); + final_state_target[final_base + static_cast(dk_idx) * final_k_stride] = static_cast(h_final); + } +} #endif // __RECURRENT_GATED_DELTA_RULE_KERNEL_CUH__ diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/info.h b/src/infiniop/ops/recurrent_gated_delta_rule/info.h index 51644964a..81e399b00 100644 --- a/src/infiniop/ops/recurrent_gated_delta_rule/info.h +++ b/src/infiniop/ops/recurrent_gated_delta_rule/info.h @@ -104,8 +104,8 @@ class RecurrentGatedDeltaRuleInfo { return INFINI_STATUS_BAD_TENSOR_SHAPE; } } else { - // Legacy layout is [B, Hv, Dk, Dv]. - if (initial_shape[0] != B || initial_shape[1] != Hv || initial_shape[2] != Dk || initial_shape[3] != Dv) { + // State layout is [B, Hv, Dv, Dk]. + if (initial_shape[0] != B || initial_shape[1] != Hv || initial_shape[2] != Dv || initial_shape[3] != Dk) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } } @@ -117,7 +117,7 @@ class RecurrentGatedDeltaRuleInfo { return INFINI_STATUS_BAD_TENSOR_SHAPE; } } else { - if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dk || final_shape[3] != Dv) { + if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dv || final_shape[3] != Dk) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } } diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu index 3c227266e..7697ad07d 100644 --- a/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu +++ b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu @@ -1,5 +1,3 @@ -// recurrent_gated_delta_rule_nvidia.cu - #include "../../../devices/nvidia/nvidia_common.cuh" #include "recurrent_gated_delta_rule_nvidia.cuh" @@ -8,8 +6,8 @@ #include "../cuda/kernel.cuh" #include -template -INFINIOP_CUDA_KERNEL recurrentGatedDeltaRule( +template +INFINIOP_CUDA_KERNEL recurrentGatedDeltaRuleIndexedPoolWarp( Tdata *out, Tdata *initial_state, Tdata *final_state, const Tdata *q, const Tdata *k, const Tdata *v, const Tgate *g, const Tgate *beta, @@ -18,7 +16,6 @@ INFINIOP_CUDA_KERNEL recurrentGatedDeltaRule( bool initial_state_indices_i64, bool final_state_indices_i64, bool use_qk_l2norm, - bool indexed_state_pool, size_t Hk, size_t value_heads_per_key_head, ptrdiff_t out_s0, @@ -47,11 +44,11 @@ INFINIOP_CUDA_KERNEL recurrentGatedDeltaRule( ptrdiff_t beta_s0, ptrdiff_t beta_s1, ptrdiff_t beta_s2) { - recurrentGatedDeltaRuleKernel( + recurrentGatedDeltaRuleIndexedPoolWarpKernel( out, initial_state, final_state, q, k, v, g, beta, initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, - use_qk_l2norm, indexed_state_pool, + use_qk_l2norm, Hk, value_heads_per_key_head, out_s0, out_s1, out_s2, initial_s0, initial_s1, initial_s2, initial_s3, @@ -62,7 +59,6 @@ INFINIOP_CUDA_KERNEL recurrentGatedDeltaRule( g_s0, g_s1, g_s2, beta_s0, beta_s1, beta_s2); } - namespace op { namespace recurrent_gated_delta_rule { namespace nvidia { @@ -104,8 +100,8 @@ infiniStatus_t Descriptor::create( return infiniStatus_t::INFINI_STATUS_SUCCESS; } -template -infiniStatus_t launchKernelTyped( +template +infiniStatus_t launchIndexedPoolWarpKernelTyped( const RecurrentGatedDeltaRuleInfo &_info, void *out, void *initial_state, void *final_state, const void *q, const void *k, const void *v, @@ -115,7 +111,8 @@ infiniStatus_t launchKernelTyped( bool initial_state_indices_i64, bool final_state_indices_i64, cudaStream_t stream) { - dim3 grid(uint32_t(_info.B), uint32_t(_info.Hv), 1); + constexpr size_t NUM_THREADS = WARPS_PER_BLOCK * 32; + dim3 grid(uint32_t(_info.B), uint32_t(_info.Hv), uint32_t((_info.Dv + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK)); dim3 block(NUM_THREADS); size_t shared_mem_size = (Dk + Dk + NUM_THREADS) * sizeof(float); @@ -124,7 +121,7 @@ infiniStatus_t launchKernelTyped( auto final_s2 = _info.final_state_strides.empty() ? 0 : _info.final_state_strides[2]; auto final_s3 = _info.final_state_strides.empty() ? 0 : _info.final_state_strides[3]; - recurrentGatedDeltaRule + recurrentGatedDeltaRuleIndexedPoolWarp <<>>( static_cast(out), static_cast(initial_state), @@ -139,7 +136,6 @@ infiniStatus_t launchKernelTyped( initial_state_indices_i64, final_state_indices_i64, _info.use_qk_l2norm, - _info.indexed_state_pool, _info.Hk, _info.value_heads_per_key_head, _info.out_strides[0], @@ -170,9 +166,8 @@ infiniStatus_t launchKernelTyped( _info.beta_strides[2]); return infiniStatus_t::INFINI_STATUS_SUCCESS; } - -template -infiniStatus_t launchKernelForGate( +template +infiniStatus_t launchIndexedPoolWarpKernelForGate( const RecurrentGatedDeltaRuleInfo &_info, void *out, void *initial_state, void *final_state, const void *q, const void *k, const void *v, @@ -184,15 +179,15 @@ infiniStatus_t launchKernelForGate( cudaStream_t stream) { switch (_info.gate_dtype) { case INFINI_DTYPE_F16: - return launchKernelTyped( + return launchIndexedPoolWarpKernelTyped( _info, out, initial_state, final_state, q, k, v, g, beta, initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); case INFINI_DTYPE_BF16: - return launchKernelTyped( + return launchIndexedPoolWarpKernelTyped( _info, out, initial_state, final_state, q, k, v, g, beta, initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); case INFINI_DTYPE_F32: - return launchKernelTyped( + return launchIndexedPoolWarpKernelTyped( _info, out, initial_state, final_state, q, k, v, g, beta, initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); default: @@ -200,8 +195,8 @@ infiniStatus_t launchKernelForGate( } } -template -infiniStatus_t launchKernel( +template +infiniStatus_t launchIndexedPoolWarpKernel( const RecurrentGatedDeltaRuleInfo &_info, void *out, void *initial_state, void *final_state, const void *q, const void *k, const void *v, @@ -213,22 +208,21 @@ infiniStatus_t launchKernel( cudaStream_t stream) { switch (_info.data_dtype) { case INFINI_DTYPE_F16: - return launchKernelForGate( + return launchIndexedPoolWarpKernelForGate( _info, out, initial_state, final_state, q, k, v, g, beta, initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); case INFINI_DTYPE_BF16: - return launchKernelForGate<__nv_bfloat16, Dk, Dv, NUM_THREADS>( + return launchIndexedPoolWarpKernelForGate<__nv_bfloat16, Dk, Dv, WARPS_PER_BLOCK>( _info, out, initial_state, final_state, q, k, v, g, beta, initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); case INFINI_DTYPE_F32: - return launchKernelForGate( + return launchIndexedPoolWarpKernelForGate( _info, out, initial_state, final_state, q, k, v, g, beta, initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); default: return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_DTYPE; } } - infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, void *out, void *initial_state, void *final_state, @@ -253,15 +247,15 @@ infiniStatus_t Descriptor::calculate( bool final_indices_i64 = _info.final_state_indices_dtype == INFINI_DTYPE_I64; if (_info.Dk == 128 && _info.Dv == 128) { - if (_opaque->internal->maxThreadsPerBlock() >= 128) { - return launchKernel<128, 128, 128>( + if (_opaque->internal->maxThreadsPerBlock() >= 256) { + return launchIndexedPoolWarpKernel<128, 128, 8>( _info, out, initial_state, final_state, q, k, v, g, beta, initial_state_indices, final_state_indices, initial_indices_i64, final_indices_i64, stream); } } else if (_info.Dk == 64 && _info.Dv == 64) { - if (_opaque->internal->maxThreadsPerBlock() >= 64) { - return launchKernel<64, 64, 64>( + if (_opaque->internal->maxThreadsPerBlock() >= 256) { + return launchIndexedPoolWarpKernel<64, 64, 8>( _info, out, initial_state, final_state, q, k, v, g, beta, initial_state_indices, final_state_indices, initial_indices_i64, final_indices_i64, stream); diff --git a/test/infinicore/ops/recurrent_gated_delta_rule.py b/test/infinicore/ops/recurrent_gated_delta_rule.py index 4fbc59490..88732dfcb 100644 --- a/test/infinicore/ops/recurrent_gated_delta_rule.py +++ b/test/infinicore/ops/recurrent_gated_delta_rule.py @@ -49,7 +49,7 @@ def ref_recurrent_gated_delta_rule( query, key, value, beta, g = [ x.contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] - state = initial_state.contiguous().to(torch.float32).clone() + state = initial_state.transpose(-1, -2).contiguous().to(torch.float32).clone() batch_size, sequence_length, key_heads, _ = key.shape value_heads, v_head_dim = value.shape[2], value.shape[-1] value_heads_per_key_head = value_heads // key_heads @@ -82,7 +82,9 @@ def ref_recurrent_gated_delta_rule( state[:, vh] = state_t out[:, i, vh] = (state_t * q_t.unsqueeze(-1)).sum(dim=-2) - return out.contiguous().to(initial_dtype), state.contiguous().to(initial_dtype) + return out.contiguous().to(initial_dtype), state.transpose(-1, -2).contiguous().to( + initial_dtype + ) def strided_bthd_strides(shape): @@ -123,7 +125,7 @@ def parse_test_cases(): k_shape = (B, T, Hk, Dk) v_shape = (B, T, Hv, Dv) gate_shape = (B, T, Hv) - initial_state_shape = (B, Hv, Dk, Dv) + initial_state_shape = (B, Hv, Dv, Dk) pool_size = B * 2 + 3 state_pool_shape = (pool_size, Hv, Dv, Dk) q_strides = strided_bthd_strides(q_shape) if strided_qkv else None @@ -220,9 +222,7 @@ def torch_operator(self, q, k, v, g, beta, initial_state, *args, **kwargs): if mode == "indexed_pool": initial_state_indices, final_state_indices = args state_pool = initial_state.clone() - gathered_initial_state = ( - state_pool[initial_state_indices].transpose(-1, -2).contiguous() - ) + gathered_initial_state = state_pool[initial_state_indices].contiguous() out, final_state = ref_recurrent_gated_delta_rule( q, k, @@ -232,7 +232,7 @@ def torch_operator(self, q, k, v, g, beta, initial_state, *args, **kwargs): gathered_initial_state, use_qk_l2norm=use_qk_l2norm, ) - state_pool[final_state_indices] = final_state.transpose(-1, -2).contiguous() + state_pool[final_state_indices] = final_state.contiguous() return out, state_pool if mode == "user_3d": diff --git a/test/infiniop/chunk_gated_delta_rule.py b/test/infiniop/chunk_gated_delta_rule.py index 674ecad8e..3e2dc087d 100644 --- a/test/infiniop/chunk_gated_delta_rule.py +++ b/test/infiniop/chunk_gated_delta_rule.py @@ -39,7 +39,7 @@ def ref_chunk_gated_delta_rule( query, key, value, beta, g = [ x.contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] - state = initial_state.contiguous().to(torch.float32).clone() + state = initial_state.transpose(-1, -2).contiguous().to(torch.float32).clone() if cu_seqlens is None: batch_size, sequence_length, key_heads, k_head_dim = key.shape @@ -77,7 +77,9 @@ def ref_chunk_gated_delta_rule( state[b, vh] = state_t out[token_b, t, vh] = (state_t * q_t.unsqueeze(-1)).sum(dim=-2) - return out.contiguous().to(initial_dtype), state.contiguous().to(initial_dtype) + return out.contiguous().to(initial_dtype), state.transpose(-1, -2).contiguous().to( + initial_dtype + ) _PADDED_TEST_CASES_DATA = [ @@ -99,7 +101,7 @@ def ref_chunk_gated_delta_rule( _TOLERANCE_MAP = { InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2}, InfiniDtype.BF16: {"atol": 5e-2, "rtol": 5e-2}, - InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-4}, + InfiniDtype.F32: {"atol": 1e-3, "rtol": 1e-3}, } DEBUG = False @@ -245,8 +247,8 @@ def test_padded( ) g = make_gate((B, T, Hv), device) beta = make_beta((B, T, Hv), device) - initial_state = TestTensor((B, Hv, Dk, Dv), None, dtype, device) - final_state = TestTensor((B, Hv, Dk, Dv), None, dtype, device) + initial_state = TestTensor((B, Hv, Dv, Dk), None, dtype, device) + final_state = TestTensor((B, Hv, Dv, Dk), None, dtype, device) out = TestTensor( (B, T, Hv, Dv), bthd_strides(B, T, Hv, Dv, strided_qkv), @@ -355,11 +357,9 @@ def test_varlen_indexed_pool( final_state_indices_torch, InfiniDtype.I64, device ) - gathered_initial = ( - initial_state_pool.torch_tensor()[initial_state_indices_torch] - .transpose(-1, -2) - .contiguous() - ) + gathered_initial = initial_state_pool.torch_tensor()[ + initial_state_indices_torch + ].contiguous() ans_out, ans_final_state = ref_chunk_gated_delta_rule( q.torch_tensor(), k.torch_tensor(), @@ -371,7 +371,7 @@ def test_varlen_indexed_pool( use_qk_l2norm_in_kernel=use_qk_l2norm, ) ans_pool = initial_state_pool.torch_tensor().clone() - ans_pool[final_state_indices_torch] = ans_final_state.transpose(-1, -2).contiguous() + ans_pool[final_state_indices_torch] = ans_final_state.contiguous() if sync: sync() diff --git a/test/infiniop/recurrent_gated_delta_rule.py b/test/infiniop/recurrent_gated_delta_rule.py index 25a40911f..d08cda54f 100644 --- a/test/infiniop/recurrent_gated_delta_rule.py +++ b/test/infiniop/recurrent_gated_delta_rule.py @@ -42,7 +42,9 @@ def ref_recurrent_gated_delta_rule( query, key, value, beta, g = [ x.contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] - initial_state = initial_state.contiguous().to(torch.float32).clone() + initial_state = ( + initial_state.transpose(-1, -2).contiguous().to(torch.float32).clone() + ) batch_size, sequence_length, key_heads, k_head_dim = key.shape value_heads, v_head_dim = value.shape[2], value.shape[-1] @@ -82,7 +84,9 @@ def ref_recurrent_gated_delta_rule( core_attn_out = core_attn_out.contiguous().to(initial_dtype) if last_recurrent_state is not None: - last_recurrent_state = last_recurrent_state.contiguous().to(initial_dtype) + last_recurrent_state = ( + last_recurrent_state.transpose(-1, -2).contiguous().to(initial_dtype) + ) return core_attn_out, last_recurrent_state @@ -160,7 +164,7 @@ def test( g = make_gate((B, T, Hv), device) beta = make_beta((B, T, Hv), device) - initial_state = TestTensor((B, Hv, Dk, Dv), None, dtype, device) + initial_state = TestTensor((B, Hv, Dv, Dk), None, dtype, device) out = TestTensor( (B, T, Hv, Dv), bthd_strides(B, T, Hv, Dv, strided_qkv), @@ -168,7 +172,7 @@ def test( device, mode="zeros", ) - final_state = TestTensor((B, Hv, Dk, Dv), None, dtype, device) + final_state = TestTensor((B, Hv, Dv, Dk), None, dtype, device) ans_out, ans_final_state = ref_recurrent_gated_delta_rule( q.torch_tensor(), @@ -341,11 +345,9 @@ def test_indexed_pool_inplace( mode="zeros", ) - gathered_initial_state = ( - initial_state_pool.torch_tensor()[initial_state_indices_torch] - .transpose(-1, -2) - .contiguous() - ) + gathered_initial_state = initial_state_pool.torch_tensor()[ + initial_state_indices_torch + ].contiguous() ans_out, ans_final_state = ref_recurrent_gated_delta_rule( q.torch_tensor(), k.torch_tensor(), @@ -357,9 +359,7 @@ def test_indexed_pool_inplace( use_qk_l2norm_in_kernel=use_qk_l2norm, ) ans_initial_state_pool = initial_state_pool.torch_tensor().clone() - ans_initial_state_pool[final_state_indices_torch] = ans_final_state.transpose( - -1, -2 - ).contiguous() + ans_initial_state_pool[final_state_indices_torch] = ans_final_state.contiguous() if sync: sync()