|
| 1 | +// Copyright © 2026 Apple Inc. |
| 2 | + |
| 3 | +#include <algorithm> |
| 4 | +#include <cstdint> |
| 5 | +#include <stdexcept> |
| 6 | + |
| 7 | +#include "mlx/backend/cuda/device.h" |
| 8 | +#include "mlx/backend/cuda/device/radix_select.cuh" |
| 9 | +#include "mlx/backend/cuda/kernel_utils.cuh" |
| 10 | +#include "mlx/dtype.h" |
| 11 | +#include "mlx/dtype_utils.h" |
| 12 | + |
| 13 | +namespace mlx::core { |
| 14 | + |
| 15 | +void gpu_partition_fallback( |
| 16 | + const Stream& s, |
| 17 | + const array& in, |
| 18 | + array& out, |
| 19 | + int axis, |
| 20 | + bool arg_partition); |
| 21 | + |
| 22 | +namespace { |
| 23 | + |
| 24 | +// Upper bound for small-kernel tiling. Keep this aligned with the |
| 25 | +// items-per-thread dispatch set and per-block shared-memory budget. |
| 26 | +constexpr int MAX_RADIX_ITEMS_PER_THREAD = 64; |
| 27 | +constexpr size_t RADIX_SMALL_SHARED_MEM_BUDGET_BYTES = 48 * 1024; |
| 28 | + |
| 29 | +template <typename F> |
| 30 | +void dispatch_radix_small_block_threads(int size_sorted_axis, F&& f) { |
| 31 | + if (size_sorted_axis <= 256) { |
| 32 | + f(std::integral_constant<int, 32>{}); |
| 33 | + } else if (size_sorted_axis <= 512) { |
| 34 | + f(std::integral_constant<int, 64>{}); |
| 35 | + } else if (size_sorted_axis <= 1024) { |
| 36 | + f(std::integral_constant<int, 128>{}); |
| 37 | + } else { |
| 38 | + f(std::integral_constant<int, 256>{}); |
| 39 | + } |
| 40 | +} |
| 41 | + |
| 42 | +template <typename F> |
| 43 | +void dispatch_radix_items_per_thread( |
| 44 | + int size_sorted_axis, |
| 45 | + int block_threads, |
| 46 | + F&& f) { |
| 47 | + int items_per_thread = (size_sorted_axis + block_threads - 1) / block_threads; |
| 48 | + if (items_per_thread <= 1) { |
| 49 | + f(std::integral_constant<int, 1>{}); |
| 50 | + } else if (items_per_thread <= 2) { |
| 51 | + f(std::integral_constant<int, 2>{}); |
| 52 | + } else if (items_per_thread <= 4) { |
| 53 | + f(std::integral_constant<int, 4>{}); |
| 54 | + } else if (items_per_thread <= 8) { |
| 55 | + f(std::integral_constant<int, 8>{}); |
| 56 | + } else if (items_per_thread <= 12) { |
| 57 | + f(std::integral_constant<int, 12>{}); |
| 58 | + } else if (items_per_thread <= 16) { |
| 59 | + f(std::integral_constant<int, 16>{}); |
| 60 | + } else if (items_per_thread <= 24) { |
| 61 | + f(std::integral_constant<int, 24>{}); |
| 62 | + } else { |
| 63 | + f(std::integral_constant<int, MAX_RADIX_ITEMS_PER_THREAD>{}); |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +size_t radix_small_shared_mem_bytes( |
| 68 | + size_t key_size, |
| 69 | + int block_threads, |
| 70 | + int items_per_thread) { |
| 71 | + size_t tile_size = static_cast<size_t>(block_threads) * |
| 72 | + static_cast<size_t>(items_per_thread); |
| 73 | + size_t num_warps = static_cast<size_t>(block_threads / WARP_SIZE); |
| 74 | + return tile_size * key_size + // shared_keys |
| 75 | + tile_size * sizeof(uint32_t) + // shared_idxs |
| 76 | + cu::RADIX_SIZE * sizeof(int) + // shared_hist for small kernel |
| 77 | + (2 + 3 * num_warps + 6) * sizeof(int); // shared_count + scatter scratch |
| 78 | +} |
| 79 | + |
| 80 | +bool radix_small_fits_shared_memory(Dtype dtype, int size_sorted_axis) { |
| 81 | + if (size_sorted_axis <= 0) { |
| 82 | + return false; |
| 83 | + } |
| 84 | + |
| 85 | + size_t required_shared_mem = 0; |
| 86 | + bool fits = false; |
| 87 | + dispatch_radix_small_block_threads(size_sorted_axis, [&](auto block_dim_tag) { |
| 88 | + constexpr int BLOCK_THREADS = block_dim_tag(); |
| 89 | + int required_items = (size_sorted_axis + BLOCK_THREADS - 1) / BLOCK_THREADS; |
| 90 | + if (required_items > MAX_RADIX_ITEMS_PER_THREAD) { |
| 91 | + fits = false; |
| 92 | + return; |
| 93 | + } |
| 94 | + |
| 95 | + dispatch_radix_items_per_thread( |
| 96 | + size_sorted_axis, BLOCK_THREADS, [&](auto items_per_thread_tag) { |
| 97 | + constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); |
| 98 | + required_shared_mem = radix_small_shared_mem_bytes( |
| 99 | + size_of(dtype), BLOCK_THREADS, ITEMS_PER_THREAD); |
| 100 | + fits = required_shared_mem <= RADIX_SMALL_SHARED_MEM_BUDGET_BYTES; |
| 101 | + }); |
| 102 | + }); |
| 103 | + return fits; |
| 104 | +} |
| 105 | + |
| 106 | +void gpu_radix_partition_small( |
| 107 | + const Stream& s, |
| 108 | + const array& in, |
| 109 | + array& out, |
| 110 | + int axis, |
| 111 | + int kth, |
| 112 | + bool arg_partition) { |
| 113 | + int n_rows = in.size() / in.shape(axis); |
| 114 | + |
| 115 | + auto in_nc_str = in.strides(); |
| 116 | + in_nc_str.erase(in_nc_str.begin() + axis); |
| 117 | + |
| 118 | + auto out_nc_str = out.strides(); |
| 119 | + out_nc_str.erase(out_nc_str.begin() + axis); |
| 120 | + |
| 121 | + auto nc_shape = in.shape(); |
| 122 | + nc_shape.erase(nc_shape.begin() + axis); |
| 123 | + |
| 124 | + int nc_dim = nc_shape.size(); |
| 125 | + |
| 126 | + int size_sorted_axis = in.shape(axis); |
| 127 | + int64_t in_stride_sorted_axis = in.strides()[axis]; |
| 128 | + int64_t out_stride_sorted_axis = out.strides()[axis]; |
| 129 | + |
| 130 | + bool contiguous = in.flags().contiguous; |
| 131 | + auto check_strides = [](const array& x, int64_t sort_stride) { |
| 132 | + int64_t min_stride = |
| 133 | + *std::min_element(x.strides().begin(), x.strides().end()); |
| 134 | + int64_t max_stride = |
| 135 | + *std::max_element(x.strides().begin(), x.strides().end()); |
| 136 | + return sort_stride == min_stride || sort_stride == max_stride; |
| 137 | + }; |
| 138 | + contiguous &= check_strides(in, in_stride_sorted_axis); |
| 139 | + contiguous &= check_strides(out, out_stride_sorted_axis); |
| 140 | + |
| 141 | + auto& encoder = cu::get_command_encoder(s); |
| 142 | + out.set_data(cu::malloc_async(out.nbytes(), encoder)); |
| 143 | + encoder.set_input_array(in); |
| 144 | + encoder.set_output_array(out); |
| 145 | + |
| 146 | + auto nc_shape_param = const_param(nc_shape); |
| 147 | + auto in_nc_strides_param = const_param(in_nc_str); |
| 148 | + auto out_nc_strides_param = const_param(out_nc_str); |
| 149 | + |
| 150 | + dispatch_all_types(in.dtype(), [&](auto type_tag) { |
| 151 | + using CTYPE = MLX_GET_TYPE(type_tag); |
| 152 | + if constexpr (!std::is_same_v<CTYPE, complex64_t>) { |
| 153 | + using ValT = cuda_type_t<CTYPE>; |
| 154 | + |
| 155 | + dispatch_bool(arg_partition, [&](auto arg_tag) { |
| 156 | + constexpr bool ARG_PARTITION = decltype(arg_tag)::value; |
| 157 | + using OutT = std::conditional_t<ARG_PARTITION, uint32_t, ValT>; |
| 158 | + |
| 159 | + int64_t in_stride_segment_axis = INT64_MAX; |
| 160 | + int64_t out_stride_segment_axis = INT64_MAX; |
| 161 | + if (contiguous) { |
| 162 | + for (size_t i = 0; i < nc_shape.size(); i++) { |
| 163 | + if (nc_shape[i] == 1) { |
| 164 | + continue; |
| 165 | + } |
| 166 | + in_stride_segment_axis = |
| 167 | + std::min(in_stride_segment_axis, in_nc_str[i]); |
| 168 | + out_stride_segment_axis = |
| 169 | + std::min(out_stride_segment_axis, out_nc_str[i]); |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + dispatch_radix_small_block_threads( |
| 174 | + size_sorted_axis, [&](auto block_dim_tag) { |
| 175 | + constexpr int BLOCK_THREADS = block_dim_tag(); |
| 176 | + dim3 grid(1, n_rows, 1); |
| 177 | + dim3 block(BLOCK_THREADS, 1, 1); |
| 178 | + |
| 179 | + dispatch_radix_items_per_thread( |
| 180 | + size_sorted_axis, |
| 181 | + BLOCK_THREADS, |
| 182 | + [&](auto items_per_thread_tag) { |
| 183 | + constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); |
| 184 | + |
| 185 | + dispatch_bool(contiguous, [&](auto contiguous_tag) { |
| 186 | + constexpr bool USE_SIMPLE_STRIDE = |
| 187 | + decltype(contiguous_tag)::value; |
| 188 | + |
| 189 | + auto kernel = cu::radix_select_small_kernel< |
| 190 | + ValT, |
| 191 | + OutT, |
| 192 | + ARG_PARTITION, |
| 193 | + USE_SIMPLE_STRIDE, |
| 194 | + BLOCK_THREADS, |
| 195 | + ITEMS_PER_THREAD>; |
| 196 | + |
| 197 | + // Calculate dynamic shared memory size |
| 198 | + using UnsignedT = |
| 199 | + typename cu::RadixTraits<ValT>::UnsignedT; |
| 200 | + constexpr int TILE_SIZE_VAL = |
| 201 | + BLOCK_THREADS * ITEMS_PER_THREAD; |
| 202 | + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; |
| 203 | + constexpr size_t shared_mem_bytes = |
| 204 | + TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys |
| 205 | + TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs |
| 206 | + cu::RADIX_SIZE * |
| 207 | + sizeof(int) + // shared_hist for small kernel |
| 208 | + (2 + 3 * NUM_WARPS + 6) * |
| 209 | + sizeof(int); // shared_count + scatter scratch |
| 210 | + |
| 211 | + encoder.add_kernel_node_ex( |
| 212 | + kernel, |
| 213 | + grid, |
| 214 | + block, |
| 215 | + {}, |
| 216 | + static_cast<uint32_t>(shared_mem_bytes), |
| 217 | + gpu_ptr<ValT>(in), |
| 218 | + gpu_ptr<OutT>(out), |
| 219 | + kth, |
| 220 | + size_sorted_axis, |
| 221 | + in_stride_sorted_axis, |
| 222 | + out_stride_sorted_axis, |
| 223 | + in_stride_segment_axis, |
| 224 | + out_stride_segment_axis, |
| 225 | + nc_shape_param, |
| 226 | + in_nc_strides_param, |
| 227 | + out_nc_strides_param, |
| 228 | + nc_dim); |
| 229 | + }); |
| 230 | + }); |
| 231 | + }); |
| 232 | + }); |
| 233 | + } else { |
| 234 | + throw std::runtime_error( |
| 235 | + "CUDA backend does not support sorting complex numbers"); |
| 236 | + } |
| 237 | + }); |
| 238 | +} |
| 239 | + |
| 240 | +} // namespace |
| 241 | + |
| 242 | +void gpu_partition( |
| 243 | + const Stream& s, |
| 244 | + const array& in, |
| 245 | + array& out, |
| 246 | + int axis_, |
| 247 | + int kth_, |
| 248 | + bool arg_partition) { |
| 249 | + int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; |
| 250 | + int size_sorted_axis = in.shape(axis); |
| 251 | + int kth = kth_ < 0 ? kth_ + size_sorted_axis : kth_; |
| 252 | + int nc_dim = static_cast<int>(in.ndim()) - 1; |
| 253 | + |
| 254 | + // Fixed-size const_param metadata is capped by MAX_NDIM. |
| 255 | + if (nc_dim > MAX_NDIM) { |
| 256 | + return gpu_partition_fallback(s, in, out, axis, arg_partition); |
| 257 | + } |
| 258 | + |
| 259 | + // Dispatch based on whether the small kernel tile fits in shared memory. |
| 260 | + if (radix_small_fits_shared_memory(in.dtype(), size_sorted_axis)) { |
| 261 | + return gpu_radix_partition_small(s, in, out, axis, kth, arg_partition); |
| 262 | + } else { |
| 263 | + return gpu_partition_fallback(s, in, out, axis, arg_partition); |
| 264 | + } |
| 265 | +} |
| 266 | + |
| 267 | +} // namespace mlx::core |
0 commit comments