Skip to content

Commit 16cc8b6

Browse files
committed
apply pr comment
1 parent d924266 commit 16cc8b6

3 files changed

Lines changed: 292 additions & 247 deletions

File tree

mlx/backend/cuda/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ target_sources(
3939
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
4040
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
4141
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
42+
${CMAKE_CURRENT_SOURCE_DIR}/partition.cu
4243
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
4344
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
4445
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu

mlx/backend/cuda/partition.cu

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
// Copyright © 2026 Apple Inc.
2+
3+
#include <algorithm>
4+
#include <cstdint>
5+
#include <stdexcept>
6+
7+
#include "mlx/backend/cuda/device.h"
8+
#include "mlx/backend/cuda/device/radix_select.cuh"
9+
#include "mlx/backend/cuda/kernel_utils.cuh"
10+
#include "mlx/dtype.h"
11+
#include "mlx/dtype_utils.h"
12+
13+
namespace mlx::core {
14+
15+
void gpu_partition_fallback(
16+
const Stream& s,
17+
const array& in,
18+
array& out,
19+
int axis,
20+
bool arg_partition);
21+
22+
namespace {
23+
24+
// Upper bound for small-kernel tiling. Keep this aligned with the
25+
// items-per-thread dispatch set and per-block shared-memory budget.
26+
constexpr int MAX_RADIX_ITEMS_PER_THREAD = 64;
27+
constexpr size_t RADIX_SMALL_SHARED_MEM_BUDGET_BYTES = 48 * 1024;
28+
29+
template <typename F>
30+
void dispatch_radix_small_block_threads(int size_sorted_axis, F&& f) {
31+
if (size_sorted_axis <= 256) {
32+
f(std::integral_constant<int, 32>{});
33+
} else if (size_sorted_axis <= 512) {
34+
f(std::integral_constant<int, 64>{});
35+
} else if (size_sorted_axis <= 1024) {
36+
f(std::integral_constant<int, 128>{});
37+
} else {
38+
f(std::integral_constant<int, 256>{});
39+
}
40+
}
41+
42+
template <typename F>
43+
void dispatch_radix_items_per_thread(
44+
int size_sorted_axis,
45+
int block_threads,
46+
F&& f) {
47+
int items_per_thread = (size_sorted_axis + block_threads - 1) / block_threads;
48+
if (items_per_thread <= 1) {
49+
f(std::integral_constant<int, 1>{});
50+
} else if (items_per_thread <= 2) {
51+
f(std::integral_constant<int, 2>{});
52+
} else if (items_per_thread <= 4) {
53+
f(std::integral_constant<int, 4>{});
54+
} else if (items_per_thread <= 8) {
55+
f(std::integral_constant<int, 8>{});
56+
} else if (items_per_thread <= 12) {
57+
f(std::integral_constant<int, 12>{});
58+
} else if (items_per_thread <= 16) {
59+
f(std::integral_constant<int, 16>{});
60+
} else if (items_per_thread <= 24) {
61+
f(std::integral_constant<int, 24>{});
62+
} else {
63+
f(std::integral_constant<int, MAX_RADIX_ITEMS_PER_THREAD>{});
64+
}
65+
}
66+
67+
size_t radix_small_shared_mem_bytes(
68+
size_t key_size,
69+
int block_threads,
70+
int items_per_thread) {
71+
size_t tile_size = static_cast<size_t>(block_threads) *
72+
static_cast<size_t>(items_per_thread);
73+
size_t num_warps = static_cast<size_t>(block_threads / WARP_SIZE);
74+
return tile_size * key_size + // shared_keys
75+
tile_size * sizeof(uint32_t) + // shared_idxs
76+
cu::RADIX_SIZE * sizeof(int) + // shared_hist for small kernel
77+
(2 + 3 * num_warps + 6) * sizeof(int); // shared_count + scatter scratch
78+
}
79+
80+
bool radix_small_fits_shared_memory(Dtype dtype, int size_sorted_axis) {
81+
if (size_sorted_axis <= 0) {
82+
return false;
83+
}
84+
85+
size_t required_shared_mem = 0;
86+
bool fits = false;
87+
dispatch_radix_small_block_threads(size_sorted_axis, [&](auto block_dim_tag) {
88+
constexpr int BLOCK_THREADS = block_dim_tag();
89+
int required_items = (size_sorted_axis + BLOCK_THREADS - 1) / BLOCK_THREADS;
90+
if (required_items > MAX_RADIX_ITEMS_PER_THREAD) {
91+
fits = false;
92+
return;
93+
}
94+
95+
dispatch_radix_items_per_thread(
96+
size_sorted_axis, BLOCK_THREADS, [&](auto items_per_thread_tag) {
97+
constexpr int ITEMS_PER_THREAD = items_per_thread_tag();
98+
required_shared_mem = radix_small_shared_mem_bytes(
99+
size_of(dtype), BLOCK_THREADS, ITEMS_PER_THREAD);
100+
fits = required_shared_mem <= RADIX_SMALL_SHARED_MEM_BUDGET_BYTES;
101+
});
102+
});
103+
return fits;
104+
}
105+
106+
void gpu_radix_partition_small(
107+
const Stream& s,
108+
const array& in,
109+
array& out,
110+
int axis,
111+
int kth,
112+
bool arg_partition) {
113+
int n_rows = in.size() / in.shape(axis);
114+
115+
auto in_nc_str = in.strides();
116+
in_nc_str.erase(in_nc_str.begin() + axis);
117+
118+
auto out_nc_str = out.strides();
119+
out_nc_str.erase(out_nc_str.begin() + axis);
120+
121+
auto nc_shape = in.shape();
122+
nc_shape.erase(nc_shape.begin() + axis);
123+
124+
int nc_dim = nc_shape.size();
125+
126+
int size_sorted_axis = in.shape(axis);
127+
int64_t in_stride_sorted_axis = in.strides()[axis];
128+
int64_t out_stride_sorted_axis = out.strides()[axis];
129+
130+
bool contiguous = in.flags().contiguous;
131+
auto check_strides = [](const array& x, int64_t sort_stride) {
132+
int64_t min_stride =
133+
*std::min_element(x.strides().begin(), x.strides().end());
134+
int64_t max_stride =
135+
*std::max_element(x.strides().begin(), x.strides().end());
136+
return sort_stride == min_stride || sort_stride == max_stride;
137+
};
138+
contiguous &= check_strides(in, in_stride_sorted_axis);
139+
contiguous &= check_strides(out, out_stride_sorted_axis);
140+
141+
auto& encoder = cu::get_command_encoder(s);
142+
out.set_data(cu::malloc_async(out.nbytes(), encoder));
143+
encoder.set_input_array(in);
144+
encoder.set_output_array(out);
145+
146+
auto nc_shape_param = const_param(nc_shape);
147+
auto in_nc_strides_param = const_param(in_nc_str);
148+
auto out_nc_strides_param = const_param(out_nc_str);
149+
150+
dispatch_all_types(in.dtype(), [&](auto type_tag) {
151+
using CTYPE = MLX_GET_TYPE(type_tag);
152+
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
153+
using ValT = cuda_type_t<CTYPE>;
154+
155+
dispatch_bool(arg_partition, [&](auto arg_tag) {
156+
constexpr bool ARG_PARTITION = decltype(arg_tag)::value;
157+
using OutT = std::conditional_t<ARG_PARTITION, uint32_t, ValT>;
158+
159+
int64_t in_stride_segment_axis = INT64_MAX;
160+
int64_t out_stride_segment_axis = INT64_MAX;
161+
if (contiguous) {
162+
for (size_t i = 0; i < nc_shape.size(); i++) {
163+
if (nc_shape[i] == 1) {
164+
continue;
165+
}
166+
in_stride_segment_axis =
167+
std::min(in_stride_segment_axis, in_nc_str[i]);
168+
out_stride_segment_axis =
169+
std::min(out_stride_segment_axis, out_nc_str[i]);
170+
}
171+
}
172+
173+
dispatch_radix_small_block_threads(
174+
size_sorted_axis, [&](auto block_dim_tag) {
175+
constexpr int BLOCK_THREADS = block_dim_tag();
176+
dim3 grid(1, n_rows, 1);
177+
dim3 block(BLOCK_THREADS, 1, 1);
178+
179+
dispatch_radix_items_per_thread(
180+
size_sorted_axis,
181+
BLOCK_THREADS,
182+
[&](auto items_per_thread_tag) {
183+
constexpr int ITEMS_PER_THREAD = items_per_thread_tag();
184+
185+
dispatch_bool(contiguous, [&](auto contiguous_tag) {
186+
constexpr bool USE_SIMPLE_STRIDE =
187+
decltype(contiguous_tag)::value;
188+
189+
auto kernel = cu::radix_select_small_kernel<
190+
ValT,
191+
OutT,
192+
ARG_PARTITION,
193+
USE_SIMPLE_STRIDE,
194+
BLOCK_THREADS,
195+
ITEMS_PER_THREAD>;
196+
197+
// Calculate dynamic shared memory size
198+
using UnsignedT =
199+
typename cu::RadixTraits<ValT>::UnsignedT;
200+
constexpr int TILE_SIZE_VAL =
201+
BLOCK_THREADS * ITEMS_PER_THREAD;
202+
constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE;
203+
constexpr size_t shared_mem_bytes =
204+
TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys
205+
TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs
206+
cu::RADIX_SIZE *
207+
sizeof(int) + // shared_hist for small kernel
208+
(2 + 3 * NUM_WARPS + 6) *
209+
sizeof(int); // shared_count + scatter scratch
210+
211+
encoder.add_kernel_node_ex(
212+
kernel,
213+
grid,
214+
block,
215+
{},
216+
static_cast<uint32_t>(shared_mem_bytes),
217+
gpu_ptr<ValT>(in),
218+
gpu_ptr<OutT>(out),
219+
kth,
220+
size_sorted_axis,
221+
in_stride_sorted_axis,
222+
out_stride_sorted_axis,
223+
in_stride_segment_axis,
224+
out_stride_segment_axis,
225+
nc_shape_param,
226+
in_nc_strides_param,
227+
out_nc_strides_param,
228+
nc_dim);
229+
});
230+
});
231+
});
232+
});
233+
} else {
234+
throw std::runtime_error(
235+
"CUDA backend does not support sorting complex numbers");
236+
}
237+
});
238+
}
239+
240+
} // namespace
241+
242+
void gpu_partition(
243+
const Stream& s,
244+
const array& in,
245+
array& out,
246+
int axis_,
247+
int kth_,
248+
bool arg_partition) {
249+
int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
250+
int size_sorted_axis = in.shape(axis);
251+
int kth = kth_ < 0 ? kth_ + size_sorted_axis : kth_;
252+
int nc_dim = static_cast<int>(in.ndim()) - 1;
253+
254+
// Fixed-size const_param metadata is capped by MAX_NDIM.
255+
if (nc_dim > MAX_NDIM) {
256+
return gpu_partition_fallback(s, in, out, axis, arg_partition);
257+
}
258+
259+
// Dispatch based on whether the small kernel tile fits in shared memory.
260+
if (radix_small_fits_shared_memory(in.dtype(), size_sorted_axis)) {
261+
return gpu_radix_partition_small(s, in, out, axis, kth, arg_partition);
262+
} else {
263+
return gpu_partition_fallback(s, in, out, axis, arg_partition);
264+
}
265+
}
266+
267+
} // namespace mlx::core

0 commit comments

Comments
 (0)