|
| 1 | +#include "max.cuh" |
| 2 | +#include "utils.cuh" |
| 3 | +#include <cstddef> |
| 4 | +#include <cuda_bf16.hpp> |
| 5 | +#include <limits> |
| 6 | + |
| 7 | +namespace tensor::kernels { |
| 8 | + |
| 9 | +using namespace dtype; |
| 10 | + |
| 11 | +const int blockThreads = 256; |
| 12 | + |
| 13 | +__global__ void max_float_kernel(Cuda<float>* out, Cuda<float>* input, size_t num_reductions, size_t reduce_size, size_t reduce_stride) { |
| 14 | + __shared__ Cuda<float> shmem[blockThreads]; // NOLINT |
| 15 | + size_t tid = threadIdx.x; |
| 16 | + |
| 17 | + size_t reduction_idx = blockIdx.x; // which reduction are we doing? |
| 18 | + |
| 19 | + // decompose into outer and inner indices |
| 20 | + size_t outer_idx = reduction_idx / reduce_stride; |
| 21 | + size_t inner_idx = reduction_idx % reduce_stride; |
| 22 | + |
| 23 | + // base pointer for this reduction |
| 24 | + size_t base = (outer_idx * reduce_size * reduce_stride) + inner_idx; |
| 25 | + |
| 26 | + // reduce with a grid stride loop to handle reduce_size > blockThreads |
| 27 | + float thread_max = -std::numeric_limits<float>::infinity(); |
| 28 | + for (size_t element = tid; element < reduce_size; element += blockDim.x) { |
| 29 | + thread_max = max(thread_max, input[base + (element * reduce_stride)]); |
| 30 | + } |
| 31 | + // now we only have to reduce 'blockThreads' elements, which is easy within a block |
| 32 | + |
| 33 | + // load partial maxs onto shmem |
| 34 | + shmem[tid] = thread_max; |
| 35 | + __syncthreads(); |
| 36 | + |
| 37 | + // reduce in shared memory |
| 38 | + for (int stride = blockDim.x / 2; stride > 32; stride >>= 1) { // NOLINT |
| 39 | + if (tid < stride) { shmem[tid] = max(shmem[tid], shmem[tid + stride]); } |
| 40 | + __syncthreads(); |
| 41 | + } |
| 42 | + |
| 43 | + // warp shuffle for the final warp-level reduction |
| 44 | + if (tid < 32) { |
| 45 | + float val = max(shmem[tid], shmem[tid + 32]); |
| 46 | + for (int offset = 16; offset > 0; offset >>= 1) { |
| 47 | + val = max(val, __shfl_down_sync(0xffffffff, val, offset)); |
| 48 | + } |
| 49 | + |
| 50 | + if (tid == 0) { |
| 51 | + out[reduction_idx] = val; |
| 52 | + } |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +Tensor<float, CUDA> max_float(const TensorView<float, CUDA>& input, int dim, bool keepdim) { |
| 57 | + assert(input.is_contiguous() && "the tensor should be contiguous"); |
| 58 | + |
| 59 | + auto shape = input.shape; |
| 60 | + |
| 61 | + if (dim < 0) { |
| 62 | + dim = shape.size() + dim; |
| 63 | + } |
| 64 | + |
| 65 | + assert(dim >= 0 && static_cast<size_t>(dim) < shape.size()); |
| 66 | + |
| 67 | + size_t outer_size = 1; // how many reductions will we perform? ("batch size") |
| 68 | + size_t inner_size = 1; // what's the distance between elements to reduce? |
| 69 | + size_t reduce_size = 1; // how many elements each reduction needs to reduce over |
| 70 | + |
| 71 | + bool found_dim = false; |
| 72 | + |
| 73 | + // Output shape |
| 74 | + Shape out_shape; |
| 75 | + for (size_t i = 0; i < shape.size(); ++i) { |
| 76 | + if (i == static_cast<size_t>(dim)) { |
| 77 | + if (keepdim) { |
| 78 | + out_shape.push_back(1); |
| 79 | + } |
| 80 | + reduce_size = shape[dim]; |
| 81 | + found_dim = true; |
| 82 | + } else { |
| 83 | + if (!found_dim) { |
| 84 | + outer_size *= shape[i]; |
| 85 | + } else { |
| 86 | + inner_size *= shape[i]; |
| 87 | + } |
| 88 | + |
| 89 | + out_shape.push_back(shape[i]); |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | + if (out_shape.empty()) { |
| 94 | + out_shape.push_back(1); |
| 95 | + } |
| 96 | + |
| 97 | + auto n_elements = outer_size * inner_size; |
| 98 | + |
| 99 | + auto input_strides = get_all_strides(shape); |
| 100 | + |
| 101 | + TensorStorage<float, CUDA> storage(n_elements); |
| 102 | + Tensor<float, CUDA> out{out_shape, std::move(storage)}; |
| 103 | + |
| 104 | + int block_size = blockThreads; |
| 105 | + |
| 106 | + // Convert to device-native types for kernel call |
| 107 | + auto* out_d = reinterpret_cast<Cuda<float>*>(out.data()); // NOLINT |
| 108 | + auto* input_d = reinterpret_cast<Cuda<float>*>(input.data); // NOLINT |
| 109 | + |
| 110 | + max_float_kernel<<<n_elements, block_size>>>(out_d, input_d, n_elements, reduce_size, inner_size); |
| 111 | + |
| 112 | + return out; |
| 113 | +} |
| 114 | + |
| 115 | +} // namespace tensor::kernels |
0 commit comments