Skip to content

Commit e732b85

Browse files
committed
Decouple CommandEncoder from Device
1 parent df7f7db commit e732b85

34 files changed

Lines changed: 198 additions & 237 deletions

docs/src/dev/extensions.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ below.
404404
auto kernel = d.get_kernel(kname, lib);
405405

406406
// Prepare to encode kernel
407-
auto& compute_encoder = d.get_command_encoder(s.index);
407+
auto& compute_encoder = metal::get_command_encoder(s);
408408
compute_encoder.set_compute_pipeline_state(kernel);
409409

410410
// Kernel parameters are registered with buffer indices corresponding to
@@ -448,7 +448,7 @@ We can now call the :meth:`axpby` operation on both the CPU and the GPU!
448448

449449
A few things to note about MLX and Metal before moving on. MLX keeps track of
450450
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
451-
associated. We rely on :meth:`d.get_command_encoder` to give us the active
451+
associated. We rely on :meth:`metal::get_command_encoder` to give us the active
452452
metal compute command encoder instead of building a new one and calling
453453
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
454454
pipelines) to the active command buffer until some specified limit is hit or

examples/extensions/axpby/axpby.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ void Axpby::eval_gpu(
192192
auto kernel = d.get_kernel(kname, lib);
193193

194194
// Prepare to encode kernel
195-
auto& compute_encoder = d.get_command_encoder(s.index);
195+
auto& compute_encoder = metal::get_command_encoder(s);
196196
compute_encoder.set_compute_pipeline_state(kernel);
197197

198198
// Kernel parameters are registered with buffer indices corresponding to

mlx/backend/cuda/quantized/quantized.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ void fast::Quantize::eval_gpu(
109109
std::vector<array>& outputs) {
110110
nvtx3::scoped_range r("Quantize::eval_gpu");
111111
auto& s = stream();
112-
auto& d = cu::device(s.device);
113-
auto& enc = d.get_command_encoder(s);
112+
auto& enc = cu::get_command_encoder(s);
114113
if (dequantize_) {
115114
auto wq = ensure_row_contiguous(inputs[0], enc, s);
116115
auto scales = ensure_row_contiguous(inputs[1], enc, s);

mlx/backend/metal/allocator.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ void* Buffer::raw_ptr() {
3131

3232
namespace metal {
3333

34-
MetalAllocator::MetalAllocator()
35-
: device_(device(mlx::core::Device::gpu).mtl_device()),
34+
MetalAllocator::MetalAllocator(Device& d)
35+
: device_(d.mtl_device()),
36+
residency_set_(d.residency_set()),
3637
buffer_cache_(
3738
vm_page_size,
3839
[](MTL::Buffer* buf) { return buf->length(); },
@@ -42,8 +43,7 @@ MetalAllocator::MetalAllocator()
4243
}
4344
auto pool = metal::new_scoped_memory_pool();
4445
buf->release();
45-
}),
46-
residency_set_(device_) {
46+
}) {
4747
const auto& info = gpu::device_info(0);
4848
auto memsize = std::get<size_t>(info.at("memory_size"));
4949
auto max_rec_size =
@@ -52,8 +52,6 @@ MetalAllocator::MetalAllocator()
5252
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
5353
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
5454
max_pool_size_ = block_limit_;
55-
device(mlx::core::Device::gpu)
56-
.set_residency_set(residency_set_.mtl_residency_set());
5755
bool is_vm = std::get<std::string>(info.at("device_name")) ==
5856
"Apple Paravirtual device";
5957
if (is_vm) {
@@ -226,7 +224,8 @@ MetalAllocator& allocator() {
226224
// By creating the |allocator_| on heap, the destructor of MetalAllocator
227225
// will not be called on exit and buffers in the cache will be leaked. This
228226
// can save some time at program exit.
229-
static MetalAllocator* allocator_ = new MetalAllocator;
227+
static MetalAllocator* allocator_ =
228+
new MetalAllocator(device(mlx::core::Device::gpu));
230229
return *allocator_;
231230
}
232231

mlx/backend/metal/allocator.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "mlx/allocator.h"
1010
#include "mlx/backend/common/buffer_cache.h"
1111
#include "mlx/backend/metal/device.h"
12-
#include "mlx/backend/metal/resident.h"
1312

1413
namespace mlx::core::metal {
1514

@@ -52,13 +51,13 @@ class MetalAllocator : public allocator::Allocator {
5251
static constexpr int small_size_ = 256;
5352
static constexpr int heap_size_ = 1 << 20;
5453

55-
MetalAllocator();
54+
MetalAllocator(Device& d);
5655
~MetalAllocator();
5756

5857
friend MetalAllocator& allocator();
5958

6059
NS::SharedPtr<MTL::Heap> heap_;
61-
ResidencySet residency_set_;
60+
ResidencySet& residency_set_;
6261

6362
// Caching allocator
6463
BufferCache<MTL::Buffer> buffer_cache_;

mlx/backend/metal/binary.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ void binary_op_gpu_inplace(
106106
auto kernel = outputs.size() == 2
107107
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
108108
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
109-
auto& compute_encoder = d.get_command_encoder(s.index);
109+
auto& compute_encoder = metal::get_command_encoder(s);
110110
compute_encoder.set_compute_pipeline_state(kernel);
111111

112112
int arg_idx = 0;

mlx/backend/metal/compiled.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ void Compiled::eval_gpu(
389389
kernel_name += "_large";
390390
}
391391
auto kernel = d.get_kernel(kernel_name, lib);
392-
auto& compute_encoder = d.get_command_encoder(s.index);
392+
auto& compute_encoder = metal::get_command_encoder(s);
393393
compute_encoder.set_compute_pipeline_state(kernel);
394394

395395
// Put the inputs in

mlx/backend/metal/conv.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
2626
return x;
2727
}
2828
auto result = contiguous_copy_gpu(x, s);
29-
d.add_temporary(result, s.index);
29+
metal::get_command_encoder(s).add_temporary(result);
3030
return result;
3131
}
3232

@@ -52,7 +52,7 @@ void explicit_gemm_conv_ND_gpu(
5252
std::string kname;
5353
kname.reserve(32);
5454
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
55-
auto& compute_encoder = d.get_command_encoder(s.index);
55+
auto& compute_encoder = metal::get_command_encoder(s);
5656
auto kernel = d.get_kernel(kname);
5757
compute_encoder.set_compute_pipeline_state(kernel);
5858

@@ -132,7 +132,7 @@ void explicit_gemm_conv_group_ND_gpu(
132132
kname.reserve(32);
133133
concatenate(
134134
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
135-
auto& compute_encoder = d.get_command_encoder(s.index);
135+
auto& compute_encoder = metal::get_command_encoder(s);
136136
auto kernel = d.get_kernel(kname);
137137
compute_encoder.set_compute_pipeline_state(kernel);
138138

@@ -286,7 +286,7 @@ void implicit_gemm_conv_2D_gpu(
286286
small_filter ? 's' : 'l');
287287

288288
// Encode and dispatch kernel
289-
auto& compute_encoder = d.get_command_encoder(s.index);
289+
auto& compute_encoder = metal::get_command_encoder(s);
290290
auto kernel = get_steel_conv_kernel(
291291
d,
292292
kname,
@@ -469,7 +469,7 @@ void implicit_gemm_conv_2D_general_gpu(
469469
};
470470

471471
// Encode and dispatch kernel
472-
auto& compute_encoder = d.get_command_encoder(s.index);
472+
auto& compute_encoder = metal::get_command_encoder(s);
473473
auto kernel = get_steel_conv_general_kernel(
474474
d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
475475
compute_encoder.set_compute_pipeline_state(kernel);
@@ -595,7 +595,7 @@ void implicit_gemm_conv_3D_gpu(
595595
small_filter ? 's' : 'l');
596596

597597
// Encode and dispatch kernel
598-
auto& compute_encoder = d.get_command_encoder(s.index);
598+
auto& compute_encoder = metal::get_command_encoder(s);
599599
auto kernel =
600600
get_steel_conv_3d_kernel(d, kname, out, bm, bn, bk, wm, wn, small_filter);
601601
compute_encoder.set_compute_pipeline_state(kernel);
@@ -644,7 +644,7 @@ void pad_and_slice_conv_3D_gpu(
644644
array x_copy(xshape, x.dtype(), nullptr, {});
645645
array zero(0, x.dtype());
646646
pad_gpu(x, zero, x_copy, {0, -1}, {0, 0}, s);
647-
d.add_temporary(x_copy, s.index);
647+
metal::get_command_encoder(s).add_temporary(x_copy);
648648

649649
return x_copy;
650650
};
@@ -804,7 +804,7 @@ void winograd_conv_2D_gpu(
804804
type_to_name(out),
805805
"_bc",
806806
bc);
807-
auto& compute_encoder = d.get_command_encoder(s.index);
807+
auto& compute_encoder = metal::get_command_encoder(s);
808808
auto kernel = d.get_kernel(kname);
809809
compute_encoder.set_compute_pipeline_state(kernel);
810810

@@ -837,7 +837,7 @@ void winograd_conv_2D_gpu(
837837
type_to_name(out),
838838
"_bc",
839839
bc);
840-
auto& compute_encoder = d.get_command_encoder(s.index);
840+
auto& compute_encoder = metal::get_command_encoder(s);
841841
auto kernel = d.get_kernel(kname);
842842
compute_encoder.set_compute_pipeline_state(kernel);
843843

@@ -889,7 +889,7 @@ void winograd_conv_2D_gpu(
889889
type_to_name(out),
890890
"_bo",
891891
bc);
892-
auto& compute_encoder = d.get_command_encoder(s.index);
892+
auto& compute_encoder = metal::get_command_encoder(s);
893893
auto kernel = d.get_kernel(kname);
894894
compute_encoder.set_compute_pipeline_state(kernel);
895895

@@ -950,7 +950,7 @@ void depthwise_conv_2D_gpu(
950950
"_tgp_w_", tw,
951951
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
952952

953-
auto& compute_encoder = d.get_command_encoder(s.index);
953+
auto& compute_encoder = metal::get_command_encoder(s);
954954
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
955955
compute_encoder.set_compute_pipeline_state(kernel);
956956

@@ -1044,7 +1044,7 @@ void depthwise_conv_1D_gpu(
10441044
type_to_name(out),
10451045
large ? "_large" : "");
10461046

1047-
auto& compute_encoder = d.get_command_encoder(s.index);
1047+
auto& compute_encoder = metal::get_command_encoder(s);
10481048
auto kernel = d.get_kernel(base_name);
10491049
compute_encoder.set_compute_pipeline_state(kernel);
10501050

@@ -1348,7 +1348,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
13481348

13491349
// Record copies
13501350
if (!copies.empty()) {
1351-
d.add_temporaries(std::move(copies), s.index);
1351+
metal::get_command_encoder(s).add_temporaries(std::move(copies));
13521352
}
13531353
}
13541354

mlx/backend/metal/copy.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void copy_gpu_inplace(
107107
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
108108
: get_copy_kernel(d, kernel_name, in, out);
109109

110-
auto& compute_encoder = d.get_command_encoder(s.index);
110+
auto& compute_encoder = metal::get_command_encoder(s);
111111
compute_encoder.set_compute_pipeline_state(kernel);
112112

113113
inp_offset *= size_of(in.dtype());
@@ -190,7 +190,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
190190
std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s");
191191
concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out));
192192
auto kernel = get_copy_kernel(d, kernel_name, val, out);
193-
auto& compute_encoder = d.get_command_encoder(s.index);
193+
auto& compute_encoder = metal::get_command_encoder(s);
194194
compute_encoder.set_compute_pipeline_state(kernel);
195195

196196
compute_encoder.set_input_array(val, 0);

mlx/backend/metal/custom_kernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ void CustomKernel::eval_gpu(
378378

379379
auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
380380
auto kernel = d.get_kernel(name_, lib);
381-
auto& compute_encoder = d.get_command_encoder(s.index);
381+
auto& compute_encoder = metal::get_command_encoder(s);
382382
compute_encoder.set_compute_pipeline_state(kernel);
383383
int index = 0;
384384
for (int i = 0; i < checked_inputs.size(); i++) {
@@ -424,7 +424,7 @@ void CustomKernel::eval_gpu(
424424
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
425425
compute_encoder.dispatch_threads(grid_dims, group_dims);
426426

427-
d.add_temporaries(std::move(copies), s.index);
427+
compute_encoder.add_temporaries(std::move(copies));
428428
}
429429

430430
} // namespace mlx::core::fast

0 commit comments

Comments
 (0)