-
Notifications
You must be signed in to change notification settings - Fork 150
Expand file tree
/
Copy pathdynamic_dispatch.cu
More file actions
342 lines (306 loc) · 14.6 KB
/
dynamic_dispatch.cu
File metadata and controls
342 lines (306 loc) · 14.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors
/// GPU kernel that decompresses a Vortex encoding tree in a single launch via dynamic dispatch.
///
/// Stages communicate through shared memory: early input stages populate
/// persistent smem regions (e.g., dictionary values, run-end endpoints) that
/// later stages reference via smem offsets.
///
/// The final output stage writes directly to global memory instead of back
/// to shared memory. Shared memory is dynamically sized at launch time to
/// fit all intermediate buffers that must coexist simultaneously.
#include <assert.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdint.h>
#include <thrust/binary_search.h>
#include <thrust/execution_policy.h>
#include "bit_unpack.cuh"
#include "dynamic_dispatch.h"
#include "types.cuh"
/// Binary search for first element strictly greater than value.
template <typename T>
__device__ inline uint64_t upper_bound(const T *data, uint64_t len, uint64_t value) {
auto it = thrust::upper_bound(thrust::seq, data, data + len, value);
return it - data;
}
/// Executes a source operation to fill a shared memory region with decoded data.
///
/// This function handles the first phase of each stage's pipeline. It reads
/// compressed or raw data from global memory and writes decoded elements into
/// the stage's shared memory region.
///
/// @param input Global memory pointer to the stage's encoded input data
/// @param smem_output Shared memory pointer where decoded elements are written
/// @param chunk_start Starting index of the chunk to process (block-relative for output stage)
/// @param chunk_len Number of elements to produce (may be < ELEMENTS_PER_BLOCK for tail blocks)
/// @param source_op Source operation descriptor (BITUNPACK, LOAD, or RUNEND)
/// @param smem_base Base of the entire dynamic shared memory pool, used by RUNEND
/// to resolve offsets to ends/values decoded by earlier stages
template <typename T>
__device__ inline void dynamic_source_op(const T *__restrict input,
T *__restrict smem_output,
uint64_t chunk_start,
uint32_t chunk_len,
const struct SourceOp &source_op,
T *__restrict smem_base) {
constexpr uint32_t T_BITS = sizeof(T) * 8;
switch (source_op.op_code) {
case SourceOp::BITUNPACK: {
constexpr uint32_t FL_CHUNK_SIZE = 1024;
constexpr uint32_t LANES_PER_FL_BLOCK = FL_CHUNK_SIZE / T_BITS;
const uint32_t bit_width = source_op.params.bitunpack.bit_width;
const uint32_t element_offset = source_op.params.bitunpack.element_offset;
const uint32_t packed_words_per_fl_block = LANES_PER_FL_BLOCK * bit_width;
// Shift chunk_start by the sub-block element offset.
const uint64_t first_fl_block = (chunk_start + element_offset) / FL_CHUNK_SIZE;
// FL blocks must divide evenly. Otherwise, the last unpack would overflow smem.
static_assert((ELEMENTS_PER_BLOCK % FL_CHUNK_SIZE) == 0);
const auto div_ceil = [](auto a, auto b) {
return (a + b - 1) / b;
};
const uint32_t num_fl_chunks = div_ceil(chunk_len, FL_CHUNK_SIZE);
for (uint32_t chunk_idx = 0; chunk_idx < num_fl_chunks; ++chunk_idx) {
const T *packed_chunk = input + (first_fl_block + chunk_idx) * packed_words_per_fl_block;
T *smem_lane = smem_output + chunk_idx * FL_CHUNK_SIZE;
// Distribute unpacking across threads via lane-wise decomposition.
for (uint32_t lane = threadIdx.x; lane < LANES_PER_FL_BLOCK; lane += blockDim.x) {
bit_unpack_lane<T>(packed_chunk, smem_lane, 0, lane, bit_width);
}
}
break;
}
case SourceOp::LOAD: {
// Copy elements verbatim from global memory into shared memory.
for (uint32_t i = threadIdx.x; i < chunk_len; i += blockDim.x) {
smem_output[i] = input[chunk_start + i];
}
break;
}
case SourceOp::RUNEND: {
// Ends and values were decoded into shared memory by earlier stages.
const T *ends = &smem_base[source_op.params.runend.ends_smem_offset];
const T *values = &smem_base[source_op.params.runend.values_smem_offset];
const uint64_t num_runs = source_op.params.runend.num_runs;
const uint64_t offset = source_op.params.runend.offset;
// Each thread binary-searches for its first position's run, then
// forward-scans for subsequent positions. Strided positions are
// monotonically increasing per thread, so current_run only advances.
uint64_t current_run = upper_bound(ends, num_runs, chunk_start + threadIdx.x + offset);
for (uint32_t i = threadIdx.x; i < chunk_len; i += blockDim.x) {
uint64_t pos = chunk_start + i + offset;
while (current_run < num_runs && static_cast<uint64_t>(ends[current_run]) <= pos) {
current_run++;
}
smem_output[i] = values[min(current_run, num_runs - 1)];
}
break;
}
default:
__builtin_unreachable();
}
}
/// Applies a single scalar operation to N values in registers.
///
/// Scalar operations are applied element-wise after the source op fills shared
/// memory. All ops compose fluently in any order: FoR adds a constant, ZigZag
/// decodes signed integers, ALP decodes floats, and DICT gathers from a
/// dictionary in shared memory.
///
/// @param values Array of N values to transform in-place
/// @param op The scalar operation descriptor
/// @param smem_base Base of dynamic shared memory pool (used by DICT to resolve offsets)
template <typename T, uint32_t N>
__device__ inline void apply_scalar_op(T *values, const struct ScalarOp &op, T *__restrict smem_base) {
switch (op.op_code) {
case ScalarOp::FOR: {
const T ref = static_cast<T>(op.params.frame_of_ref.reference);
// clang-format off
#pragma unroll
// clang-format on
for (uint32_t i = 0; i < N; ++i) {
values[i] += ref;
}
break;
}
case ScalarOp::ZIGZAG: {
// clang-format off
#pragma unroll
// clang-format on
for (uint32_t i = 0; i < N; ++i) {
values[i] = (values[i] >> 1) ^ static_cast<T>(-(values[i] & 1));
}
break;
}
case ScalarOp::ALP: {
const float f = op.params.alp.f;
const float e = op.params.alp.e;
// clang-format off
#pragma unroll
// clang-format on
for (uint32_t i = 0; i < N; ++i) {
float result = static_cast<float>(static_cast<int32_t>(values[i])) * f * e;
values[i] = static_cast<T>(__float_as_uint(result));
}
break;
}
case ScalarOp::DICT: {
const T *dict_values = &smem_base[op.params.dict.values_smem_offset];
// clang-format off
#pragma unroll
// clang-format on
for (uint32_t i = 0; i < N; ++i) {
values[i] = dict_values[static_cast<uint32_t>(values[i])];
}
break;
}
default:
__builtin_unreachable();
}
}
/// Store policy for global memory writes.
enum class StorePolicy {
/// Default write-back stores — data stays in L2 cache.
WRITEBACK,
/// Streaming stores (`__stcs` / `st.cs`) — hint L2 to evict early.
/// Use for write-only output data that this kernel will not read again.
/// `__stcs` is a regular synchronous store (not async like `cp.async`),
/// so the existing `__syncthreads()` barrier after each tile is
/// sufficient for ordering.
STREAMING,
};
/// Reads values from `smem_input`, applies scalar ops in registers, and
/// writes results to `write_dest` at `write_offset`.
template <typename T, StorePolicy S>
__device__ void apply_scalar_ops(const T *__restrict smem_input,
T *__restrict write_dest,
uint64_t write_offset,
uint32_t chunk_len,
uint8_t num_scalar_ops,
const struct ScalarOp *scalar_ops,
T *__restrict smem_base) {
constexpr uint32_t VALUES_PER_LOOP = 64 / sizeof(T);
const uint32_t tile_size = blockDim.x * VALUES_PER_LOOP;
const uint32_t num_full_tiles = chunk_len / tile_size;
// Each thread holds multiple values in registers for instruction-level
// parallelism, hiding pipeline latency between independent operations.
for (uint32_t tile = 0; tile < num_full_tiles; ++tile) {
const uint32_t tile_base = tile * tile_size;
T values[VALUES_PER_LOOP];
// clang-format off
#pragma unroll
// clang-format on
for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) {
values[idx] = smem_input[tile_base + idx * blockDim.x + threadIdx.x];
}
for (uint8_t op_idx = 0; op_idx < num_scalar_ops; ++op_idx) {
apply_scalar_op<T, VALUES_PER_LOOP>(values, scalar_ops[op_idx], smem_base);
}
// clang-format off
#pragma unroll
// clang-format on
for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) {
if constexpr (S == StorePolicy::STREAMING) {
__stcs(&write_dest[write_offset + tile_base + idx * blockDim.x + threadIdx.x], values[idx]);
} else {
write_dest[write_offset + tile_base + idx * blockDim.x + threadIdx.x] = values[idx];
}
}
}
const uint32_t rem_start = num_full_tiles * tile_size;
for (uint32_t elem_idx = rem_start + threadIdx.x; elem_idx < chunk_len; elem_idx += blockDim.x) {
T val = smem_input[elem_idx];
for (uint8_t op_idx = 0; op_idx < num_scalar_ops; ++op_idx) {
apply_scalar_op<T, 1>(&val, scalar_ops[op_idx], smem_base);
}
if constexpr (S == StorePolicy::STREAMING) {
__stcs(&write_dest[write_offset + elem_idx], val);
} else {
write_dest[write_offset + elem_idx] = val;
}
}
}
/// Decodes and transforms a stage's data through shared memory, writing
/// final results to `write_dest` at `write_offset`. Input stages write
/// back to smem; the output stage writes to global memory.
template <typename T, StorePolicy S>
__device__ void execute_stage(const struct Stage &stage,
T *__restrict smem_base,
uint64_t chunk_start,
uint32_t chunk_len,
T *__restrict write_dest,
uint64_t write_offset) {
T *smem_output = &smem_base[stage.smem_offset];
dynamic_source_op<T>(reinterpret_cast<const T *>(stage.input_ptr),
smem_output,
chunk_start,
chunk_len,
stage.source,
smem_base);
__syncthreads();
apply_scalar_ops<T, S>(smem_output,
write_dest,
write_offset,
chunk_len,
stage.num_scalar_ops,
stage.scalar_ops,
smem_base);
__syncthreads();
}
/// Entry point of the dynamic dispatch kernel.
///
/// Executes the plan's stages in order:
/// 1. Input stages populate shared memory with intermediate data
/// for the output stage to reference.
/// 2. The output stage decodes the root array and writes directly to
/// global memory.
///
/// @param output Global memory output buffer
/// @param array_len Total number of elements to produce
/// @param plan Device pointer to the dispatch plan
template <typename T>
__device__ void dynamic_dispatch_impl(T *__restrict output,
uint64_t array_len,
const struct DynamicDispatchPlan *__restrict plan) {
// Dynamically-sized shared memory: The host computes the exact byte count
// needed to hold all stage outputs that must coexist simultaneously, and
// passes the count at kernel launch (see DynamicDispatchPlan::shared_mem_bytes).
extern __shared__ char smem_bytes[];
T *smem_base = reinterpret_cast<T *>(smem_bytes);
__shared__ struct DynamicDispatchPlan smem_plan;
if (threadIdx.x == 0) {
smem_plan = *plan;
}
__syncthreads();
const uint8_t last = smem_plan.num_stages - 1;
// Input stages: Decode inputs into smem regions.
for (uint8_t i = 0; i < last; ++i) {
const struct Stage &stage = smem_plan.stages[i];
T *smem_output = &smem_base[stage.smem_offset];
execute_stage<T, StorePolicy::WRITEBACK>(stage, smem_base, 0, stage.len, smem_output, 0);
}
// Output stage: process in SMEM_TILE_SIZE tiles to reduce smem footprint.
// Each tile decodes into the same smem region and writes to global memory.
const struct Stage &output_stage = smem_plan.stages[last];
const uint64_t block_start = static_cast<uint64_t>(blockIdx.x) * ELEMENTS_PER_BLOCK;
const uint64_t block_end = min(block_start + ELEMENTS_PER_BLOCK, array_len);
const uint32_t block_len = static_cast<uint32_t>(block_end - block_start);
for (uint32_t tile_off = 0; tile_off < block_len; tile_off += SMEM_TILE_SIZE) {
const uint32_t tile_len = min(SMEM_TILE_SIZE, block_len - tile_off);
execute_stage<T, StorePolicy::STREAMING>(output_stage,
smem_base,
block_start + tile_off,
tile_len,
output,
block_start + tile_off);
}
}
/// Generates a dynamic dispatch kernel entry point for each unsigned integer type.
#define GENERATE_DYNAMIC_DISPATCH_KERNEL(suffix, Type) \
extern "C" __global__ void dynamic_dispatch_##suffix( \
Type *__restrict output, \
uint64_t array_len, \
const struct DynamicDispatchPlan *__restrict plan) { \
dynamic_dispatch_impl<Type>(output, array_len, plan); \
}
FOR_EACH_UNSIGNED_INT(GENERATE_DYNAMIC_DISPATCH_KERNEL)