From 7af5b6a075a621a9f6511c3008488975bb159c7d Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sun, 10 May 2026 16:01:38 +0100 Subject: [PATCH 01/26] Support Metal DLPack zero-copy import --- docs/src/usage/numpy.rst | 49 +++++- mlx/allocator.h | 4 + mlx/backend/cuda/allocator.cpp | 4 + mlx/backend/metal/allocator.cpp | 18 ++- mlx/backend/no_gpu/allocator.cpp | 3 + mlx/ops.cpp | 10 ++ mlx/ops.h | 3 + mlx/utils.cpp | 16 +- python/src/buffer.h | 8 + python/src/convert.cpp | 257 +++++++++++++++++++++---------- python/src/convert.h | 2 + python/src/mlx_func.cpp | 2 + python/tests/test_array.py | 89 ++++++++++- 13 files changed, 372 insertions(+), 93 deletions(-) diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index bf71938dff..e455fbebda 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -76,6 +76,7 @@ PyTorch ------- PyTorch supports DLPack inputs and can import MLX arrays directly. +MLX can also import PyTorch tensors through DLPack with ``mx.array``. .. code-block:: python @@ -84,7 +85,53 @@ PyTorch supports DLPack inputs and can import MLX arrays directly. a = mx.arange(3) b = torch.tensor(a) - c = mx.array(b.cpu()) + c = mx.array(b) + +Creating an MLX array from a CPU tensor copies the data into MLX-owned storage. +The arrays do not share memory: + +.. code-block:: python + + b = torch.arange(3) + c = mx.array(b) + + b += 10 + print(c.tolist()) # [0, 1, 2] + +Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to +``mx.array``, MLX imports the underlying Metal buffer without copying it. The +PyTorch tensor and the MLX array then share the same storage. + +Since the buffer is shared across frameworks, synchronization has to be managed +explicitly. After PyTorch writes to an MPS tensor, call +``torch.mps.synchronize()`` before reading the shared data from MLX. After MLX +writes to the shared array, call ``mx.eval`` on the MLX result before reading +the shared data from PyTorch. Without these synchronization points, the other +framework may read the shared buffer before the producer has finished writing, +so it can observe stale data. + +.. code-block:: python + + b = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + c = mx.array(b) # zero-copy Metal DLPack import + + b.add_(10) + torch.mps.synchronize() + print(c.tolist()) # [10.0, 11.0, 12.0] + +Updates made by MLX can also be observed from PyTorch after the MLX computation +has been evaluated: + +.. code-block:: python + + b = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + c = mx.array(b) + + c += 10 + mx.eval(c) + print(b.cpu()) # tensor([10., 11., 12.]) JAX --- diff --git a/mlx/allocator.h b/mlx/allocator.h index 824deac2c7..185ccd6d76 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -21,6 +21,10 @@ class MLX_API Buffer { // Get the raw data pointer from the buffer void* raw_ptr(); + // Whether raw_ptr() can return a host-accessible pointer without moving or + // copying the buffer. + bool is_host_accessible() const; + // Get the buffer pointer from the buffer const void* ptr() const { return ptr_; diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 718ae33e9c..37dc457a02 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -416,6 +416,10 @@ void* Buffer::raw_ptr() { return cbuf.data; } +bool Buffer::is_host_accessible() const { + return true; +} + } // namespace allocator size_t get_active_memory() { diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 222c6fd9fa..5b24142fd5 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -24,7 +24,23 @@ void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; } - return static_cast(ptr_)->contents(); + auto* buf = static_cast(ptr_); + auto* contents = buf->contents(); + if (!contents && buf->length() > 0) { + throw std::runtime_error( + "[metal::Buffer::raw_ptr] Cannot access Metal buffer contents on the " + "host. The buffer is not CPU-addressable, for example because it uses " + "private storage."); + } + return contents; +} + +bool Buffer::is_host_accessible() const { + if (!ptr_) { + return true; + } + auto* buf = static_cast(ptr_); + return buf->storageMode() != MTL::StorageModePrivate; } } // namespace allocator diff --git a/mlx/backend/no_gpu/allocator.cpp b/mlx/backend/no_gpu/allocator.cpp index 1fdcc262a4..359fa36018 100644 --- a/mlx/backend/no_gpu/allocator.cpp +++ b/mlx/backend/no_gpu/allocator.cpp @@ -178,6 +178,9 @@ void* Buffer::raw_ptr() { return static_cast(ptr_) + 1; } +bool Buffer::is_host_accessible() const { + return true; +} } // namespace allocator size_t get_active_memory() { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 6ad41e2e38..4598529e6c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -296,6 +296,16 @@ array copy(array a, StreamOrDevice s /* = {} */) { {std::move(a)}); } +array copy_to_new_buffer(array a, StreamOrDevice s /* = {} */) { + auto copied_shape = a.shape(); // |a| will be moved + auto dtype = a.dtype(); + return array( + std::move(copied_shape), + dtype, + std::make_shared(to_stream(s), dtype), + {std::move(a)}); +} + array full_impl(array vals, Dtype dtype, StreamOrDevice s /* = {} */) { return array( vals.shape(), diff --git a/mlx/ops.h b/mlx/ops.h index 208964d1aa..3084de3962 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -60,6 +60,9 @@ MLX_API array as_strided( /** Copy another array. */ MLX_API array copy(array a, StreamOrDevice s = {}); +/** Copy another array into newly allocated storage. */ +MLX_API array copy_to_new_buffer(array a, StreamOrDevice s = {}); + /** Fill an array of the given shape with the given value(s). */ MLX_API array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {}); MLX_API array full(Shape shape, array vals, StreamOrDevice s = {}); diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 239e6603dd..39056ee79e 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -7,6 +7,7 @@ #include #include "mlx/dtype_utils.h" +#include "mlx/ops.h" #include "mlx/types/limits.h" #include "mlx/utils.h" @@ -212,6 +213,19 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) { namespace { +array host_accessible_array(array a) { + a.eval(); + a.wait(); + if (a.buffer().is_host_accessible()) { + return a; + } + auto out = copy_to_new_buffer(std::move(a), Device::gpu); + out.eval(); + out.wait(); + out.detach(); + return out; +} + template void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { int num_print = 3; @@ -277,7 +291,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { } std::ostream& operator<<(std::ostream& os, array a) { - a.eval(); + a = host_accessible_array(std::move(a)); dispatch_all_types(a.dtype(), [&](auto type_tag) { print_array(os, a); }); diff --git a/python/src/buffer.h b/python/src/buffer.h index 272a918883..8d01b82132 100644 --- a/python/src/buffer.h +++ b/python/src/buffer.h @@ -91,6 +91,14 @@ extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) { { nb::gil_scoped_release nogil; a.eval(); + a.wait(); + } + if (!a.buffer().is_host_accessible()) { + PyErr_SetString( + PyExc_BufferError, + "Cannot provide a buffer for an array whose storage is not " + "CPU-addressable."); + return -1; } std::vector shape(a.shape().begin(), a.shape().end()); diff --git a/python/src/convert.cpp b/python/src/convert.cpp index a5455c2b33..0b78a8fe20 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -1,13 +1,16 @@ // Copyright © 2024 Apple Inc. #include +#include #include #include +#include #include "python/src/convert.h" #include "python/src/utils.h" +#include "mlx/ops.h" #include "mlx/utils.h" enum PyScalarT { @@ -31,6 +34,16 @@ int check_shape_dim(int64_t dim) { return static_cast(dim); } +template +mx::Shape get_shape(const nb::ndarray& nd_array) { + mx::Shape shape; + shape.reserve(nd_array.ndim()); + for (int i = 0; i < nd_array.ndim(); i++) { + shape.push_back(check_shape_dim(nd_array.shape(i))); + } + return shape; +} + template mx::array nd_array_to_mlx_contiguous( nb::ndarray nd_array, @@ -42,88 +55,149 @@ mx::array nd_array_to_mlx_contiguous( return mx::array(static_cast(data_ptr), shape, dtype); } -mx::array nd_array_to_mlx( - nb::ndarray nd_array, - std::optional dtype, - std::optional nb_dtype) { - if (nd_array.device_type() != nb::device::cpu::value) { - throw std::invalid_argument( - "Cannot convert non-CPU DLPack array to mlx array."); - } - - // Compute the shape and size - mx::Shape shape; - shape.reserve(nd_array.ndim()); - for (int i = 0; i < nd_array.ndim(); i++) { - shape.push_back(check_shape_dim(nd_array.shape(i))); - } - auto type = nb_dtype.value_or(nd_array.dtype()); - - // Copy data and make array +template +auto dispatch_dlpack_dtype( + nb::dlpack::dtype type, + F&& f, + const char* error_message) { if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::bool_)); + return f.template operator()(mx::bool_); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::uint8)); + return f.template operator()(mx::uint8); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::uint16)); + return f.template operator()(mx::uint16); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::uint32)); + return f.template operator()(mx::uint32); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::uint64)); + return f.template operator()(mx::uint64); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::int8)); + return f.template operator()(mx::int8); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::int16)); + return f.template operator()(mx::int16); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::int32)); + return f.template operator()(mx::int32); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::int64)); + return f.template operator()(mx::int64); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::float16)); + return f.template operator()(mx::float16); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::bfloat16)); + return f.template operator()(mx::bfloat16); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::float32)); + return f.template operator()(mx::float32); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::float32)); + return f.template operator()(mx::float32); } else if (type == nb::dtype>()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::complex64)); + return f.template operator()(mx::complex64); } else if (type == nb::dtype>()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::complex64)); + return f.template operator()(mx::complex64); } else { - throw std::invalid_argument("Cannot convert numpy array to mlx array."); + throw std::invalid_argument(error_message); } } +mx::array metal_dlpack_to_mlx( + nb::ndarray nd_array, + std::optional dtype); + +mx::array host_accessible_array(mx::array a) { + a.eval(); + a.wait(); + if (a.buffer().is_host_accessible()) { + return a; + } + auto out = mx::copy_to_new_buffer(std::move(a), mx::Device::gpu); + out.eval(); + out.wait(); + out.detach(); + return out; +} + +mx::array nd_array_to_mlx( + nb::ndarray nd_array, + std::optional dtype, + std::optional nb_dtype) { + switch (nd_array.device_type()) { + case nb::device::cpu::value: { + auto shape = get_shape(nd_array); + auto type = nb_dtype.value_or(nd_array.dtype()); + return dispatch_dlpack_dtype( + type, + [&](mx::Dtype default_dtype) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(default_dtype)); + }, + "Cannot convert numpy array to mlx array."); + } + case nb::device::metal::value: + return metal_dlpack_to_mlx(std::move(nd_array), dtype); + default: + throw std::invalid_argument("Unsupported DLPack device."); + } +} + +template +mx::array metal_dlpack_to_mlx_contiguous( + std::shared_ptr> owner, + const mx::Shape& shape, + mx::Dtype type, + std::optional dtype) { + auto itemsize = mx::size_of(type); + if (owner->itemsize() != itemsize) { + throw std::invalid_argument( + "Cannot convert Metal DLPack dtype to mlx dtype."); + } + + auto byte_offset = owner->data_offset(); + if (byte_offset % itemsize != 0) { + throw std::invalid_argument( + "Metal DLPack byte offset is not aligned to dtype size."); + } + + auto out = mx::array( + mx::allocator::Buffer(owner->data_handle()), + shape, + type, + [](mx::allocator::Buffer) {}); + auto flags = out.flags(); + out.set_data( + out.buffer(), + out.data_size(), + out.strides(), + flags, + [owner = std::move(owner)](mx::allocator::Buffer) {}); + + auto offset = static_cast(byte_offset / itemsize); + if (offset != 0) { + out.copy_shared_buffer(out, out.strides(), flags, out.data_size(), offset); + } + + if (dtype) { + auto result = (*dtype == out.dtype()) + ? mx::copy_to_new_buffer(out, mx::Device::gpu) + : mx::astype(out, *dtype, mx::Device::gpu); + result.eval(); + result.wait(); + result.detach(); + return result; + } + return out; +} + template nb::ndarray mlx_to_nd_array_impl( mx::array a, std::optional t = {}) { { nb::gil_scoped_release nogil; - a.eval(); + a = host_accessible_array(std::move(a)); } std::vector shape(a.shape().begin(), a.shape().end()); + auto owner = nb::cast(a); return nb::ndarray( a.data(), a.ndim(), shape.data(), - /* owner= */ nb::none(), + /* owner= */ owner, a.strides().data(), t.value_or(nb::dtype())); } @@ -177,44 +251,60 @@ nb::object to_scalar(mx::array& a) { throw std::invalid_argument( "[convert] Only length-1 arrays can be converted to Python scalars."); } + auto host = mx::array(a); { nb::gil_scoped_release nogil; - a.eval(); + host = host_accessible_array(std::move(host)); } - switch (a.dtype()) { + switch (host.dtype()) { case mx::bool_: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::uint8: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::uint16: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::uint32: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::uint64: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::int8: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::int16: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::int32: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::int64: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::float16: - return nb::cast(static_cast(a.item())); + return nb::cast(static_cast(host.item())); case mx::float32: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::bfloat16: - return nb::cast(static_cast(a.item())); + return nb::cast(static_cast(host.item())); case mx::complex64: - return nb::cast(a.item>()); + return nb::cast(host.item>()); case mx::float64: - return nb::cast(a.item()); + return nb::cast(host.item()); default: throw nb::type_error("type cannot be converted to Python scalar."); } } +mx::array metal_dlpack_to_mlx( + nb::ndarray nd_array, + std::optional dtype) { + auto owner = + std::make_shared>(std::move(nd_array)); + auto shape = get_shape(*owner); + + return dispatch_dlpack_dtype( + owner->dtype(), + [&](mx::Dtype type) { + return metal_dlpack_to_mlx_contiguous(owner, shape, type, dtype); + }, + "Cannot convert Metal DLPack array to mlx array."); +} + template nb::list to_list(mx::array& a, size_t index, int dim) { nb::list pl; @@ -234,39 +324,40 @@ nb::object tolist(mx::array& a) { if (a.ndim() == 0) { return to_scalar(a); } + auto host = mx::array(a); { nb::gil_scoped_release nogil; - a.eval(); + host = host_accessible_array(std::move(host)); } - switch (a.dtype()) { + switch (host.dtype()) { case mx::bool_: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::uint8: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::uint16: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::uint32: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::uint64: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::int8: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::int16: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::int32: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::int64: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::float16: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::float32: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::bfloat16: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::float64: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::complex64: - return to_list>(a, 0, 0); + return to_list>(host, 0, 0); default: throw nb::type_error("data type cannot be converted to Python list."); } diff --git a/python/src/convert.h b/python/src/convert.h index 9341dd3122..a8e56d64f1 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -69,6 +69,8 @@ mx::array nd_array_to_mlx( nb::ndarray mlx_to_np_array(const mx::array& a); nb::ndarray<> mlx_to_dlpack(const mx::array& a); +mx::array host_accessible_array(mx::array a); + nb::object to_scalar(mx::array& a); nb::object tolist(mx::array& a); diff --git a/python/src/mlx_func.cpp b/python/src/mlx_func.cpp index 09aceabe9b..9955e134f9 100644 --- a/python/src/mlx_func.cpp +++ b/python/src/mlx_func.cpp @@ -2,6 +2,8 @@ #include "python/src/mlx_func.h" +#include + // A garbage collected function which wraps nb::cpp_function // See https://github.com/wjakob/nanobind/discussions/919 diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 8258b38c94..243cd16752 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -735,7 +735,7 @@ def test_array_np_conversion(self): self.assertEqual(x.tolist(), cvals) def test_array_np_dtype_conversion(self): - dtypes_list = [ + to_mlx_dtypes_list = [ (mx.bool_, np.bool_), (mx.uint8, np.uint8), (mx.uint16, np.uint16), @@ -750,13 +750,14 @@ def test_array_np_dtype_conversion(self): (mx.complex64, np.complex64), ] - for mlx_dtype, np_dtype in dtypes_list: + for mlx_dtype, np_dtype in to_mlx_dtypes_list: a_npy = np.random.uniform(low=0, high=100, size=(32,)).astype(np_dtype) a_mlx = mx.array(a_npy) self.assertEqual(a_mlx.dtype, mlx_dtype) self.assertTrue(np.allclose(a_mlx, a_npy)) + for mlx_dtype, np_dtype in to_mlx_dtypes_list: b_mlx = mx.random.uniform( low=0, high=10, @@ -2048,19 +2049,93 @@ def test_dlpack(self): self.assertTrue(mx.array_equal(y, x)) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_torch_mps_dlpack_non_cpu_error(self): + def test_torch_mps_dlpack_import(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) self.assertEqual(x.__dlpack_device__()[0], 8) - with self.assertRaisesRegex(ValueError, "non-CPU DLPack"): - mx.array(x) + y = mx.array(x) + self.assertEqual(y.dtype, mx.float32) + torch.mps.synchronize() + self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_host_access(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + y = mx.array(x) + + torch.mps.synchronize() + self.assertIn("array(", repr(y)) + self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) + try: + mv = memoryview(y) + except BufferError: + pass + else: + self.assertEqual(mv.tolist(), x.cpu().numpy().tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_zero_copy_reads_torch_updates(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + y = mx.array(x) + + x.add_(100) + torch.mps.synchronize() + self.assertEqual((y + 1).tolist(), (x + 1).cpu().numpy().tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_dtype_argument_copies(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + torch.mps.synchronize() + y_copy = mx.array(x, dtype=mx.float32) + expected = x.cpu().numpy().tolist() + x.add_(100) + torch.mps.synchronize() + self.assertEqual(y_copy.tolist(), expected) + + z = mx.array(x, dtype=mx.float16) + self.assertEqual(z.dtype, mx.float16) + self.assertEqual(z.tolist(), x.to(torch.float16).cpu().numpy().tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_data_offset(self): + view = torch.arange(12, device="mps", dtype=torch.float32)[3:9] + view_mx = mx.array(view) + torch.mps.synchronize() + self.assertEqual((view_mx + 1).tolist(), (view + 1).cpu().numpy().tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_bfloat16(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + bf = x.to(torch.bfloat16) + bf_mx = mx.array(bf) + + self.assertEqual(bf_mx.dtype, mx.bfloat16) + torch.mps.synchronize() + self.assertEqual( + bf_mx.astype(mx.float32).tolist(), + bf.to(torch.float32).cpu().numpy().tolist(), + ) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_array_operand(self): a = mx.array([1]) b = torch.tensor([2]) self.assertTrue(mx.array_equal(a + b, mx.array([3]))) - with self.assertRaisesRegex(ValueError, "non-CPU DLPack"): - a + b.to("mps") + b_mps = b.to("mps") + torch.mps.synchronize() + self.assertTrue(mx.array_equal(a + b_mps, mx.array([3]))) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_mlx_update_writes_torch_buffer(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.array(x) + + y += 3 + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), [3, 4, 5]) def test_getitem_with_list(self): a = mx.array([1, 2, 3, 4, 5]) From bbffe6a1ae15998ce8a0ac19cc3f759049dad29a Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sun, 10 May 2026 16:02:38 +0100 Subject: [PATCH 02/26] Add from_dlpack copy controls --- docs/src/usage/numpy.rst | 22 ++++++++++++++++--- python/src/convert.cpp | 29 +++++++++++++++++++++++++ python/src/convert.h | 1 + python/src/ops.cpp | 23 ++++++++++++++++++++ python/tests/test_array.py | 44 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 116 insertions(+), 3 deletions(-) diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index e455fbebda..f5641b4be5 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -76,7 +76,8 @@ PyTorch ------- PyTorch supports DLPack inputs and can import MLX arrays directly. -MLX can also import PyTorch tensors through DLPack with ``mx.array``. +MLX can also import PyTorch tensors through DLPack with ``mx.array`` or +``mx.from_dlpack``. .. code-block:: python @@ -99,8 +100,9 @@ The arrays do not share memory: print(c.tolist()) # [0, 1, 2] Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to -``mx.array``, MLX imports the underlying Metal buffer without copying it. The -PyTorch tensor and the MLX array then share the same storage. +``mx.array`` or to ``mx.from_dlpack`` with ``copy=None`` or ``copy=False``, MLX +imports the underlying Metal buffer without copying it. The PyTorch tensor and +the MLX array then share the same storage. Since the buffer is shared across frameworks, synchronization has to be managed explicitly. After PyTorch writes to an MPS tensor, call @@ -133,6 +135,20 @@ has been evaluated: mx.eval(c) print(b.cpu()) # tensor([10., 11., 12.]) +Use ``mx.from_dlpack`` when you need to control the copy behavior. Specifying +``copy=True`` asks MLX to create a new array instead of sharing the Metal +buffer: + +.. code-block:: python + + b = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + c = mx.from_dlpack(b, copy=True) + + b.add_(10) + torch.mps.synchronize() + print(c.tolist()) # [0.0, 1.0, 2.0] + JAX --- JAX fully supports the buffer protocol. diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 0b78a8fe20..a19f1bf745 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -135,6 +135,35 @@ mx::array nd_array_to_mlx( } } +mx::array from_dlpack(nb::object v, std::optional copy) { + using ContigArray = nb::ndarray; + auto nd = nb::cast(v); + + switch (nd.device_type()) { + case nb::device::cpu::value: + if (copy == false) { + throw std::invalid_argument( + "Cannot import a CPU DLPack array without a copy."); + } + return nd_array_to_mlx(std::move(nd), std::nullopt); + case nb::device::metal::value: { + std::optional dtype; + if (copy == true) { + dtype = dispatch_dlpack_dtype( + nd.dtype(), + [](mx::Dtype dtype) { return dtype; }, + "Cannot convert Metal DLPack array to mlx array."); + } + return nd_array_to_mlx(std::move(nd), dtype); + } + case nb::device::cuda::value: + case nb::device::cuda_managed::value: + throw std::invalid_argument("CUDA DLPack import is not supported."); + default: + throw std::invalid_argument("Unsupported DLPack device."); + } +} + template mx::array metal_dlpack_to_mlx_contiguous( std::shared_ptr> owner, diff --git a/python/src/convert.h b/python/src/convert.h index a8e56d64f1..3ac6cb9f74 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -76,6 +76,7 @@ nb::object to_scalar(mx::array& a); nb::object tolist(mx::array& a); mx::array create_array(nb::object v, std::optional t); +mx::array from_dlpack(nb::object v, std::optional copy); mx::array array_from_list(nb::list pl, std::optional dtype); mx::array array_from_list(nb::tuple pl, std::optional dtype); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 198a4861d9..8f70389983 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1795,6 +1795,29 @@ void init_ops(nb::module_& m) { Raises: ValueError: If ``copy`` is ``False``. )pbdoc"); + m.def( + "from_dlpack", + [](const nb::object& x, std::optional copy) { + return from_dlpack(x, copy); + }, + nb::arg(), + nb::kw_only(), + "copy"_a = nb::none(), + nb::sig( + "def from_dlpack(x: DLPackCompatible, /, *, copy: Optional[bool] = None) -> array"), + R"pbdoc( + Create an array from an object that supports DLPack. + + Args: + x: Input object implementing ``__dlpack__`` and + ``__dlpack_device__``. + copy (bool, optional): Whether to copy the input. If ``True``, + always copy. If ``False``, never copy. If ``None``, share memory + when possible and copy otherwise. + + Returns: + array: An array containing the input data. + )pbdoc"); m.def( "zeros_like", &mx::zeros_like, diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 243cd16752..de4e4795a4 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2048,6 +2048,20 @@ def test_dlpack(self): y = np.from_dlpack(x) self.assertTrue(mx.array_equal(y, x)) + def test_from_dlpack_cpu(self): + x = np.arange(3, dtype=np.float32) + + y = mx.from_dlpack(x) + x += 10 + self.assertEqual(y.tolist(), [0.0, 1.0, 2.0]) + + y = mx.from_dlpack(x, copy=True) + x += 10 + self.assertEqual(y.tolist(), [10.0, 11.0, 12.0]) + + with self.assertRaises(ValueError): + mx.from_dlpack(x, copy=False) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_import(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) @@ -2137,6 +2151,36 @@ def test_torch_mps_dlpack_mlx_update_writes_torch_buffer(self): mx.eval(y) self.assertEqual(x.cpu().numpy().tolist(), [3, 4, 5]) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_from_dlpack_torch_mps_copy_none_shares(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.from_dlpack(x) + + x.add_(10) + torch.mps.synchronize() + self.assertEqual(y.tolist(), [10.0, 11.0, 12.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_from_dlpack_torch_mps_copy_false_shares(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.from_dlpack(x, copy=False) + + y += 10 + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), [10.0, 11.0, 12.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_from_dlpack_torch_mps_copy_true_copies(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.from_dlpack(x, copy=True) + + x.add_(10) + torch.mps.synchronize() + self.assertEqual(y.tolist(), [0.0, 1.0, 2.0]) + def test_getitem_with_list(self): a = mx.array([1, 2, 3, 4, 5]) idx = [0, 2, 4] From 11cda582d4c6a6a135f82b524b43d5929c609a04 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Tue, 12 May 2026 01:11:23 +0800 Subject: [PATCH 03/26] Support Metal DLPack zero-copy sharing --- docs/src/usage/numpy.rst | 7 ++- mlx/utils.cpp | 4 +- mlx/utils.h | 2 + python/src/array.cpp | 24 ++++++++- python/src/convert.cpp | 105 +++++++++++++++++++++++++++++++------ python/src/convert.h | 7 ++- python/tests/test_array.py | 73 +++++++++++++++++++++++--- 7 files changed, 192 insertions(+), 30 deletions(-) diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index f5641b4be5..fc511b8098 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -102,7 +102,8 @@ The arrays do not share memory: Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to ``mx.array`` or to ``mx.from_dlpack`` with ``copy=None`` or ``copy=False``, MLX imports the underlying Metal buffer without copying it. The PyTorch tensor and -the MLX array then share the same storage. +the MLX array then share the same storage. MLX arrays exported to PyTorch with +DLPack are also shared without a copy. Since the buffer is shared across frameworks, synchronization has to be managed explicitly. After PyTorch writes to an MPS tensor, call @@ -135,6 +136,10 @@ has been evaluated: mx.eval(c) print(b.cpu()) # tensor([10., 11., 12.]) +For MLX arrays exported to PyTorch, the share is tied to the exported buffer. +MLX updates after export may rebind the MLX array to a new buffer, while the +PyTorch tensor continues to reference the exported buffer. + Use ``mx.from_dlpack`` when you need to control the copy behavior. Specifying ``copy=True`` asks MLX to create a new array instead of sharing the Metal buffer: diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 39056ee79e..ce5b8f3387 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -211,8 +211,6 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) { return os; } -namespace { - array host_accessible_array(array a) { a.eval(); a.wait(); @@ -226,6 +224,8 @@ array host_accessible_array(array a) { return out; } +namespace { + template void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { int num_print = 3; diff --git a/mlx/utils.h b/mlx/utils.h index d8b4c7ac99..b34ca1918f 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -69,6 +69,8 @@ MLX_API void set_printoptions(PrintOptions options); MLX_API PrintFormatter& get_global_formatter(); +MLX_API array host_accessible_array(array a); + /** Print the exception and then abort. */ MLX_API void abort_with_exception(const std::exception& error); diff --git a/python/src/array.cpp b/python/src/array.cpp index 28c12f622c..013267b28e 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -496,7 +496,29 @@ void init_array(nb::module_& m) { new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt)); } }) - .def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); }) + .def( + "__dlpack__", + [](const mx::array& a, + nb::object, + nb::object, + nb::object dl_device, + nb::object) { + std::optional dl_device_type; + if (!dl_device.is_none()) { + auto device = nb::cast(dl_device); + if (nb::len(device) != 2) { + throw nb::type_error( + "dl_device must be None or a tuple[int, int]"); + } + dl_device_type = nb::cast(device[0]); + } + return mlx_to_dlpack(a, dl_device_type); + }, + nb::kw_only(), + "stream"_a = nb::none(), + "max_version"_a = nb::none(), + "dl_device"_a = nb::none(), + "copy"_a = nb::none()) .def( "__dlpack_device__", [](const mx::array& a) { diff --git a/python/src/convert.cpp b/python/src/convert.cpp index a19f1bf745..d746eab34c 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -10,6 +10,8 @@ #include "python/src/convert.h" #include "python/src/utils.h" +#include "mlx/backend/cuda/cuda.h" +#include "mlx/backend/metal/metal.h" #include "mlx/ops.h" #include "mlx/utils.h" @@ -99,19 +101,6 @@ mx::array metal_dlpack_to_mlx( nb::ndarray nd_array, std::optional dtype); -mx::array host_accessible_array(mx::array a) { - a.eval(); - a.wait(); - if (a.buffer().is_host_accessible()) { - return a; - } - auto out = mx::copy_to_new_buffer(std::move(a), mx::Device::gpu); - out.eval(); - out.wait(); - out.detach(); - return out; -} - mx::array nd_array_to_mlx( nb::ndarray nd_array, std::optional dtype, @@ -176,7 +165,7 @@ mx::array metal_dlpack_to_mlx_contiguous( "Cannot convert Metal DLPack dtype to mlx dtype."); } - auto byte_offset = owner->data_offset(); + auto byte_offset = owner->byte_offset(); if (byte_offset % itemsize != 0) { throw std::invalid_argument( "Metal DLPack byte offset is not aligned to dtype size."); @@ -271,8 +260,92 @@ nb::ndarray mlx_to_np_array(const mx::array& a) { return mlx_to_nd_array(a); } -nb::ndarray<> mlx_to_dlpack(const mx::array& a) { - return mlx_to_nd_array<>(a); +template +nb::ndarray<> mlx_to_dlpack_impl(mx::array a, int dl_device_type) { + void* data = nullptr; + uint64_t byte_offset = 0; + { + nb::gil_scoped_release nogil; + a.eval(); + a.wait(); + if (dl_device_type == nb::device::cpu::value) { + a = host_accessible_array(std::move(a)); + data = a.data(); + } else { + data = a.buffer().ptr(); + byte_offset = a.offset(); + } + } + + std::vector shape(a.shape().begin(), a.shape().end()); + auto owner = nb::cast(a); + return nb::ndarray<>( + data, + a.ndim(), + shape.data(), + /* owner= */ owner, + a.strides().data(), + nb::dtype(), + dl_device_type, + 0, + '\0', + byte_offset); +} + +nb::ndarray<> mlx_to_dlpack( + const mx::array& a, + std::optional dl_device_type) { + int device_type = dl_device_type.value_or( + mx::metal::is_available() + ? nb::device::metal::value + : (mx::cu::is_available() ? nb::device::cuda_managed::value + : nb::device::cpu::value)); + + if (device_type == nb::device::cuda::value || + device_type == nb::device::cuda_managed::value) { + throw nb::buffer_error("CUDA DLPack export is not supported."); + } + if (device_type != nb::device::cpu::value && + device_type != nb::device::metal::value) { + throw nb::buffer_error( + "Cannot export mlx array to requested DLPack device."); + } + if (device_type == nb::device::metal::value && !mx::metal::is_available()) { + throw nb::buffer_error("Metal DLPack export is not available."); + } + + switch (a.dtype()) { + case mx::bool_: + return mlx_to_dlpack_impl(a, device_type); + case mx::uint8: + return mlx_to_dlpack_impl(a, device_type); + case mx::uint16: + return mlx_to_dlpack_impl(a, device_type); + case mx::uint32: + return mlx_to_dlpack_impl(a, device_type); + case mx::uint64: + return mlx_to_dlpack_impl(a, device_type); + case mx::int8: + return mlx_to_dlpack_impl(a, device_type); + case mx::int16: + return mlx_to_dlpack_impl(a, device_type); + case mx::int32: + return mlx_to_dlpack_impl(a, device_type); + case mx::int64: + return mlx_to_dlpack_impl(a, device_type); + case mx::float16: + return mlx_to_dlpack_impl(a, device_type); + case mx::bfloat16: + return mlx_to_dlpack_impl(a, device_type); + case mx::float32: + return mlx_to_dlpack_impl(a, device_type); + case mx::float64: + return mlx_to_dlpack_impl(a, device_type); + case mx::complex64: + return mlx_to_dlpack_impl>(a, device_type); + default: + throw nb::buffer_error("Cannot export mlx array with unsupported dtype."); + } } nb::object to_scalar(mx::array& a) { diff --git a/python/src/convert.h b/python/src/convert.h index 3ac6cb9f74..bf93540f2a 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -7,7 +7,6 @@ #include #include "mlx/array.h" -#include "mlx/ops.h" namespace mx = mlx::core; namespace nb = nanobind; @@ -67,9 +66,9 @@ mx::array nd_array_to_mlx( std::optional nb_dtype = std::nullopt); nb::ndarray mlx_to_np_array(const mx::array& a); -nb::ndarray<> mlx_to_dlpack(const mx::array& a); - -mx::array host_accessible_array(mx::array a); +nb::ndarray<> mlx_to_dlpack( + const mx::array& a, + std::optional dl_device_type = std::nullopt); nb::object to_scalar(mx::array& a); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index de4e4795a4..cee6107a72 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2035,17 +2035,28 @@ def test_add_numpy(self): self.assertEqual(z.item(), 3) def test_dlpack(self): + class CpuDLPack: + def __init__(self, array): + self.array = array + + def __dlpack_device__(self): + return (1, 0) + + def __dlpack__(self, *args, **kwargs): + kwargs["dl_device"] = (1, 0) + return self.array.__dlpack__(*args, **kwargs) + x = mx.array(1, dtype=mx.int32) - y = np.from_dlpack(x) + y = np.from_dlpack(CpuDLPack(x)) self.assertTrue(mx.array_equal(y, x)) x = mx.array([[1.0, 2.0], [3.0, 4.0]]) - y = np.from_dlpack(x) + y = np.from_dlpack(CpuDLPack(x)) self.assertTrue(mx.array_equal(y, x)) x = mx.arange(16).reshape(4, 4) x = x[::2, ::2] - y = np.from_dlpack(x) + y = np.from_dlpack(CpuDLPack(x)) self.assertTrue(mx.array_equal(y, x)) def test_from_dlpack_cpu(self): @@ -2088,14 +2099,19 @@ def test_torch_mps_dlpack_host_access(self): self.assertEqual(mv.tolist(), x.cpu().numpy().tolist()) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_torch_mps_dlpack_zero_copy_reads_torch_updates(self): + def test_torch_mps_dlpack_zero_copy_shares_updates(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + torch.mps.synchronize() y = mx.array(x) x.add_(100) torch.mps.synchronize() self.assertEqual((y + 1).tolist(), (x + 1).cpu().numpy().tolist()) + y += 10 + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), y.tolist()) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_dtype_argument_copies(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) @@ -2152,7 +2168,44 @@ def test_torch_mps_dlpack_mlx_update_writes_torch_buffer(self): self.assertEqual(x.cpu().numpy().tolist(), [3, 4, 5]) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_from_dlpack_torch_mps_copy_none_shares(self): + def test_mlx_dlpack_exports_mps_tensor_to_torch(self): + x = mx.array([1]).astype(mx.float16) + mx.eval(x) + y = torch.utils.dlpack.from_dlpack(x) + torch.mps.synchronize() + + self.assertEqual(y.device.type, "mps") + self.assertEqual(y.dtype, torch.float16) + self.assertEqual(y.cpu().numpy().tolist(), [1.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_mlx_dlpack_exports_mps_tensor_to_torch_tensor(self): + x = mx.array([1]).astype(mx.float16) + mx.eval(x) + y = torch.tensor(x) + torch.mps.synchronize() + + self.assertEqual(y.device.type, "mps") + self.assertEqual(y.dtype, torch.float16) + self.assertEqual(y.cpu().numpy().tolist(), [1.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_mlx_dlpack_export_torch_update_writes_mlx_buffer(self): + x = mx.arange(8, dtype=mx.float32) + y = x[2:6] + mx.eval(y) + t = torch.utils.dlpack.from_dlpack(y) + + self.assertEqual(t.device.type, "mps") + self.assertEqual(t.cpu().numpy().tolist(), [2.0, 3.0, 4.0, 5.0]) + + t.add_(10) + torch.mps.synchronize() + self.assertEqual(y.tolist(), [12.0, 13.0, 14.0, 15.0]) + self.assertEqual(x.tolist(), [0.0, 1.0, 12.0, 13.0, 14.0, 15.0, 6.0, 7.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_from_dlpack_torch_mps_copy_none_shares_updates(self): x = torch.arange(3, device="mps", dtype=torch.float32) torch.mps.synchronize() y = mx.from_dlpack(x) @@ -2161,8 +2214,12 @@ def test_from_dlpack_torch_mps_copy_none_shares(self): torch.mps.synchronize() self.assertEqual(y.tolist(), [10.0, 11.0, 12.0]) + y += 10 + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), [20.0, 21.0, 22.0]) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_from_dlpack_torch_mps_copy_false_shares(self): + def test_from_dlpack_torch_mps_copy_false_shares_updates(self): x = torch.arange(3, device="mps", dtype=torch.float32) torch.mps.synchronize() y = mx.from_dlpack(x, copy=False) @@ -2171,6 +2228,10 @@ def test_from_dlpack_torch_mps_copy_false_shares(self): mx.eval(y) self.assertEqual(x.cpu().numpy().tolist(), [10.0, 11.0, 12.0]) + x.add_(10) + torch.mps.synchronize() + self.assertEqual(y.tolist(), [20.0, 21.0, 22.0]) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_from_dlpack_torch_mps_copy_true_copies(self): x = torch.arange(3, device="mps", dtype=torch.float32) From 361143f8f6d51d5417a22b02781b1e4b57b01b1f Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Tue, 12 May 2026 01:19:33 +0800 Subject: [PATCH 04/26] Share DLPack arrays when dtype matches --- python/src/convert.cpp | 21 ++++++++++++++++----- python/tests/test_array.py | 23 ++++++++++++++++++----- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index d746eab34c..0bb6045203 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -97,6 +97,13 @@ auto dispatch_dlpack_dtype( } } +mx::Dtype mlx_dtype_from_dlpack( + nb::dlpack::dtype type, + const char* error_message) { + return dispatch_dlpack_dtype( + type, [](mx::Dtype dtype) { return dtype; }, error_message); +} + mx::array metal_dlpack_to_mlx( nb::ndarray nd_array, std::optional dtype); @@ -138,10 +145,8 @@ mx::array from_dlpack(nb::object v, std::optional copy) { case nb::device::metal::value: { std::optional dtype; if (copy == true) { - dtype = dispatch_dlpack_dtype( - nd.dtype(), - [](mx::Dtype dtype) { return dtype; }, - "Cannot convert Metal DLPack array to mlx array."); + dtype = mlx_dtype_from_dlpack( + nd.dtype(), "Cannot convert Metal DLPack array to mlx array."); } return nd_array_to_mlx(std::move(nd), dtype); } @@ -704,7 +709,13 @@ mx::array create_array(nb::object v, std::optional t) { } else { nd = nb::cast(v); } - return nd_array_to_mlx(nd, t, nb_dtype); + auto type = nb_dtype.value_or(nd.dtype()); + std::optional copy_dtype; + if (t && + *t != mlx_dtype_from_dlpack(type, "Cannot convert array to mlx.")) { + copy_dtype = t; + } + return nd_array_to_mlx(nd, copy_dtype, nb_dtype); } else { auto arr = to_array_with_accessor(v); return mx::astype(arr, t.value_or(arr.dtype())); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index cee6107a72..2988f93f89 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2113,19 +2113,32 @@ def test_torch_mps_dlpack_zero_copy_shares_updates(self): self.assertEqual(x.cpu().numpy().tolist(), y.tolist()) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_torch_mps_dlpack_dtype_argument_copies(self): + def test_torch_mps_dlpack_matching_dtype_argument_shares_updates(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) torch.mps.synchronize() - y_copy = mx.array(x, dtype=mx.float32) - expected = x.cpu().numpy().tolist() + y = mx.array(x, dtype=mx.float32) x.add_(100) torch.mps.synchronize() - self.assertEqual(y_copy.tolist(), expected) + self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) + + y += 10 + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), y.tolist()) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_different_dtype_argument_copies(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + torch.mps.synchronize() z = mx.array(x, dtype=mx.float16) + expected = x.to(torch.float16).cpu().numpy().tolist() + self.assertEqual(z.dtype, mx.float16) - self.assertEqual(z.tolist(), x.to(torch.float16).cpu().numpy().tolist()) + self.assertEqual(z.tolist(), expected) + + x.add_(100) + torch.mps.synchronize() + self.assertEqual(z.tolist(), expected) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_data_offset(self): From 206dce437b4fcf48f36185bbba639230b2fb1ba2 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Tue, 19 May 2026 15:59:51 +0800 Subject: [PATCH 05/26] Clarify MPS DLPack host access test --- python/tests/test_array.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 2988f93f89..5360edfcd1 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -735,7 +735,7 @@ def test_array_np_conversion(self): self.assertEqual(x.tolist(), cvals) def test_array_np_dtype_conversion(self): - to_mlx_dtypes_list = [ + dtypes_list = [ (mx.bool_, np.bool_), (mx.uint8, np.uint8), (mx.uint16, np.uint16), @@ -750,14 +750,13 @@ def test_array_np_dtype_conversion(self): (mx.complex64, np.complex64), ] - for mlx_dtype, np_dtype in to_mlx_dtypes_list: + for mlx_dtype, np_dtype in dtypes_list: a_npy = np.random.uniform(low=0, high=100, size=(32,)).astype(np_dtype) a_mlx = mx.array(a_npy) self.assertEqual(a_mlx.dtype, mlx_dtype) self.assertTrue(np.allclose(a_mlx, a_npy)) - for mlx_dtype, np_dtype in to_mlx_dtypes_list: b_mlx = mx.random.uniform( low=0, high=10, @@ -2091,6 +2090,9 @@ def test_torch_mps_dlpack_host_access(self): torch.mps.synchronize() self.assertIn("array(", repr(y)) self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) + # PyTorch 2.12 allocates ordinary MPS tensors in shared/unified + # MTLBuffers on Apple silicon, while private Metal DLPack producers are + # not host-accessible. try: mv = memoryview(y) except BufferError: From a7c39c7f0d26808809159cb6399988379b70f78c Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Tue, 19 May 2026 16:05:38 +0800 Subject: [PATCH 06/26] Pin nanobind with Metal DLPack support --- CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a2395d02f6..ff110ec929 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -357,8 +357,7 @@ if(MLX_BUILD_PYTHON_BINDINGS) FetchContent_Declare( nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git - GIT_TAG v2.12.0 - GIT_SHALLOW TRUE + GIT_TAG 117ba9ee0253e1bbada678bb4b5d6c6e4ed441eb EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(nanobind) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) From dc90ebad9df49c8eae7ed702faa853573c2dcbfa Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Tue, 19 May 2026 16:36:55 +0800 Subject: [PATCH 07/26] Use host accessibility check for Metal raw pointer --- mlx/backend/metal/allocator.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 5b24142fd5..89ad12be6f 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -24,15 +24,13 @@ void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; } - auto* buf = static_cast(ptr_); - auto* contents = buf->contents(); - if (!contents && buf->length() > 0) { + if (!is_host_accessible()) { throw std::runtime_error( "[metal::Buffer::raw_ptr] Cannot access Metal buffer contents on the " "host. The buffer is not CPU-addressable, for example because it uses " "private storage."); } - return contents; + return static_cast(ptr_)->contents(); } bool Buffer::is_host_accessible() const { From 2dea5e97edb5368e58c60952ff096925c7f5dcfe Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Thu, 21 May 2026 14:48:17 +0800 Subject: [PATCH 08/26] Add array copy for shared DLPack buffers --- CMakeLists.txt | 1 + docs/src/python/array.rst | 1 + docs/src/usage/numpy.rst | 47 +++++++++++++++++++--------------- mlx/backend/cpu/primitives.cpp | 5 ++++ mlx/backend/cuda/allocator.cpp | 6 ++++- mlx/backend/gpu/primitives.cpp | 7 +++++ mlx/ops.cpp | 14 ++++++++++ mlx/ops.h | 4 +++ mlx/primitives.h | 11 ++++++++ python/src/array.cpp | 23 +++++++++++++++++ python/src/convert.cpp | 14 +++++----- python/tests/test_array.py | 40 +++++++++++++++++++++++++---- 12 files changed, 139 insertions(+), 34 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ff110ec929..e33682b018 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -358,6 +358,7 @@ if(MLX_BUILD_PYTHON_BINDINGS) nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG 117ba9ee0253e1bbada678bb4b5d6c6e4ed441eb + GIT_SHALLOW TRUE EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(nanobind) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) diff --git a/docs/src/python/array.rst b/docs/src/python/array.rst index e68524d5a4..42ea2186f9 100644 --- a/docs/src/python/array.rst +++ b/docs/src/python/array.rst @@ -11,6 +11,7 @@ Array array array.astype array.at + array.copy array.item array.tolist array.dtype diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index fc511b8098..fc84122793 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -79,21 +79,15 @@ PyTorch supports DLPack inputs and can import MLX arrays directly. MLX can also import PyTorch tensors through DLPack with ``mx.array`` or ``mx.from_dlpack``. -.. code-block:: python - - import mlx.core as mx - import torch - - a = mx.arange(3) - b = torch.tensor(a) - c = mx.array(b) - Creating an MLX array from a CPU tensor copies the data into MLX-owned storage. The arrays do not share memory: .. code-block:: python - b = torch.arange(3) + import mlx.core as mx + import torch + + b = torch.arange(3).cpu() c = mx.array(b) b += 10 @@ -113,6 +107,12 @@ the shared data from PyTorch. Without these synchronization points, the other framework may read the shared buffer before the producer has finished writing, so it can observe stale data. +Do not rely on regular MLX update expressions to preserve sharing. MLX may +rebind an array to a new buffer when evaluating an expression, while PyTorch +continues to reference the original shared buffer. Use +:meth:`mlx.core.array.copy` to write into an array's existing storage when you +need the update to be visible to another framework sharing the same buffer: + .. code-block:: python b = torch.arange(3, device="mps", dtype=torch.float32) @@ -123,22 +123,27 @@ so it can observe stale data. torch.mps.synchronize() print(c.tolist()) # [10.0, 11.0, 12.0] -Updates made by MLX can also be observed from PyTorch after the MLX computation -has been evaluated: + c.copy(mx.array([4, 5, 6], dtype=mx.float32)) + mx.eval(c) + print(b.cpu()) # tensor([4., 5., 6.]) + +The same synchronization rules apply when PyTorch imports an MLX array through +DLPack. Use ``torch.as_tensor`` to share the buffer with PyTorch; ``torch.tensor`` +copies the data instead: .. code-block:: python - b = torch.arange(3, device="mps", dtype=torch.float32) - torch.mps.synchronize() - c = mx.array(b) + a = mx.arange(3, dtype=mx.float32) + mx.eval(a) + b = torch.as_tensor(a) - c += 10 - mx.eval(c) - print(b.cpu()) # tensor([10., 11., 12.]) + b.add_(10) + torch.mps.synchronize() + print(a.tolist()) # [10.0, 11.0, 12.0] -For MLX arrays exported to PyTorch, the share is tied to the exported buffer. -MLX updates after export may rebind the MLX array to a new buffer, while the -PyTorch tensor continues to reference the exported buffer. + a.copy(mx.array([4, 5, 6], dtype=mx.float32)) + mx.eval(a) + print(b.cpu()) # tensor([4., 5., 6.]) Use ``mx.from_dlpack`` when you need to control the copy behavior. Specifying ``copy=True`` asks MLX to create a new array instead of sharing the Metal diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index f1d83dd306..133dab2104 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -88,6 +88,11 @@ void BroadcastAxes::eval_cpu(const std::vector& inputs, array& out) { void Copy::eval_cpu(const std::vector& inputs, array& out) { eval(inputs, out); } +void CopyInto::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + out.copy_shared_buffer(inputs[0]); + copy_cpu_inplace(inputs[1], out, CopyType::GeneralGeneral, stream()); +} void CustomTransforms::eval_cpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 37dc457a02..4d67b2798e 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -417,7 +417,11 @@ void* Buffer::raw_ptr() { } bool Buffer::is_host_accessible() const { - return true; + if (!ptr_) { + return true; + } + auto& cbuf = *static_cast(ptr_); + return cbuf.device == -1; } } // namespace allocator diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 268d6290bf..29f9c2d085 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -66,6 +66,13 @@ void Copy::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } +void CopyInto::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("CopyInto::eval_gpu"); + assert(inputs.size() == 2); + out.copy_shared_buffer(inputs[0]); + copy_gpu_inplace(inputs[1], out, CopyType::GeneralGeneral, stream()); +} + void CustomTransforms::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4598529e6c..983e1bb116 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -296,6 +296,20 @@ array copy(array a, StreamOrDevice s /* = {} */) { {std::move(a)}); } +array copy_into( + const array& dst, + const array& src, + StreamOrDevice s /* = {} */) { + auto stream = to_stream(s); + auto update = + broadcast_to(astype(src, dst.dtype(), stream), dst.shape(), stream); + return array( + dst.shape(), + dst.dtype(), + std::make_shared(stream), + {dst, std::move(update)}); +} + array copy_to_new_buffer(array a, StreamOrDevice s /* = {} */) { auto copied_shape = a.shape(); // |a| will be moved auto dtype = a.dtype(); diff --git a/mlx/ops.h b/mlx/ops.h index 3084de3962..35edd17d88 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -60,6 +60,10 @@ MLX_API array as_strided( /** Copy another array. */ MLX_API array copy(array a, StreamOrDevice s = {}); +/** Copy an array into another array's storage. */ +MLX_API array +copy_into(const array& dst, const array& src, StreamOrDevice s = {}); + /** Copy another array into newly allocated storage. */ MLX_API array copy_to_new_buffer(array a, StreamOrDevice s = {}); diff --git a/mlx/primitives.h b/mlx/primitives.h index 75fb978dce..9a8c455b52 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -803,6 +803,17 @@ class Copy : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class CopyInto : public UnaryPrimitive { + public: + explicit CopyInto(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_NAME(CopyInto) + DEFINE_INPUT_OUTPUT_SHAPE() +}; + class Cos : public UnaryPrimitive { public: explicit Cos(Stream stream) : UnaryPrimitive(stream) {} diff --git a/python/src/array.cpp b/python/src/array.cpp index 013267b28e..3b2189ab01 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -538,6 +538,29 @@ void init_array(nb::module_& m) { "__deepcopy__", [](const mx::array& self, nb::dict) { return mx::array(self); }, "memo"_a) + .def( + "copy", + [](mx::array& a, + const ScalarOrArray& v, + mx::StreamOrDevice s) -> mx::array& { + a.eval(); + auto out = mx::copy_into(a, to_array(v, a.dtype()), s); + a.overwrite_descriptor(out); + return a; + }, + "src"_a, + "stream"_a = nb::none(), + nb::rv_policy::none, + nb::sig( + "def copy(self, src: Union[scalar, array, DLPackCompatible], /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Copy ``src`` into the array's existing storage. + + ``src`` can be a scalar, an MLX array, or a DLPack-compatible + array. The destination array keeps its shape, dtype, and storage. + ``src`` is cast to the destination dtype and broadcast to the + destination shape before being written. + )pbdoc") .def( "__add__", [](const mx::array& a, const ScalarOrArray v) { diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 0bb6045203..43e8cbc8b3 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -60,8 +60,8 @@ mx::array nd_array_to_mlx_contiguous( template auto dispatch_dlpack_dtype( nb::dlpack::dtype type, - F&& f, - const char* error_message) { + const char* error_message, + F&& f) { if (type == nb::dtype()) { return f.template operator()(mx::bool_); } else if (type == nb::dtype()) { @@ -101,7 +101,7 @@ mx::Dtype mlx_dtype_from_dlpack( nb::dlpack::dtype type, const char* error_message) { return dispatch_dlpack_dtype( - type, [](mx::Dtype dtype) { return dtype; }, error_message); + type, error_message, [](mx::Dtype dtype) { return dtype; }); } mx::array metal_dlpack_to_mlx( @@ -118,11 +118,11 @@ mx::array nd_array_to_mlx( auto type = nb_dtype.value_or(nd_array.dtype()); return dispatch_dlpack_dtype( type, + "Cannot convert numpy array to mlx array.", [&](mx::Dtype default_dtype) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(default_dtype)); - }, - "Cannot convert numpy array to mlx array."); + }); } case nb::device::metal::value: return metal_dlpack_to_mlx(std::move(nd_array), dtype); @@ -406,10 +406,10 @@ mx::array metal_dlpack_to_mlx( return dispatch_dlpack_dtype( owner->dtype(), + "Cannot convert Metal DLPack array to mlx array.", [&](mx::Dtype type) { return metal_dlpack_to_mlx_contiguous(owner, shape, type, dtype); - }, - "Cannot convert Metal DLPack array to mlx array."); + }); } template diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 5360edfcd1..6813bcc18e 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -956,6 +956,22 @@ def test_array_copy(self): y -= 1 self.assertEqualArray(y, x - 1) + def test_array_copy_method(self): + x = mx.array([1, 2, 3]) + y = x.copy(mx.array([4, 5, 6])) + self.assertIs(y, x) + mx.eval(x) + self.assertEqualArray(x, mx.array([4, 5, 6])) + + x.copy(9) + mx.eval(x) + self.assertEqualArray(x, mx.array([9, 9, 9])) + + x.copy(np.array([1.5, 2.5, 3.5])) + mx.eval(x) + self.assertEqual(x.dtype, mx.int32) + self.assertEqualArray(x, mx.array([1, 2, 3])) + def test_indexing(self): # Only ellipsis is a no-op a_mlx = mx.array([1])[...] @@ -2110,7 +2126,7 @@ def test_torch_mps_dlpack_zero_copy_shares_updates(self): torch.mps.synchronize() self.assertEqual((y + 1).tolist(), (x + 1).cpu().numpy().tolist()) - y += 10 + y.copy(mx.full(y.shape, 7.0)) mx.eval(y) self.assertEqual(x.cpu().numpy().tolist(), y.tolist()) @@ -2124,7 +2140,7 @@ def test_torch_mps_dlpack_matching_dtype_argument_shares_updates(self): torch.mps.synchronize() self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) - y += 10 + y.copy(mx.full(y.shape, 7.0)) mx.eval(y) self.assertEqual(x.cpu().numpy().tolist(), y.tolist()) @@ -2178,10 +2194,24 @@ def test_torch_mps_dlpack_mlx_update_writes_torch_buffer(self): torch.mps.synchronize() y = mx.array(x) - y += 3 + y.copy(mx.array([3, 4, 5], dtype=mx.float32)) mx.eval(y) self.assertEqual(x.cpu().numpy().tolist(), [3, 4, 5]) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_copy_writes_torch_buffer(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.array(x) + + y.copy(mx.array([4, 5, 6], dtype=mx.float32)) + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), [4.0, 5.0, 6.0]) + + x.add_(10) + torch.mps.synchronize() + self.assertEqual(y.tolist(), [14.0, 15.0, 16.0]) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_mlx_dlpack_exports_mps_tensor_to_torch(self): x = mx.array([1]).astype(mx.float16) @@ -2229,7 +2259,7 @@ def test_from_dlpack_torch_mps_copy_none_shares_updates(self): torch.mps.synchronize() self.assertEqual(y.tolist(), [10.0, 11.0, 12.0]) - y += 10 + y.copy(mx.array([20, 21, 22], dtype=mx.float32)) mx.eval(y) self.assertEqual(x.cpu().numpy().tolist(), [20.0, 21.0, 22.0]) @@ -2239,7 +2269,7 @@ def test_from_dlpack_torch_mps_copy_false_shares_updates(self): torch.mps.synchronize() y = mx.from_dlpack(x, copy=False) - y += 10 + y.copy(mx.array([10, 11, 12], dtype=mx.float32)) mx.eval(y) self.assertEqual(x.cpu().numpy().tolist(), [10.0, 11.0, 12.0]) From ef1252833ab3d315ed2d055135be1862384b1591 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Thu, 21 May 2026 15:50:34 +0800 Subject: [PATCH 09/26] Keep FFT torch baseline on CPU --- python/tests/test_fft.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 9c03530b04..0bb190bd99 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -398,12 +398,14 @@ def g(x): x = mx.random.normal(sh) else: x = mx.random.normal((*sh, 2)).view(mx.complex64).squeeze() + mx.eval(x) + x_torch = torch.tensor(x, device="cpu") fx = f(x) - gx = g(torch.tensor(x)) + gx = g(x_torch) self.assertLess((fx - gx).abs().max() / gx.abs().mean(), 1e-4) dfdx = mx.grad(f)(x) - dgdx = torch.func.grad(g)(torch.tensor(x)) + dgdx = torch.func.grad(g)(x_torch) self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4) From 9400ca465011794b7856c679a1f30ccc2ddd12f5 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Thu, 21 May 2026 16:59:09 +0800 Subject: [PATCH 10/26] Handle private Metal DLPack buffers with copies --- CMakeLists.txt | 1 - docs/src/usage/numpy.rst | 13 ++++-- mlx/utils.cpp | 15 +----- mlx/utils.h | 2 - python/src/convert.cpp | 99 ++++++++++++++++++++++------------------ 5 files changed, 63 insertions(+), 67 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e33682b018..ff110ec929 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -358,7 +358,6 @@ if(MLX_BUILD_PYTHON_BINDINGS) nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG 117ba9ee0253e1bbada678bb4b5d6c6e4ed441eb - GIT_SHALLOW TRUE EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(nanobind) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index fc84122793..17338cf97a 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -94,10 +94,13 @@ The arrays do not share memory: print(c.tolist()) # [0, 1, 2] Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to -``mx.array`` or to ``mx.from_dlpack`` with ``copy=None`` or ``copy=False``, MLX -imports the underlying Metal buffer without copying it. The PyTorch tensor and -the MLX array then share the same storage. MLX arrays exported to PyTorch with -DLPack are also shared without a copy. +``mx.array`` or to ``mx.from_dlpack`` with ``copy=None``, MLX shares the +underlying Metal buffer when it is host-accessible. Private Metal buffers are +copied into MLX-managed storage instead. Passing ``copy=False`` requires +sharing and raises an error if a copy would be needed. MLX arrays exported to +PyTorch with DLPack are also shared without a copy. In particular, PyTorch 2.12 +and later use shared storage for ordinary MPS tensors on Apple silicon, while +older PyTorch versions may use private storage and require a copy on import. Since the buffer is shared across frameworks, synchronization has to be managed explicitly. After PyTorch writes to an MPS tensor, call @@ -117,7 +120,7 @@ need the update to be visible to another framework sharing the same buffer: b = torch.arange(3, device="mps", dtype=torch.float32) torch.mps.synchronize() - c = mx.array(b) # zero-copy Metal DLPack import + c = mx.array(b) # zero-copy if the Metal buffer is host-accessible b.add_(10) torch.mps.synchronize() diff --git a/mlx/utils.cpp b/mlx/utils.cpp index ce5b8f3387..2b45f52d88 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -211,19 +211,6 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) { return os; } -array host_accessible_array(array a) { - a.eval(); - a.wait(); - if (a.buffer().is_host_accessible()) { - return a; - } - auto out = copy_to_new_buffer(std::move(a), Device::gpu); - out.eval(); - out.wait(); - out.detach(); - return out; -} - namespace { template @@ -291,7 +278,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { } std::ostream& operator<<(std::ostream& os, array a) { - a = host_accessible_array(std::move(a)); + a.eval(); dispatch_all_types(a.dtype(), [&](auto type_tag) { print_array(os, a); }); diff --git a/mlx/utils.h b/mlx/utils.h index b34ca1918f..d8b4c7ac99 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -69,8 +69,6 @@ MLX_API void set_printoptions(PrintOptions options); MLX_API PrintFormatter& get_global_formatter(); -MLX_API array host_accessible_array(array a); - /** Print the exception and then abort. */ MLX_API void abort_with_exception(const std::exception& error); diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 43e8cbc8b3..4bd1abc72f 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -106,7 +106,8 @@ mx::Dtype mlx_dtype_from_dlpack( mx::array metal_dlpack_to_mlx( nb::ndarray nd_array, - std::optional dtype); + std::optional dtype, + std::optional copy); mx::array nd_array_to_mlx( nb::ndarray nd_array, @@ -125,7 +126,7 @@ mx::array nd_array_to_mlx( }); } case nb::device::metal::value: - return metal_dlpack_to_mlx(std::move(nd_array), dtype); + return metal_dlpack_to_mlx(std::move(nd_array), dtype, std::nullopt); default: throw std::invalid_argument("Unsupported DLPack device."); } @@ -148,7 +149,7 @@ mx::array from_dlpack(nb::object v, std::optional copy) { dtype = mlx_dtype_from_dlpack( nd.dtype(), "Cannot convert Metal DLPack array to mlx array."); } - return nd_array_to_mlx(std::move(nd), dtype); + return metal_dlpack_to_mlx(std::move(nd), dtype, copy); } case nb::device::cuda::value: case nb::device::cuda_managed::value: @@ -163,7 +164,8 @@ mx::array metal_dlpack_to_mlx_contiguous( std::shared_ptr> owner, const mx::Shape& shape, mx::Dtype type, - std::optional dtype) { + std::optional dtype, + std::optional copy) { auto itemsize = mx::size_of(type); if (owner->itemsize() != itemsize) { throw std::invalid_argument( @@ -194,10 +196,18 @@ mx::array metal_dlpack_to_mlx_contiguous( out.copy_shared_buffer(out, out.strides(), flags, out.data_size(), offset); } - if (dtype) { - auto result = (*dtype == out.dtype()) + auto is_host_accessible = out.buffer().is_host_accessible(); + if (copy == false && !is_host_accessible) { + throw std::invalid_argument( + "Cannot import a non-host-accessible Metal DLPack buffer without a " + "copy."); + } + + if (dtype || !is_host_accessible) { + auto result_dtype = dtype.value_or(out.dtype()); + auto result = (result_dtype == out.dtype()) ? mx::copy_to_new_buffer(out, mx::Device::gpu) - : mx::astype(out, *dtype, mx::Device::gpu); + : mx::astype(out, result_dtype, mx::Device::gpu); result.eval(); result.wait(); result.detach(); @@ -212,7 +222,7 @@ nb::ndarray mlx_to_nd_array_impl( std::optional t = {}) { { nb::gil_scoped_release nogil; - a = host_accessible_array(std::move(a)); + a.eval(); } std::vector shape(a.shape().begin(), a.shape().end()); auto owner = nb::cast(a); @@ -274,7 +284,6 @@ nb::ndarray<> mlx_to_dlpack_impl(mx::array a, int dl_device_type) { a.eval(); a.wait(); if (dl_device_type == nb::device::cpu::value) { - a = host_accessible_array(std::move(a)); data = a.data(); } else { data = a.buffer().ptr(); @@ -358,40 +367,39 @@ nb::object to_scalar(mx::array& a) { throw std::invalid_argument( "[convert] Only length-1 arrays can be converted to Python scalars."); } - auto host = mx::array(a); { nb::gil_scoped_release nogil; - host = host_accessible_array(std::move(host)); + a.eval(); } - switch (host.dtype()) { + switch (a.dtype()) { case mx::bool_: - return nb::cast(host.item()); + return nb::cast(a.item()); case mx::uint8: - return nb::cast(host.item()); + return nb::cast(a.item()); case mx::uint16: - return nb::cast(host.item()); + return nb::cast(a.item()); case mx::uint32: - return nb::cast(host.item()); + return nb::cast(a.item()); case mx::uint64: - return nb::cast(host.item()); + return nb::cast(a.item()); case mx::int8: - return nb::cast(host.item()); + return nb::cast(a.item()); case mx::int16: - return nb::cast(host.item()); + return nb::cast(a.item()); case mx::int32: - return nb::cast(host.item()); + return nb::cast(a.item()); case mx::int64: - return nb::cast(host.item()); + return nb::cast(a.item()); case mx::float16: - return nb::cast(static_cast(host.item())); + return nb::cast(static_cast(a.item())); case mx::float32: - return nb::cast(host.item()); + return nb::cast(a.item()); case mx::bfloat16: - return nb::cast(static_cast(host.item())); + return nb::cast(static_cast(a.item())); case mx::complex64: - return nb::cast(host.item>()); + return nb::cast(a.item>()); case mx::float64: - return nb::cast(host.item()); + return nb::cast(a.item()); default: throw nb::type_error("type cannot be converted to Python scalar."); } @@ -399,7 +407,8 @@ nb::object to_scalar(mx::array& a) { mx::array metal_dlpack_to_mlx( nb::ndarray nd_array, - std::optional dtype) { + std::optional dtype, + std::optional copy) { auto owner = std::make_shared>(std::move(nd_array)); auto shape = get_shape(*owner); @@ -408,7 +417,8 @@ mx::array metal_dlpack_to_mlx( owner->dtype(), "Cannot convert Metal DLPack array to mlx array.", [&](mx::Dtype type) { - return metal_dlpack_to_mlx_contiguous(owner, shape, type, dtype); + return metal_dlpack_to_mlx_contiguous( + owner, shape, type, dtype, copy); }); } @@ -431,40 +441,39 @@ nb::object tolist(mx::array& a) { if (a.ndim() == 0) { return to_scalar(a); } - auto host = mx::array(a); { nb::gil_scoped_release nogil; - host = host_accessible_array(std::move(host)); + a.eval(); } - switch (host.dtype()) { + switch (a.dtype()) { case mx::bool_: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::uint8: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::uint16: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::uint32: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::uint64: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::int8: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::int16: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::int32: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::int64: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::float16: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::float32: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::bfloat16: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::float64: - return to_list(host, 0, 0); + return to_list(a, 0, 0); case mx::complex64: - return to_list>(host, 0, 0); + return to_list>(a, 0, 0); default: throw nb::type_error("data type cannot be converted to Python list."); } From 7b7b4d07b4a331a0e815aa2a5312b0e4cb8c2931 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Thu, 21 May 2026 17:05:55 +0800 Subject: [PATCH 11/26] Leave nanobind shallow clone note --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index ff110ec929..9380748597 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -358,6 +358,7 @@ if(MLX_BUILD_PYTHON_BINDINGS) nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG 117ba9ee0253e1bbada678bb4b5d6c6e4ed441eb + # GIT_SHALLOW TRUE EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(nanobind) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) From 310999b63517c793cd5c0289e6046d8dd41762a5 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Thu, 21 May 2026 17:34:16 +0800 Subject: [PATCH 12/26] Reduce Metal DLPack test redundancy --- python/tests/test_array.py | 54 +++----------------------------------- 1 file changed, 4 insertions(+), 50 deletions(-) diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 6813bcc18e..5ca16667f8 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2097,24 +2097,9 @@ def test_torch_mps_dlpack_import(self): self.assertEqual(y.dtype, mx.float32) torch.mps.synchronize() self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) - - @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_torch_mps_dlpack_host_access(self): - x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) - y = mx.array(x) - - torch.mps.synchronize() self.assertIn("array(", repr(y)) - self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) - # PyTorch 2.12 allocates ordinary MPS tensors in shared/unified - # MTLBuffers on Apple silicon, while private Metal DLPack producers are - # not host-accessible. - try: - mv = memoryview(y) - except BufferError: - pass - else: - self.assertEqual(mv.tolist(), x.cpu().numpy().tolist()) + mv = memoryview(y) + self.assertEqual(mv.tolist(), x.cpu().numpy().tolist()) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_zero_copy_shares_updates(self): @@ -2135,15 +2120,12 @@ def test_torch_mps_dlpack_matching_dtype_argument_shares_updates(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) torch.mps.synchronize() y = mx.array(x, dtype=mx.float32) + self.assertEqual(y.dtype, mx.float32) x.add_(100) torch.mps.synchronize() self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) - y.copy(mx.full(y.shape, 7.0)) - mx.eval(y) - self.assertEqual(x.cpu().numpy().tolist(), y.tolist()) - @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_different_dtype_argument_copies(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) @@ -2188,30 +2170,6 @@ def test_torch_mps_array_operand(self): torch.mps.synchronize() self.assertTrue(mx.array_equal(a + b_mps, mx.array([3]))) - @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_torch_mps_dlpack_mlx_update_writes_torch_buffer(self): - x = torch.arange(3, device="mps", dtype=torch.float32) - torch.mps.synchronize() - y = mx.array(x) - - y.copy(mx.array([3, 4, 5], dtype=mx.float32)) - mx.eval(y) - self.assertEqual(x.cpu().numpy().tolist(), [3, 4, 5]) - - @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_torch_mps_dlpack_copy_writes_torch_buffer(self): - x = torch.arange(3, device="mps", dtype=torch.float32) - torch.mps.synchronize() - y = mx.array(x) - - y.copy(mx.array([4, 5, 6], dtype=mx.float32)) - mx.eval(y) - self.assertEqual(x.cpu().numpy().tolist(), [4.0, 5.0, 6.0]) - - x.add_(10) - torch.mps.synchronize() - self.assertEqual(y.tolist(), [14.0, 15.0, 16.0]) - @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_mlx_dlpack_exports_mps_tensor_to_torch(self): x = mx.array([1]).astype(mx.float16) @@ -2269,13 +2227,9 @@ def test_from_dlpack_torch_mps_copy_false_shares_updates(self): torch.mps.synchronize() y = mx.from_dlpack(x, copy=False) - y.copy(mx.array([10, 11, 12], dtype=mx.float32)) - mx.eval(y) - self.assertEqual(x.cpu().numpy().tolist(), [10.0, 11.0, 12.0]) - x.add_(10) torch.mps.synchronize() - self.assertEqual(y.tolist(), [20.0, 21.0, 22.0]) + self.assertEqual(y.tolist(), [10.0, 11.0, 12.0]) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_from_dlpack_torch_mps_copy_true_copies(self): From da3559a70e9929aff845e3c39531e1e4f488e007 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Thu, 21 May 2026 20:02:51 +0800 Subject: [PATCH 13/26] Support DLPack copy export --- python/src/array.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 3b2189ab01..386be3ec0b 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -502,7 +502,7 @@ void init_array(nb::module_& m) { nb::object, nb::object, nb::object dl_device, - nb::object) { + nb::object copy) { std::optional dl_device_type; if (!dl_device.is_none()) { auto device = nb::cast(dl_device); @@ -512,6 +512,9 @@ void init_array(nb::module_& m) { } dl_device_type = nb::cast(device[0]); } + if (!copy.is_none() && nb::cast(copy)) { + return mlx_to_dlpack(mx::copy_to_new_buffer(a), dl_device_type); + } return mlx_to_dlpack(a, dl_device_type); }, nb::kw_only(), From 3aaf373b6092ec30863c056f8dcbb59c90887404 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Fri, 22 May 2026 13:53:00 +0800 Subject: [PATCH 14/26] Copy private Metal DLPack buffers on import --- docs/src/usage/numpy.rst | 14 +++++++------- mlx/allocator.h | 4 ---- mlx/backend/cuda/allocator.cpp | 8 -------- mlx/backend/metal/allocator.cpp | 13 +++---------- mlx/backend/no_gpu/allocator.cpp | 3 --- python/src/buffer.h | 8 -------- python/src/convert.cpp | 32 +++++++++++++++++++++++++++----- 7 files changed, 37 insertions(+), 45 deletions(-) diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index 17338cf97a..049f3cb1ca 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -95,12 +95,12 @@ The arrays do not share memory: Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to ``mx.array`` or to ``mx.from_dlpack`` with ``copy=None``, MLX shares the -underlying Metal buffer when it is host-accessible. Private Metal buffers are -copied into MLX-managed storage instead. Passing ``copy=False`` requires -sharing and raises an error if a copy would be needed. MLX arrays exported to -PyTorch with DLPack are also shared without a copy. In particular, PyTorch 2.12 -and later use shared storage for ordinary MPS tensors on Apple silicon, while -older PyTorch versions may use private storage and require a copy on import. +underlying Metal buffer when it is not private. Private Metal buffers are copied +into MLX-managed storage instead. Passing ``copy=False`` requires sharing and +raises an error if a copy would be needed. MLX arrays exported to PyTorch with +DLPack are also shared without a copy. In particular, PyTorch 2.12 and later use +shared storage for ordinary MPS tensors on Apple silicon, while older PyTorch +versions may use private storage and require a copy on import. Since the buffer is shared across frameworks, synchronization has to be managed explicitly. After PyTorch writes to an MPS tensor, call @@ -120,7 +120,7 @@ need the update to be visible to another framework sharing the same buffer: b = torch.arange(3, device="mps", dtype=torch.float32) torch.mps.synchronize() - c = mx.array(b) # zero-copy if the Metal buffer is host-accessible + c = mx.array(b) # zero-copy if the Metal buffer is not private b.add_(10) torch.mps.synchronize() diff --git a/mlx/allocator.h b/mlx/allocator.h index 185ccd6d76..824deac2c7 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -21,10 +21,6 @@ class MLX_API Buffer { // Get the raw data pointer from the buffer void* raw_ptr(); - // Whether raw_ptr() can return a host-accessible pointer without moving or - // copying the buffer. - bool is_host_accessible() const; - // Get the buffer pointer from the buffer const void* ptr() const { return ptr_; diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 4d67b2798e..718ae33e9c 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -416,14 +416,6 @@ void* Buffer::raw_ptr() { return cbuf.data; } -bool Buffer::is_host_accessible() const { - if (!ptr_) { - return true; - } - auto& cbuf = *static_cast(ptr_); - return cbuf.device == -1; -} - } // namespace allocator size_t get_active_memory() { diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 89ad12be6f..14ca0388ab 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -24,21 +24,14 @@ void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; } - if (!is_host_accessible()) { + auto* buf = static_cast(ptr_); + if (buf->storageMode() == MTL::StorageModePrivate) { throw std::runtime_error( "[metal::Buffer::raw_ptr] Cannot access Metal buffer contents on the " "host. The buffer is not CPU-addressable, for example because it uses " "private storage."); } - return static_cast(ptr_)->contents(); -} - -bool Buffer::is_host_accessible() const { - if (!ptr_) { - return true; - } - auto* buf = static_cast(ptr_); - return buf->storageMode() != MTL::StorageModePrivate; + return buf->contents(); } } // namespace allocator diff --git a/mlx/backend/no_gpu/allocator.cpp b/mlx/backend/no_gpu/allocator.cpp index 359fa36018..1fdcc262a4 100644 --- a/mlx/backend/no_gpu/allocator.cpp +++ b/mlx/backend/no_gpu/allocator.cpp @@ -178,9 +178,6 @@ void* Buffer::raw_ptr() { return static_cast(ptr_) + 1; } -bool Buffer::is_host_accessible() const { - return true; -} } // namespace allocator size_t get_active_memory() { diff --git a/python/src/buffer.h b/python/src/buffer.h index 8d01b82132..d7d4903565 100644 --- a/python/src/buffer.h +++ b/python/src/buffer.h @@ -93,14 +93,6 @@ extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) { a.eval(); a.wait(); } - if (!a.buffer().is_host_accessible()) { - PyErr_SetString( - PyExc_BufferError, - "Cannot provide a buffer for an array whose storage is not " - "CPU-addressable."); - return -1; - } - std::vector shape(a.shape().begin(), a.shape().end()); std::vector strides(a.strides().begin(), a.strides().end()); for (auto& s : strides) { diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 4bd1abc72f..62c45fe2c2 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -7,6 +7,10 @@ #include #include +#if __has_include() +#include +#endif + #include "python/src/convert.h" #include "python/src/utils.h" @@ -22,6 +26,21 @@ enum PyScalarT { pycomplex = 3, }; +namespace { + +bool metal_buffer_is_private(void* ptr) { +#if __has_include() + if (!ptr) { + return false; + } + auto* buf = static_cast(ptr); + return buf->storageMode() == MTL::StorageModePrivate; +#endif + return false; +} + +} // namespace + int check_shape_dim(int64_t dim) { if (dim > std::numeric_limits::max() || dim < std::numeric_limits::min()) { @@ -178,6 +197,8 @@ mx::array metal_dlpack_to_mlx_contiguous( "Metal DLPack byte offset is not aligned to dtype size."); } + auto is_private_buffer = metal_buffer_is_private(owner->data_handle()); + auto out = mx::array( mx::allocator::Buffer(owner->data_handle()), shape, @@ -196,14 +217,12 @@ mx::array metal_dlpack_to_mlx_contiguous( out.copy_shared_buffer(out, out.strides(), flags, out.data_size(), offset); } - auto is_host_accessible = out.buffer().is_host_accessible(); - if (copy == false && !is_host_accessible) { + if (copy == false && is_private_buffer) { throw std::invalid_argument( - "Cannot import a non-host-accessible Metal DLPack buffer without a " - "copy."); + "Cannot import a private Metal DLPack buffer without a copy."); } - if (dtype || !is_host_accessible) { + if (dtype || is_private_buffer) { auto result_dtype = dtype.value_or(out.dtype()); auto result = (result_dtype == out.dtype()) ? mx::copy_to_new_buffer(out, mx::Device::gpu) @@ -409,6 +428,9 @@ mx::array metal_dlpack_to_mlx( nb::ndarray nd_array, std::optional dtype, std::optional copy) { + if (!mx::metal::is_available()) { + throw std::invalid_argument("Metal DLPack import is not available."); + } auto owner = std::make_shared>(std::move(nd_array)); auto shape = get_shape(*owner); From 05a073e9a701fce5df65fa15e37da9f2aec003af Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Fri, 22 May 2026 14:27:38 +0800 Subject: [PATCH 15/26] Clean up Metal DLPack owner handling --- python/src/convert.cpp | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 62c45fe2c2..0a15f105d6 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. #include -#include #include #include @@ -180,27 +179,27 @@ mx::array from_dlpack(nb::object v, std::optional copy) { template mx::array metal_dlpack_to_mlx_contiguous( - std::shared_ptr> owner, + nb::ndarray owner, const mx::Shape& shape, mx::Dtype type, std::optional dtype, std::optional copy) { auto itemsize = mx::size_of(type); - if (owner->itemsize() != itemsize) { + if (owner.itemsize() != itemsize) { throw std::invalid_argument( "Cannot convert Metal DLPack dtype to mlx dtype."); } - auto byte_offset = owner->byte_offset(); + auto byte_offset = owner.byte_offset(); if (byte_offset % itemsize != 0) { throw std::invalid_argument( "Metal DLPack byte offset is not aligned to dtype size."); } - auto is_private_buffer = metal_buffer_is_private(owner->data_handle()); + auto is_private_buffer = metal_buffer_is_private(owner.data_handle()); auto out = mx::array( - mx::allocator::Buffer(owner->data_handle()), + mx::allocator::Buffer(owner.data_handle()), shape, type, [](mx::allocator::Buffer) {}); @@ -228,7 +227,6 @@ mx::array metal_dlpack_to_mlx_contiguous( ? mx::copy_to_new_buffer(out, mx::Device::gpu) : mx::astype(out, result_dtype, mx::Device::gpu); result.eval(); - result.wait(); result.detach(); return result; } @@ -301,7 +299,6 @@ nb::ndarray<> mlx_to_dlpack_impl(mx::array a, int dl_device_type) { { nb::gil_scoped_release nogil; a.eval(); - a.wait(); if (dl_device_type == nb::device::cpu::value) { data = a.data(); } else { @@ -431,12 +428,11 @@ mx::array metal_dlpack_to_mlx( if (!mx::metal::is_available()) { throw std::invalid_argument("Metal DLPack import is not available."); } - auto owner = - std::make_shared>(std::move(nd_array)); - auto shape = get_shape(*owner); + auto owner = std::move(nd_array); + auto shape = get_shape(owner); return dispatch_dlpack_dtype( - owner->dtype(), + owner.dtype(), "Cannot convert Metal DLPack array to mlx array.", [&](mx::Dtype type) { return metal_dlpack_to_mlx_contiguous( From 6cab38d45ccaa66a9006639cb193b07da979bbde Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Fri, 22 May 2026 15:07:37 +0800 Subject: [PATCH 16/26] Clarify DLPack conversion dtype names --- python/src/convert.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 0a15f105d6..4c064abb31 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -129,22 +129,22 @@ mx::array metal_dlpack_to_mlx( mx::array nd_array_to_mlx( nb::ndarray nd_array, - std::optional dtype, - std::optional nb_dtype) { + std::optional dst_dtype, + std::optional src_dtype) { switch (nd_array.device_type()) { case nb::device::cpu::value: { auto shape = get_shape(nd_array); - auto type = nb_dtype.value_or(nd_array.dtype()); + auto type = src_dtype.value_or(nd_array.dtype()); return dispatch_dlpack_dtype( type, "Cannot convert numpy array to mlx array.", [&](mx::Dtype default_dtype) { return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(default_dtype)); + nd_array, shape, dst_dtype.value_or(default_dtype)); }); } case nb::device::metal::value: - return metal_dlpack_to_mlx(std::move(nd_array), dtype, std::nullopt); + return metal_dlpack_to_mlx(std::move(nd_array), dst_dtype, std::nullopt); default: throw std::invalid_argument("Unsupported DLPack device."); } @@ -156,14 +156,14 @@ mx::array from_dlpack(nb::object v, std::optional copy) { switch (nd.device_type()) { case nb::device::cpu::value: - if (copy == false) { + if (copy.has_value() && copy.value() == false) { throw std::invalid_argument( "Cannot import a CPU DLPack array without a copy."); } return nd_array_to_mlx(std::move(nd), std::nullopt); case nb::device::metal::value: { std::optional dtype; - if (copy == true) { + if (copy.has_value() && copy.value() == true) { dtype = mlx_dtype_from_dlpack( nd.dtype(), "Cannot convert Metal DLPack array to mlx array."); } @@ -216,7 +216,7 @@ mx::array metal_dlpack_to_mlx_contiguous( out.copy_shared_buffer(out, out.strides(), flags, out.data_size(), offset); } - if (copy == false && is_private_buffer) { + if (copy.has_value() && copy.value() == false && is_private_buffer) { throw std::invalid_argument( "Cannot import a private Metal DLPack buffer without a copy."); } From f88366d4d169bf8bf3bbce85268034ce368a9b2b Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Fri, 22 May 2026 15:29:36 +0800 Subject: [PATCH 17/26] Clean up DLPack conversion paths --- docs/src/python/ops.rst | 1 + python/src/convert.cpp | 117 ++++++++++++++++++++-------------------- python/src/convert.h | 3 +- 3 files changed, 60 insertions(+), 61 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 6aabfc6954..19608ff37b 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -77,6 +77,7 @@ Operations floor floor_divide full + from_dlpack gather_mm gather_qmm greater diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 4c064abb31..59ab543f6d 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -122,6 +122,41 @@ mx::Dtype mlx_dtype_from_dlpack( type, error_message, [](mx::Dtype dtype) { return dtype; }); } +nb::dlpack::dtype mlx_dtype_to_dl_dtype(mx::Dtype dtype) { + switch (dtype) { + case mx::bool_: + return nb::dtype(); + case mx::uint8: + return nb::dtype(); + case mx::uint16: + return nb::dtype(); + case mx::uint32: + return nb::dtype(); + case mx::uint64: + return nb::dtype(); + case mx::int8: + return nb::dtype(); + case mx::int16: + return nb::dtype(); + case mx::int32: + return nb::dtype(); + case mx::int64: + return nb::dtype(); + case mx::float16: + return nb::dtype(); + case mx::bfloat16: + return nb::dtype(); + case mx::float32: + return nb::dtype(); + case mx::float64: + return nb::dtype(); + case mx::complex64: + return nb::dtype>(); + default: + throw nb::buffer_error("Cannot export mlx array with unsupported dtype."); + } +} + mx::array metal_dlpack_to_mlx( nb::ndarray nd_array, std::optional dtype, @@ -130,9 +165,14 @@ mx::array metal_dlpack_to_mlx( mx::array nd_array_to_mlx( nb::ndarray nd_array, std::optional dst_dtype, - std::optional src_dtype) { + std::optional src_dtype, + std::optional copy) { switch (nd_array.device_type()) { case nb::device::cpu::value: { + if (copy.has_value() && copy.value() == false) { + throw std::invalid_argument( + "Cannot import a CPU DLPack array without a copy."); + } auto shape = get_shape(nd_array); auto type = src_dtype.value_or(nd_array.dtype()); return dispatch_dlpack_dtype( @@ -144,7 +184,15 @@ mx::array nd_array_to_mlx( }); } case nb::device::metal::value: - return metal_dlpack_to_mlx(std::move(nd_array), dst_dtype, std::nullopt); + if (copy.has_value() && copy.value() == true && !dst_dtype) { + dst_dtype = mlx_dtype_from_dlpack( + nd_array.dtype(), + "Cannot convert Metal DLPack array to mlx array."); + } + return metal_dlpack_to_mlx(std::move(nd_array), dst_dtype, copy); + case nb::device::cuda::value: + case nb::device::cuda_managed::value: + throw std::invalid_argument("CUDA DLPack import is not supported."); default: throw std::invalid_argument("Unsupported DLPack device."); } @@ -153,28 +201,8 @@ mx::array nd_array_to_mlx( mx::array from_dlpack(nb::object v, std::optional copy) { using ContigArray = nb::ndarray; auto nd = nb::cast(v); - - switch (nd.device_type()) { - case nb::device::cpu::value: - if (copy.has_value() && copy.value() == false) { - throw std::invalid_argument( - "Cannot import a CPU DLPack array without a copy."); - } - return nd_array_to_mlx(std::move(nd), std::nullopt); - case nb::device::metal::value: { - std::optional dtype; - if (copy.has_value() && copy.value() == true) { - dtype = mlx_dtype_from_dlpack( - nd.dtype(), "Cannot convert Metal DLPack array to mlx array."); - } - return metal_dlpack_to_mlx(std::move(nd), dtype, copy); - } - case nb::device::cuda::value: - case nb::device::cuda_managed::value: - throw std::invalid_argument("CUDA DLPack import is not supported."); - default: - throw std::invalid_argument("Unsupported DLPack device."); - } + return nd_array_to_mlx( + std::move(nd), std::nullopt, std::nullopt, std::move(copy)); } template @@ -292,15 +320,15 @@ nb::ndarray mlx_to_np_array(const mx::array& a) { return mlx_to_nd_array(a); } -template -nb::ndarray<> mlx_to_dlpack_impl(mx::array a, int dl_device_type) { +nb::ndarray<> +mlx_to_dlpack_impl(mx::array a, int dl_device_type, nb::dlpack::dtype dtype) { void* data = nullptr; uint64_t byte_offset = 0; { nb::gil_scoped_release nogil; a.eval(); if (dl_device_type == nb::device::cpu::value) { - data = a.data(); + data = static_cast(a.buffer().raw_ptr()) + a.offset(); } else { data = a.buffer().ptr(); byte_offset = a.offset(); @@ -315,7 +343,7 @@ nb::ndarray<> mlx_to_dlpack_impl(mx::array a, int dl_device_type) { shape.data(), /* owner= */ owner, a.strides().data(), - nb::dtype(), + dtype, dl_device_type, 0, '\0', @@ -344,38 +372,7 @@ nb::ndarray<> mlx_to_dlpack( throw nb::buffer_error("Metal DLPack export is not available."); } - switch (a.dtype()) { - case mx::bool_: - return mlx_to_dlpack_impl(a, device_type); - case mx::uint8: - return mlx_to_dlpack_impl(a, device_type); - case mx::uint16: - return mlx_to_dlpack_impl(a, device_type); - case mx::uint32: - return mlx_to_dlpack_impl(a, device_type); - case mx::uint64: - return mlx_to_dlpack_impl(a, device_type); - case mx::int8: - return mlx_to_dlpack_impl(a, device_type); - case mx::int16: - return mlx_to_dlpack_impl(a, device_type); - case mx::int32: - return mlx_to_dlpack_impl(a, device_type); - case mx::int64: - return mlx_to_dlpack_impl(a, device_type); - case mx::float16: - return mlx_to_dlpack_impl(a, device_type); - case mx::bfloat16: - return mlx_to_dlpack_impl(a, device_type); - case mx::float32: - return mlx_to_dlpack_impl(a, device_type); - case mx::float64: - return mlx_to_dlpack_impl(a, device_type); - case mx::complex64: - return mlx_to_dlpack_impl>(a, device_type); - default: - throw nb::buffer_error("Cannot export mlx array with unsupported dtype."); - } + return mlx_to_dlpack_impl(a, device_type, mlx_dtype_to_dl_dtype(a.dtype())); } nb::object to_scalar(mx::array& a) { diff --git a/python/src/convert.h b/python/src/convert.h index bf93540f2a..cf17e1c7ce 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -63,7 +63,8 @@ struct ArrayLike { mx::array nd_array_to_mlx( nb::ndarray nd_array, std::optional mx_dtype, - std::optional nb_dtype = std::nullopt); + std::optional nb_dtype = std::nullopt, + std::optional copy = std::nullopt); nb::ndarray mlx_to_np_array(const mx::array& a); nb::ndarray<> mlx_to_dlpack( From 07784c274b475671fb3e79d81055d44c45a8f7c9 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Fri, 22 May 2026 15:40:29 +0800 Subject: [PATCH 18/26] Revert "Add array copy for shared DLPack buffers" This reverts commit eb7469561d57460dc4efd595d2bfc37f44f9a8dd. --- docs/src/python/array.rst | 1 - docs/src/usage/numpy.rst | 90 +++++++++++----------------------- mlx/backend/cpu/primitives.cpp | 5 -- mlx/backend/gpu/primitives.cpp | 7 --- mlx/ops.cpp | 14 ------ mlx/ops.h | 4 -- mlx/primitives.h | 11 ----- python/src/array.cpp | 23 --------- python/src/convert.cpp | 14 +++--- python/tests/test_array.py | 20 +------- 10 files changed, 38 insertions(+), 151 deletions(-) diff --git a/docs/src/python/array.rst b/docs/src/python/array.rst index 42ea2186f9..e68524d5a4 100644 --- a/docs/src/python/array.rst +++ b/docs/src/python/array.rst @@ -11,7 +11,6 @@ Array array array.astype array.at - array.copy array.item array.tolist array.dtype diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index 049f3cb1ca..bfc762fcbf 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -77,90 +77,58 @@ PyTorch PyTorch supports DLPack inputs and can import MLX arrays directly. MLX can also import PyTorch tensors through DLPack with ``mx.array`` or -``mx.from_dlpack``. - -Creating an MLX array from a CPU tensor copies the data into MLX-owned storage. -The arrays do not share memory: +``mx.from_dlpack``. Use ``torch.as_tensor`` to import an MLX array with +DLPack; ``torch.tensor`` copies the data instead: .. code-block:: python import mlx.core as mx import torch - b = torch.arange(3).cpu() + a = mx.arange(3, dtype=mx.float32) + mx.eval(a) + + shared = torch.as_tensor(a) + copied = torch.tensor(a) + +Creating an MLX array from a CPU tensor copies the data into MLX-owned storage. +The arrays do not share memory: + +.. code-block:: python + + b = torch.arange(3) c = mx.array(b) b += 10 print(c.tolist()) # [0, 1, 2] Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to -``mx.array`` or to ``mx.from_dlpack`` with ``copy=None``, MLX shares the -underlying Metal buffer when it is not private. Private Metal buffers are copied -into MLX-managed storage instead. Passing ``copy=False`` requires sharing and -raises an error if a copy would be needed. MLX arrays exported to PyTorch with -DLPack are also shared without a copy. In particular, PyTorch 2.12 and later use -shared storage for ordinary MPS tensors on Apple silicon, while older PyTorch -versions may use private storage and require a copy on import. - -Since the buffer is shared across frameworks, synchronization has to be managed -explicitly. After PyTorch writes to an MPS tensor, call -``torch.mps.synchronize()`` before reading the shared data from MLX. After MLX -writes to the shared array, call ``mx.eval`` on the MLX result before reading -the shared data from PyTorch. Without these synchronization points, the other -framework may read the shared buffer before the producer has finished writing, -so it can observe stale data. - -Do not rely on regular MLX update expressions to preserve sharing. MLX may -rebind an array to a new buffer when evaluating an expression, while PyTorch -continues to reference the original shared buffer. Use -:meth:`mlx.core.array.copy` to write into an array's existing storage when you -need the update to be visible to another framework sharing the same buffer: +``mx.array`` or to ``mx.from_dlpack`` with ``copy=None``, MLX imports it +without a copy when the underlying Metal buffer is not private. Private Metal +buffers are copied into MLX-managed storage instead. Passing ``copy=False`` +requires zero-copy import and raises an error if a copy would be needed. +Passing ``copy=True`` asks MLX to create a new array instead of reusing the +Metal buffer. MLX arrays exported to PyTorch with DLPack are also exported +without a copy on Metal. + +In particular, PyTorch 2.12 and later use shared storage for ordinary MPS +tensors on Apple silicon, while older PyTorch versions may use private storage +and require a copy on import. DLPack conversion does not synchronize pending +Metal work; synchronize or evaluate the producing framework before reading the +converted array. .. code-block:: python b = torch.arange(3, device="mps", dtype=torch.float32) torch.mps.synchronize() c = mx.array(b) # zero-copy if the Metal buffer is not private - - b.add_(10) - torch.mps.synchronize() - print(c.tolist()) # [10.0, 11.0, 12.0] - - c.copy(mx.array([4, 5, 6], dtype=mx.float32)) - mx.eval(c) - print(b.cpu()) # tensor([4., 5., 6.]) - -The same synchronization rules apply when PyTorch imports an MLX array through -DLPack. Use ``torch.as_tensor`` to share the buffer with PyTorch; ``torch.tensor`` -copies the data instead: + d = mx.from_dlpack(b, copy=True) # explicit copy .. code-block:: python a = mx.arange(3, dtype=mx.float32) mx.eval(a) - b = torch.as_tensor(a) - - b.add_(10) - torch.mps.synchronize() - print(a.tolist()) # [10.0, 11.0, 12.0] - - a.copy(mx.array([4, 5, 6], dtype=mx.float32)) - mx.eval(a) - print(b.cpu()) # tensor([4., 5., 6.]) - -Use ``mx.from_dlpack`` when you need to control the copy behavior. Specifying -``copy=True`` asks MLX to create a new array instead of sharing the Metal -buffer: - -.. code-block:: python - - b = torch.arange(3, device="mps", dtype=torch.float32) - torch.mps.synchronize() - c = mx.from_dlpack(b, copy=True) - - b.add_(10) - torch.mps.synchronize() - print(c.tolist()) # [0.0, 1.0, 2.0] + b = torch.as_tensor(a) # zero-copy DLPack import on Metal JAX --- diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index 133dab2104..f1d83dd306 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -88,11 +88,6 @@ void BroadcastAxes::eval_cpu(const std::vector& inputs, array& out) { void Copy::eval_cpu(const std::vector& inputs, array& out) { eval(inputs, out); } -void CopyInto::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - out.copy_shared_buffer(inputs[0]); - copy_cpu_inplace(inputs[1], out, CopyType::GeneralGeneral, stream()); -} void CustomTransforms::eval_cpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 29f9c2d085..268d6290bf 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -66,13 +66,6 @@ void Copy::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } -void CopyInto::eval_gpu(const std::vector& inputs, array& out) { - MLX_PROFILER_RANGE("CopyInto::eval_gpu"); - assert(inputs.size() == 2); - out.copy_shared_buffer(inputs[0]); - copy_gpu_inplace(inputs[1], out, CopyType::GeneralGeneral, stream()); -} - void CustomTransforms::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 983e1bb116..4598529e6c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -296,20 +296,6 @@ array copy(array a, StreamOrDevice s /* = {} */) { {std::move(a)}); } -array copy_into( - const array& dst, - const array& src, - StreamOrDevice s /* = {} */) { - auto stream = to_stream(s); - auto update = - broadcast_to(astype(src, dst.dtype(), stream), dst.shape(), stream); - return array( - dst.shape(), - dst.dtype(), - std::make_shared(stream), - {dst, std::move(update)}); -} - array copy_to_new_buffer(array a, StreamOrDevice s /* = {} */) { auto copied_shape = a.shape(); // |a| will be moved auto dtype = a.dtype(); diff --git a/mlx/ops.h b/mlx/ops.h index 35edd17d88..3084de3962 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -60,10 +60,6 @@ MLX_API array as_strided( /** Copy another array. */ MLX_API array copy(array a, StreamOrDevice s = {}); -/** Copy an array into another array's storage. */ -MLX_API array -copy_into(const array& dst, const array& src, StreamOrDevice s = {}); - /** Copy another array into newly allocated storage. */ MLX_API array copy_to_new_buffer(array a, StreamOrDevice s = {}); diff --git a/mlx/primitives.h b/mlx/primitives.h index 9a8c455b52..75fb978dce 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -803,17 +803,6 @@ class Copy : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; -class CopyInto : public UnaryPrimitive { - public: - explicit CopyInto(Stream stream) : UnaryPrimitive(stream) {} - - void eval_cpu(const std::vector& inputs, array& out) override; - void eval_gpu(const std::vector& inputs, array& out) override; - - DEFINE_NAME(CopyInto) - DEFINE_INPUT_OUTPUT_SHAPE() -}; - class Cos : public UnaryPrimitive { public: explicit Cos(Stream stream) : UnaryPrimitive(stream) {} diff --git a/python/src/array.cpp b/python/src/array.cpp index 386be3ec0b..63d765a68f 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -541,29 +541,6 @@ void init_array(nb::module_& m) { "__deepcopy__", [](const mx::array& self, nb::dict) { return mx::array(self); }, "memo"_a) - .def( - "copy", - [](mx::array& a, - const ScalarOrArray& v, - mx::StreamOrDevice s) -> mx::array& { - a.eval(); - auto out = mx::copy_into(a, to_array(v, a.dtype()), s); - a.overwrite_descriptor(out); - return a; - }, - "src"_a, - "stream"_a = nb::none(), - nb::rv_policy::none, - nb::sig( - "def copy(self, src: Union[scalar, array, DLPackCompatible], /, *, stream: Union[None, Stream, Device] = None) -> array"), - R"pbdoc( - Copy ``src`` into the array's existing storage. - - ``src`` can be a scalar, an MLX array, or a DLPack-compatible - array. The destination array keeps its shape, dtype, and storage. - ``src`` is cast to the destination dtype and broadcast to the - destination shape before being written. - )pbdoc") .def( "__add__", [](const mx::array& a, const ScalarOrArray v) { diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 59ab543f6d..ef931bf76f 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -78,8 +78,8 @@ mx::array nd_array_to_mlx_contiguous( template auto dispatch_dlpack_dtype( nb::dlpack::dtype type, - const char* error_message, - F&& f) { + F&& f, + const char* error_message) { if (type == nb::dtype()) { return f.template operator()(mx::bool_); } else if (type == nb::dtype()) { @@ -119,7 +119,7 @@ mx::Dtype mlx_dtype_from_dlpack( nb::dlpack::dtype type, const char* error_message) { return dispatch_dlpack_dtype( - type, error_message, [](mx::Dtype dtype) { return dtype; }); + type, [](mx::Dtype dtype) { return dtype; }, error_message); } nb::dlpack::dtype mlx_dtype_to_dl_dtype(mx::Dtype dtype) { @@ -177,11 +177,11 @@ mx::array nd_array_to_mlx( auto type = src_dtype.value_or(nd_array.dtype()); return dispatch_dlpack_dtype( type, - "Cannot convert numpy array to mlx array.", [&](mx::Dtype default_dtype) { return nd_array_to_mlx_contiguous( nd_array, shape, dst_dtype.value_or(default_dtype)); - }); + }, + "Cannot convert numpy array to mlx array."); } case nb::device::metal::value: if (copy.has_value() && copy.value() == true && !dst_dtype) { @@ -430,11 +430,11 @@ mx::array metal_dlpack_to_mlx( return dispatch_dlpack_dtype( owner.dtype(), - "Cannot convert Metal DLPack array to mlx array.", [&](mx::Dtype type) { return metal_dlpack_to_mlx_contiguous( owner, shape, type, dtype, copy); - }); + }, + "Cannot convert Metal DLPack array to mlx array."); } template diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 5ca16667f8..be3061d46f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -956,22 +956,6 @@ def test_array_copy(self): y -= 1 self.assertEqualArray(y, x - 1) - def test_array_copy_method(self): - x = mx.array([1, 2, 3]) - y = x.copy(mx.array([4, 5, 6])) - self.assertIs(y, x) - mx.eval(x) - self.assertEqualArray(x, mx.array([4, 5, 6])) - - x.copy(9) - mx.eval(x) - self.assertEqualArray(x, mx.array([9, 9, 9])) - - x.copy(np.array([1.5, 2.5, 3.5])) - mx.eval(x) - self.assertEqual(x.dtype, mx.int32) - self.assertEqualArray(x, mx.array([1, 2, 3])) - def test_indexing(self): # Only ellipsis is a no-op a_mlx = mx.array([1])[...] @@ -2111,7 +2095,7 @@ def test_torch_mps_dlpack_zero_copy_shares_updates(self): torch.mps.synchronize() self.assertEqual((y + 1).tolist(), (x + 1).cpu().numpy().tolist()) - y.copy(mx.full(y.shape, 7.0)) + y += 10 mx.eval(y) self.assertEqual(x.cpu().numpy().tolist(), y.tolist()) @@ -2217,7 +2201,7 @@ def test_from_dlpack_torch_mps_copy_none_shares_updates(self): torch.mps.synchronize() self.assertEqual(y.tolist(), [10.0, 11.0, 12.0]) - y.copy(mx.array([20, 21, 22], dtype=mx.float32)) + y += 10 mx.eval(y) self.assertEqual(x.cpu().numpy().tolist(), [20.0, 21.0, 22.0]) From 09fa65dd08ba97be41d87db66f25646d4971d282 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sat, 23 May 2026 13:44:42 +0800 Subject: [PATCH 19/26] Address DLPack review cleanup --- python/src/array.cpp | 23 ++--- python/src/buffer.h | 1 - python/src/convert.cpp | 192 +++++++++++++++++++---------------------- python/src/convert.h | 4 +- python/src/ops.cpp | 4 +- 5 files changed, 100 insertions(+), 124 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 63d765a68f..3accf6172b 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -2,11 +2,13 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -501,21 +503,12 @@ void init_array(nb::module_& m) { [](const mx::array& a, nb::object, nb::object, - nb::object dl_device, - nb::object copy) { - std::optional dl_device_type; - if (!dl_device.is_none()) { - auto device = nb::cast(dl_device); - if (nb::len(device) != 2) { - throw nb::type_error( - "dl_device must be None or a tuple[int, int]"); - } - dl_device_type = nb::cast(device[0]); - } - if (!copy.is_none() && nb::cast(copy)) { - return mlx_to_dlpack(mx::copy_to_new_buffer(a), dl_device_type); - } - return mlx_to_dlpack(a, dl_device_type); + std::optional> dl_device, + std::optional copy) { + if (copy.value_or(false)) { + return mlx_to_dlpack(mx::copy_to_new_buffer(a), dl_device); + } + return mlx_to_dlpack(a, dl_device); }, nb::kw_only(), "stream"_a = nb::none(), diff --git a/python/src/buffer.h b/python/src/buffer.h index d7d4903565..49ad0d64d0 100644 --- a/python/src/buffer.h +++ b/python/src/buffer.h @@ -91,7 +91,6 @@ extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) { { nb::gil_scoped_release nogil; a.eval(); - a.wait(); } std::vector shape(a.shape().begin(), a.shape().end()); std::vector strides(a.strides().begin(), a.strides().end()); diff --git a/python/src/convert.cpp b/python/src/convert.cpp index ef931bf76f..c1bdcab04b 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -157,54 +158,6 @@ nb::dlpack::dtype mlx_dtype_to_dl_dtype(mx::Dtype dtype) { } } -mx::array metal_dlpack_to_mlx( - nb::ndarray nd_array, - std::optional dtype, - std::optional copy); - -mx::array nd_array_to_mlx( - nb::ndarray nd_array, - std::optional dst_dtype, - std::optional src_dtype, - std::optional copy) { - switch (nd_array.device_type()) { - case nb::device::cpu::value: { - if (copy.has_value() && copy.value() == false) { - throw std::invalid_argument( - "Cannot import a CPU DLPack array without a copy."); - } - auto shape = get_shape(nd_array); - auto type = src_dtype.value_or(nd_array.dtype()); - return dispatch_dlpack_dtype( - type, - [&](mx::Dtype default_dtype) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dst_dtype.value_or(default_dtype)); - }, - "Cannot convert numpy array to mlx array."); - } - case nb::device::metal::value: - if (copy.has_value() && copy.value() == true && !dst_dtype) { - dst_dtype = mlx_dtype_from_dlpack( - nd_array.dtype(), - "Cannot convert Metal DLPack array to mlx array."); - } - return metal_dlpack_to_mlx(std::move(nd_array), dst_dtype, copy); - case nb::device::cuda::value: - case nb::device::cuda_managed::value: - throw std::invalid_argument("CUDA DLPack import is not supported."); - default: - throw std::invalid_argument("Unsupported DLPack device."); - } -} - -mx::array from_dlpack(nb::object v, std::optional copy) { - using ContigArray = nb::ndarray; - auto nd = nb::cast(v); - return nd_array_to_mlx( - std::move(nd), std::nullopt, std::nullopt, std::move(copy)); -} - template mx::array metal_dlpack_to_mlx_contiguous( nb::ndarray owner, @@ -261,6 +214,60 @@ mx::array metal_dlpack_to_mlx_contiguous( return out; } +mx::array metal_dlpack_to_mlx( + nb::ndarray nd_array, + std::optional dtype, + std::optional copy) { + if (!mx::metal::is_available()) { + throw std::invalid_argument("Metal DLPack import is not available."); + } + auto shape = get_shape(nd_array); + + return dispatch_dlpack_dtype( + nd_array.dtype(), + [&](mx::Dtype type) { + return metal_dlpack_to_mlx_contiguous( + nd_array, shape, type, dtype, copy); + }, + "Cannot convert Metal DLPack array to mlx array."); +} + +mx::array nd_array_to_mlx( + nb::ndarray nd_array, + std::optional dst_dtype, + std::optional src_dtype, + std::optional copy) { + switch (nd_array.device_type()) { + case nb::device::cpu::value: { + if (copy.has_value() && copy.value() == false) { + throw std::invalid_argument( + "Cannot import a CPU DLPack array without a copy."); + } + auto shape = get_shape(nd_array); + auto type = src_dtype.value_or(nd_array.dtype()); + return dispatch_dlpack_dtype( + type, + [&](mx::Dtype default_dtype) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dst_dtype.value_or(default_dtype)); + }, + "Cannot convert numpy array to mlx array."); + } + case nb::device::metal::value: + if (copy.has_value() && copy.value() == true && !dst_dtype) { + dst_dtype = mlx_dtype_from_dlpack( + nd_array.dtype(), + "Cannot convert Metal DLPack array to mlx array."); + } + return metal_dlpack_to_mlx(nd_array, dst_dtype, copy); + case nb::device::cuda::value: + case nb::device::cuda_managed::value: + throw std::invalid_argument("CUDA DLPack import is not supported."); + default: + throw std::invalid_argument("Unsupported DLPack device."); + } +} + template nb::ndarray mlx_to_nd_array_impl( mx::array a, @@ -320,44 +327,14 @@ nb::ndarray mlx_to_np_array(const mx::array& a) { return mlx_to_nd_array(a); } -nb::ndarray<> -mlx_to_dlpack_impl(mx::array a, int dl_device_type, nb::dlpack::dtype dtype) { - void* data = nullptr; - uint64_t byte_offset = 0; - { - nb::gil_scoped_release nogil; - a.eval(); - if (dl_device_type == nb::device::cpu::value) { - data = static_cast(a.buffer().raw_ptr()) + a.offset(); - } else { - data = a.buffer().ptr(); - byte_offset = a.offset(); - } - } - - std::vector shape(a.shape().begin(), a.shape().end()); - auto owner = nb::cast(a); - return nb::ndarray<>( - data, - a.ndim(), - shape.data(), - /* owner= */ owner, - a.strides().data(), - dtype, - dl_device_type, - 0, - '\0', - byte_offset); -} - nb::ndarray<> mlx_to_dlpack( const mx::array& a, - std::optional dl_device_type) { - int device_type = dl_device_type.value_or( - mx::metal::is_available() - ? nb::device::metal::value - : (mx::cu::is_available() ? nb::device::cuda_managed::value - : nb::device::cpu::value)); + std::optional> dl_device) { + auto default_device = mx::metal::is_available() + ? std::tuple{nb::device::metal::value, 0} + : (mx::cu::is_available() ? std::tuple{nb::device::cuda_managed::value, 0} + : std::tuple{nb::device::cpu::value, 0}); + auto [device_type, device_id] = dl_device.value_or(default_device); if (device_type == nb::device::cuda::value || device_type == nb::device::cuda_managed::value) { @@ -372,7 +349,33 @@ nb::ndarray<> mlx_to_dlpack( throw nb::buffer_error("Metal DLPack export is not available."); } - return mlx_to_dlpack_impl(a, device_type, mlx_dtype_to_dl_dtype(a.dtype())); + auto arr = a; + void* data = nullptr; + uint64_t byte_offset = 0; + { + nb::gil_scoped_release nogil; + arr.eval(); + if (device_type == nb::device::cpu::value) { + data = static_cast(arr.buffer().raw_ptr()) + arr.offset(); + } else { + data = arr.buffer().ptr(); + byte_offset = arr.offset(); + } + } + + std::vector shape(arr.shape().begin(), arr.shape().end()); + auto owner = nb::cast(arr); + return nb::ndarray<>( + data, + arr.ndim(), + shape.data(), + /* owner= */ owner, + arr.strides().data(), + mlx_dtype_to_dl_dtype(arr.dtype()), + device_type, + device_id, + '\0', + byte_offset); } nb::object to_scalar(mx::array& a) { @@ -418,25 +421,6 @@ nb::object to_scalar(mx::array& a) { } } -mx::array metal_dlpack_to_mlx( - nb::ndarray nd_array, - std::optional dtype, - std::optional copy) { - if (!mx::metal::is_available()) { - throw std::invalid_argument("Metal DLPack import is not available."); - } - auto owner = std::move(nd_array); - auto shape = get_shape(owner); - - return dispatch_dlpack_dtype( - owner.dtype(), - [&](mx::Dtype type) { - return metal_dlpack_to_mlx_contiguous( - owner, shape, type, dtype, copy); - }, - "Cannot convert Metal DLPack array to mlx array."); -} - template nb::list to_list(mx::array& a, size_t index, int dim) { nb::list pl; diff --git a/python/src/convert.h b/python/src/convert.h index cf17e1c7ce..17030be47c 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -2,6 +2,7 @@ #pragma once #include +#include #include #include @@ -69,14 +70,13 @@ mx::array nd_array_to_mlx( nb::ndarray mlx_to_np_array(const mx::array& a); nb::ndarray<> mlx_to_dlpack( const mx::array& a, - std::optional dl_device_type = std::nullopt); + std::optional> dl_device = std::nullopt); nb::object to_scalar(mx::array& a); nb::object tolist(mx::array& a); mx::array create_array(nb::object v, std::optional t); -mx::array from_dlpack(nb::object v, std::optional copy); mx::array array_from_list(nb::list pl, std::optional dtype); mx::array array_from_list(nb::tuple pl, std::optional dtype); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 8f70389983..07279eeb42 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1797,8 +1797,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "from_dlpack", - [](const nb::object& x, std::optional copy) { - return from_dlpack(x, copy); + [](nb::ndarray x, std::optional copy) { + return nd_array_to_mlx(x, std::nullopt, std::nullopt, copy); }, nb::arg(), nb::kw_only(), From f5bab00b42a487b4ce7e6f4e4558ac6515238371 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sat, 23 May 2026 14:26:13 +0800 Subject: [PATCH 20/26] Address DLPack buffer reuse review --- mlx/allocator.h | 2 ++ mlx/array.cpp | 4 ++- mlx/array.h | 1 + mlx/backend/cuda/allocator.cpp | 8 ++++++ mlx/backend/metal/allocator.cpp | 8 ++++++ mlx/backend/no_gpu/allocator.cpp | 4 +++ mlx/ops.cpp | 21 +++++---------- mlx/ops.h | 9 ++++--- python/src/array.cpp | 6 +++-- python/src/convert.cpp | 44 ++++++++------------------------ 10 files changed, 53 insertions(+), 54 deletions(-) diff --git a/mlx/allocator.h b/mlx/allocator.h index 824deac2c7..c831b7247d 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -72,4 +72,6 @@ inline void release(Buffer buffer) { allocator().release(buffer); } +MLX_API bool can_reuse_alien_buffer(void* ptr); + } // namespace mlx::core::allocator diff --git a/mlx/array.cpp b/mlx/array.cpp index e694ccfd67..c154d24fcf 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -64,6 +64,7 @@ array array::unsafe_weak_copy(const array& other) { other.data_size(), other.strides(), other.flags(), + 0, [](auto) {}); cpy.array_desc_->offset = other.array_desc_->offset; return cpy; @@ -180,9 +181,10 @@ void array::set_data( size_t data_size, Strides strides, Flags flags, + int64_t offset, Deleter d) { array_desc_->data = std::make_shared(buffer, d); - array_desc_->offset = 0; + array_desc_->offset = itemsize() * offset; array_desc_->data_size = data_size; array_desc_->strides = std::move(strides); array_desc_->flags = flags; diff --git a/mlx/array.h b/mlx/array.h index 60d5e50bbd..d8082a4502 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -443,6 +443,7 @@ class MLX_API array { size_t data_size, Strides strides, Flags flags, + int64_t offset = 0, Deleter d = allocator::free); void copy_shared_buffer( diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 718ae33e9c..3e1c7e2872 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -416,6 +416,14 @@ void* Buffer::raw_ptr() { return cbuf.data; } +bool can_reuse_alien_buffer(void* ptr) { + if (!ptr) { + return true; + } + auto& cbuf = *static_cast(ptr); + return cbuf.device == -1; +} + } // namespace allocator size_t get_active_memory() { diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 14ca0388ab..5a65c1637b 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -34,6 +34,14 @@ void* Buffer::raw_ptr() { return buf->contents(); } +bool can_reuse_alien_buffer(void* ptr) { + if (!ptr) { + return true; + } + auto* buf = static_cast(ptr); + return buf->storageMode() != MTL::StorageModePrivate; +} + } // namespace allocator namespace metal { diff --git a/mlx/backend/no_gpu/allocator.cpp b/mlx/backend/no_gpu/allocator.cpp index 1fdcc262a4..a800e381a7 100644 --- a/mlx/backend/no_gpu/allocator.cpp +++ b/mlx/backend/no_gpu/allocator.cpp @@ -178,6 +178,10 @@ void* Buffer::raw_ptr() { return static_cast(ptr_) + 1; } +bool can_reuse_alien_buffer(void*) { + return true; +} + } // namespace allocator size_t get_active_memory() { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4598529e6c..3a093582b7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -257,13 +257,16 @@ array linspace( s); } -array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) { - if (dtype == a.dtype()) { +array astype( + array a, + Dtype dtype, + std::optional copy, + StreamOrDevice s /* = {} */) { + if (dtype == a.dtype() && !copy.value_or(false)) { return a; } - auto copied_shape = a.shape(); // |a| will be moved return array( - std::move(copied_shape), + a.shape(), dtype, std::make_shared(to_stream(s), dtype), {std::move(a)}); @@ -296,16 +299,6 @@ array copy(array a, StreamOrDevice s /* = {} */) { {std::move(a)}); } -array copy_to_new_buffer(array a, StreamOrDevice s /* = {} */) { - auto copied_shape = a.shape(); // |a| will be moved - auto dtype = a.dtype(); - return array( - std::move(copied_shape), - dtype, - std::make_shared(to_stream(s), dtype), - {std::move(a)}); -} - array full_impl(array vals, Dtype dtype, StreamOrDevice s /* = {} */) { return array( vals.shape(), diff --git a/mlx/ops.h b/mlx/ops.h index 3084de3962..d047045a23 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -47,7 +47,11 @@ MLX_API array linspace( StreamOrDevice s = {}); /** Convert an array to the given data type. */ -MLX_API array astype(array a, Dtype dtype, StreamOrDevice s = {}); +MLX_API array +astype(array a, Dtype dtype, std::optional copy, StreamOrDevice s = {}); +inline array astype(array a, Dtype dtype, StreamOrDevice s = {}) { + return astype(std::move(a), dtype, std::nullopt, s); +} /** Create a view of an array with the given shape and strides. */ MLX_API array as_strided( @@ -60,9 +64,6 @@ MLX_API array as_strided( /** Copy another array. */ MLX_API array copy(array a, StreamOrDevice s = {}); -/** Copy another array into newly allocated storage. */ -MLX_API array copy_to_new_buffer(array a, StreamOrDevice s = {}); - /** Fill an array of the given shape with the given value(s). */ MLX_API array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {}); MLX_API array full(Shape shape, array vals, StreamOrDevice s = {}); diff --git a/python/src/array.cpp b/python/src/array.cpp index 3accf6172b..add17c2036 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -375,7 +375,9 @@ void init_array(nb::module_& m) { )pbdoc") .def( "astype", - &mx::astype, + [](const mx::array& a, mx::Dtype dtype, mx::StreamOrDevice s) { + return mx::astype(a, dtype, s); + }, "dtype"_a, "stream"_a = nb::none(), R"pbdoc( @@ -506,7 +508,7 @@ void init_array(nb::module_& m) { std::optional> dl_device, std::optional copy) { if (copy.value_or(false)) { - return mlx_to_dlpack(mx::copy_to_new_buffer(a), dl_device); + return mlx_to_dlpack(mx::astype(a, a.dtype(), true), dl_device); } return mlx_to_dlpack(a, dl_device); }, diff --git a/python/src/convert.cpp b/python/src/convert.cpp index c1bdcab04b..fdaa026771 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -7,13 +7,10 @@ #include #include -#if __has_include() -#include -#endif - #include "python/src/convert.h" #include "python/src/utils.h" +#include "mlx/allocator.h" #include "mlx/backend/cuda/cuda.h" #include "mlx/backend/metal/metal.h" #include "mlx/ops.h" @@ -26,21 +23,6 @@ enum PyScalarT { pycomplex = 3, }; -namespace { - -bool metal_buffer_is_private(void* ptr) { -#if __has_include() - if (!ptr) { - return false; - } - auto* buf = static_cast(ptr); - return buf->storageMode() == MTL::StorageModePrivate; -#endif - return false; -} - -} // namespace - int check_shape_dim(int64_t dim) { if (dim > std::numeric_limits::max() || dim < std::numeric_limits::min()) { @@ -177,7 +159,8 @@ mx::array metal_dlpack_to_mlx_contiguous( "Metal DLPack byte offset is not aligned to dtype size."); } - auto is_private_buffer = metal_buffer_is_private(owner.data_handle()); + auto can_reuse_buffer = + mx::allocator::can_reuse_alien_buffer(owner.data_handle()); auto out = mx::array( mx::allocator::Buffer(owner.data_handle()), @@ -185,28 +168,23 @@ mx::array metal_dlpack_to_mlx_contiguous( type, [](mx::allocator::Buffer) {}); auto flags = out.flags(); + auto offset = static_cast(byte_offset / itemsize); out.set_data( out.buffer(), out.data_size(), out.strides(), flags, + offset, [owner = std::move(owner)](mx::allocator::Buffer) {}); - auto offset = static_cast(byte_offset / itemsize); - if (offset != 0) { - out.copy_shared_buffer(out, out.strides(), flags, out.data_size(), offset); - } - - if (copy.has_value() && copy.value() == false && is_private_buffer) { + if (copy.has_value() && copy.value() == false && !can_reuse_buffer) { throw std::invalid_argument( "Cannot import a private Metal DLPack buffer without a copy."); } - if (dtype || is_private_buffer) { + if (dtype || !can_reuse_buffer) { auto result_dtype = dtype.value_or(out.dtype()); - auto result = (result_dtype == out.dtype()) - ? mx::copy_to_new_buffer(out, mx::Device::gpu) - : mx::astype(out, result_dtype, mx::Device::gpu); + auto result = mx::astype(out, result_dtype, true, mx::Device::gpu); result.eval(); result.detach(); return result; @@ -718,12 +696,12 @@ mx::array create_array(nb::object v, std::optional t) { nd = nb::cast(v); } auto type = nb_dtype.value_or(nd.dtype()); - std::optional copy_dtype; + std::optional dst_dtype; if (t && *t != mlx_dtype_from_dlpack(type, "Cannot convert array to mlx.")) { - copy_dtype = t; + dst_dtype = t; } - return nd_array_to_mlx(nd, copy_dtype, nb_dtype); + return nd_array_to_mlx(nd, dst_dtype, nb_dtype); } else { auto arr = to_array_with_accessor(v); return mx::astype(arr, t.value_or(arr.dtype())); From a4378cfa6027c71d58cd5cffdaa3254e862e7245 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sat, 23 May 2026 15:49:04 +0800 Subject: [PATCH 21/26] Add copy control to asarray --- docs/src/usage/numpy.rst | 14 +++++++----- python/src/array.cpp | 2 +- python/src/convert.cpp | 34 +++++++++++++++++++++++---- python/src/convert.h | 5 +++- python/src/ops.cpp | 18 ++++++--------- python/tests/test_array.py | 47 ++++++++++++++++++++++++++++++++------ 6 files changed, 90 insertions(+), 30 deletions(-) diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index bfc762fcbf..36154898f9 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -76,9 +76,10 @@ PyTorch ------- PyTorch supports DLPack inputs and can import MLX arrays directly. -MLX can also import PyTorch tensors through DLPack with ``mx.array`` or +MLX can also import PyTorch tensors through DLPack with ``mx.asarray`` or ``mx.from_dlpack``. Use ``torch.as_tensor`` to import an MLX array with -DLPack; ``torch.tensor`` copies the data instead: +DLPack; ``torch.tensor`` copies the data instead. Similarly, ``mx.asarray`` +can share DLPack inputs when possible, while ``mx.array`` copies: .. code-block:: python @@ -103,13 +104,14 @@ The arrays do not share memory: print(c.tolist()) # [0, 1, 2] Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to -``mx.array`` or to ``mx.from_dlpack`` with ``copy=None``, MLX imports it +``mx.asarray`` or to ``mx.from_dlpack`` with ``copy=None``, MLX imports it without a copy when the underlying Metal buffer is not private. Private Metal buffers are copied into MLX-managed storage instead. Passing ``copy=False`` requires zero-copy import and raises an error if a copy would be needed. Passing ``copy=True`` asks MLX to create a new array instead of reusing the -Metal buffer. MLX arrays exported to PyTorch with DLPack are also exported -without a copy on Metal. +Metal buffer. ``mx.array`` also creates a new array instead of reusing the +Metal buffer. MLX arrays exported to PyTorch with DLPack are exported without a +copy on Metal. In particular, PyTorch 2.12 and later use shared storage for ordinary MPS tensors on Apple silicon, while older PyTorch versions may use private storage @@ -121,7 +123,7 @@ converted array. b = torch.arange(3, device="mps", dtype=torch.float32) torch.mps.synchronize() - c = mx.array(b) # zero-copy if the Metal buffer is not private + c = mx.asarray(b) # zero-copy if the Metal buffer can be reused d = mx.from_dlpack(b, copy=True) # explicit copy .. code-block:: python diff --git a/python/src/array.cpp b/python/src/array.cpp index add17c2036..db0d7508d1 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -297,7 +297,7 @@ void init_array(nb::module_& m) { .def( "__init__", [](mx::array* aptr, nb::object v, std::optional t) { - new (aptr) mx::array(create_array(v, t)); + new (aptr) mx::array(create_array(v, t, true)); }, "val"_a, "dtype"_a = nb::none(), diff --git a/python/src/convert.cpp b/python/src/convert.cpp index fdaa026771..799f7dcba5 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -215,6 +215,10 @@ mx::array nd_array_to_mlx( std::optional dst_dtype, std::optional src_dtype, std::optional copy) { + if (copy.has_value() && copy.value() == false && dst_dtype) { + throw std::invalid_argument( + "Cannot convert DLPack array to requested dtype without a copy."); + } switch (nd_array.device_type()) { case nb::device::cpu::value: { if (copy.has_value() && copy.value() == false) { @@ -655,10 +659,22 @@ mx::array array_from_list(nb::tuple pl, std::optional dtype) { return array_from_list_impl(pl, dtype); } -mx::array create_array(nb::object v, std::optional t) { +void check_copy_false(std::optional copy) { + if (copy.has_value() && copy.value() == false) { + throw std::invalid_argument( + "Unable to avoid copy while creating an array as requested."); + } +} + +mx::array create_array( + nb::object v, + std::optional t, + std::optional copy) { if (nb::isinstance(v)) { + check_copy_false(copy); return mx::array(nb::cast(v), t.value_or(mx::bool_)); } else if (nb::isinstance(v)) { + check_copy_false(copy); auto val = nb::cast(v); auto default_type = (val > std::numeric_limits::max() || val < std::numeric_limits::min()) @@ -666,6 +682,7 @@ mx::array create_array(nb::object v, std::optional t) { : mx::int32; return mx::array(val, t.value_or(default_type)); } else if (nb::isinstance(v)) { + check_copy_false(copy); auto out_type = t.value_or(mx::float32); if (out_type == mx::float64) { return mx::array(nb::cast(v), out_type); @@ -673,16 +690,24 @@ mx::array create_array(nb::object v, std::optional t) { return mx::array(nb::cast(v), out_type); } } else if (PyComplex_Check(v.ptr())) { + check_copy_false(copy); return mx::array( static_cast(nb::cast>(v)), t.value_or(mx::complex64)); } else if (nb::isinstance(v)) { + check_copy_false(copy); return array_from_list(nb::cast(v), t); } else if (nb::isinstance(v)) { + check_copy_false(copy); return array_from_list(nb::cast(v), t); } else if (nb::isinstance(v)) { auto arr = nb::cast(v); - return mx::astype(arr, t.value_or(arr.dtype())); + auto dtype = t.value_or(arr.dtype()); + if (copy.has_value() && copy.value() == false && dtype != arr.dtype()) { + throw std::invalid_argument( + "Unable to avoid copy while creating an array as requested."); + } + return mx::astype(arr, dtype, copy); } else if (nb::ndarray_check(v)) { using ContigArray = nb::ndarray; ContigArray nd; @@ -701,9 +726,10 @@ mx::array create_array(nb::object v, std::optional t) { *t != mlx_dtype_from_dlpack(type, "Cannot convert array to mlx.")) { dst_dtype = t; } - return nd_array_to_mlx(nd, dst_dtype, nb_dtype); + return nd_array_to_mlx(nd, dst_dtype, nb_dtype, copy); } else { + check_copy_false(copy); auto arr = to_array_with_accessor(v); - return mx::astype(arr, t.value_or(arr.dtype())); + return mx::astype(arr, t.value_or(arr.dtype()), copy); } } diff --git a/python/src/convert.h b/python/src/convert.h index 17030be47c..9546e88b9a 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -76,7 +76,10 @@ nb::object to_scalar(mx::array& a); nb::object tolist(mx::array& a); -mx::array create_array(nb::object v, std::optional t); +mx::array create_array( + nb::object v, + std::optional t, + std::optional copy = true); mx::array array_from_list(nb::list pl, std::optional dtype); mx::array array_from_list(nb::tuple pl, std::optional dtype); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 07279eeb42..02524a9160 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1767,33 +1767,29 @@ void init_ops(nb::module_& m) { "asarray", [](const nb::object& a, std::optional dtype, - std::optional copy) { - if (copy.has_value() && !*copy) { - throw std::invalid_argument("[asarray] copy=False is not supported."); - } - return create_array(a, dtype); - }, + std::optional copy) { return create_array(a, dtype, copy); }, nb::arg(), "dtype"_a = nb::none(), nb::kw_only(), "copy"_a = nb::none(), nb::sig( - "def asarray(a: Union[scalar, array, Sequence], dtype: Optional[Dtype] = None, *, copy: Optional[bool] = None) -> array"), + "def asarray(a: Union[scalar, array, Sequence], dtype: " + "Optional[Dtype] = None, *, copy: Optional[bool] = None) -> array"), R"pbdoc( Convert the input to an array. Args: a: Input data. dtype (Dtype, optional): The desired data-type for the array. - copy (bool, optional): Must be ``True`` or unspecified. ``False`` - is not supported, since MLX has no in-place operations and - cannot return a non-copying view. + copy (bool, optional): Whether to copy the input. If ``True``, + always copy. If ``False``, never copy. If ``None``, share memory + when possible and copy otherwise. Returns: array: An array interpretation of the input. Raises: - ValueError: If ``copy`` is ``False``. + ValueError: If ``copy`` is ``False`` and a copy is required. )pbdoc"); m.def( "from_dlpack", diff --git a/python/tests/test_array.py b/python/tests/test_array.py index be3061d46f..87cd28235b 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2077,7 +2077,7 @@ def test_torch_mps_dlpack_import(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) self.assertEqual(x.__dlpack_device__()[0], 8) - y = mx.array(x) + y = mx.asarray(x) self.assertEqual(y.dtype, mx.float32) torch.mps.synchronize() self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) @@ -2085,11 +2085,31 @@ def test_torch_mps_dlpack_import(self): mv = memoryview(y) self.assertEqual(mv.tolist(), x.cpu().numpy().tolist()) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_array_copies_dlpack_input(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.array(x) + + x.add_(10) + torch.mps.synchronize() + self.assertEqual(y.tolist(), [0.0, 1.0, 2.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_asarray_copy_true_copies_dlpack_input(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.asarray(x, copy=True) + + x.add_(10) + torch.mps.synchronize() + self.assertEqual(y.tolist(), [0.0, 1.0, 2.0]) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_zero_copy_shares_updates(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) torch.mps.synchronize() - y = mx.array(x) + y = mx.asarray(x) x.add_(100) torch.mps.synchronize() @@ -2103,7 +2123,7 @@ def test_torch_mps_dlpack_zero_copy_shares_updates(self): def test_torch_mps_dlpack_matching_dtype_argument_shares_updates(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) torch.mps.synchronize() - y = mx.array(x, dtype=mx.float32) + y = mx.asarray(x, dtype=mx.float32, copy=False) self.assertEqual(y.dtype, mx.float32) x.add_(100) @@ -2114,7 +2134,7 @@ def test_torch_mps_dlpack_matching_dtype_argument_shares_updates(self): def test_torch_mps_dlpack_different_dtype_argument_copies(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) torch.mps.synchronize() - z = mx.array(x, dtype=mx.float16) + z = mx.asarray(x, dtype=mx.float16) expected = x.to(torch.float16).cpu().numpy().tolist() self.assertEqual(z.dtype, mx.float16) @@ -2124,10 +2144,13 @@ def test_torch_mps_dlpack_different_dtype_argument_copies(self): torch.mps.synchronize() self.assertEqual(z.tolist(), expected) + with self.assertRaises(ValueError): + mx.asarray(x, dtype=mx.float16, copy=False) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_data_offset(self): view = torch.arange(12, device="mps", dtype=torch.float32)[3:9] - view_mx = mx.array(view) + view_mx = mx.asarray(view) torch.mps.synchronize() self.assertEqual((view_mx + 1).tolist(), (view + 1).cpu().numpy().tolist()) @@ -2135,7 +2158,7 @@ def test_torch_mps_dlpack_data_offset(self): def test_torch_mps_dlpack_bfloat16(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) bf = x.to(torch.bfloat16) - bf_mx = mx.array(bf) + bf_mx = mx.asarray(bf) self.assertEqual(bf_mx.dtype, mx.bfloat16) torch.mps.synchronize() @@ -2345,9 +2368,10 @@ def test_asarray_copy(self): self.assertEqual( mx.asarray(existing, dtype=mx.float32, copy=True).dtype, mx.float32 ) + self.assertEqual(mx.asarray(existing, copy=False).tolist(), [1, 2, 3]) with self.assertRaises(ValueError): - mx.asarray(existing, copy=False) + mx.asarray(existing, dtype=mx.float32, copy=False) def test_asarray(self): # List inputs @@ -2371,17 +2395,26 @@ def test_asarray(self): # MLX array inputs arr = mx.array([1, 2, 3]) self.assertEqual(mx.asarray(arr).tolist(), [1, 2, 3]) + self.assertEqual(mx.asarray(arr, copy=False).tolist(), [1, 2, 3]) + self.assertEqual(mx.asarray(arr, copy=True).tolist(), [1, 2, 3]) arr_int = mx.array([1, 2, 3], dtype=mx.int32) arr_float = mx.asarray(arr_int, dtype=mx.float32) self.assertEqual(arr_float.dtype, mx.float32) self.assertEqual(arr_float.tolist(), [1.0, 2.0, 3.0]) + with self.assertRaises(ValueError): + mx.asarray(arr_int, dtype=mx.float32, copy=False) # NumPy array inputs np_arr = np.array([1.0, 2.0, 3.0], dtype=np.float32) mx_arr = mx.asarray(np_arr) self.assertEqual(mx_arr.tolist(), [1.0, 2.0, 3.0]) self.assertEqual(mx_arr.dtype, mx.float32) + with self.assertRaises(ValueError): + mx.asarray(np_arr, copy=False) + + with self.assertRaises(ValueError): + mx.asarray([1, 2, 3], copy=False) # dtype parameter self.assertEqual(mx.asarray([1, 2, 3], dtype=mx.float32).dtype, mx.float32) From 7e877579f01e3728927acc320b75c9bb527a02ec Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sun, 24 May 2026 17:28:36 +0800 Subject: [PATCH 22/26] Address DLPack review feedback --- mlx/array.cpp | 2 +- mlx/backend/cuda/allocator.cpp | 6 +- mlx/backend/metal/allocator.cpp | 8 +-- mlx/utils.cpp | 1 - python/src/buffer.h | 1 + python/src/convert.cpp | 122 +++++++++++--------------------- python/src/convert.h | 2 +- python/tests/test_array.py | 7 +- 8 files changed, 53 insertions(+), 96 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index c154d24fcf..07d3acd616 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -184,7 +184,7 @@ void array::set_data( int64_t offset, Deleter d) { array_desc_->data = std::make_shared(buffer, d); - array_desc_->offset = itemsize() * offset; + array_desc_->offset = offset; array_desc_->data_size = data_size; array_desc_->strides = std::move(strides); array_desc_->flags = flags; diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 3e1c7e2872..21358929da 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -417,11 +417,7 @@ void* Buffer::raw_ptr() { } bool can_reuse_alien_buffer(void* ptr) { - if (!ptr) { - return true; - } - auto& cbuf = *static_cast(ptr); - return cbuf.device == -1; + return true; } } // namespace allocator diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 5a65c1637b..60459c67c2 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -7,6 +7,7 @@ #include #include +#include #include namespace mlx::core { @@ -25,12 +26,7 @@ void* Buffer::raw_ptr() { return nullptr; } auto* buf = static_cast(ptr_); - if (buf->storageMode() == MTL::StorageModePrivate) { - throw std::runtime_error( - "[metal::Buffer::raw_ptr] Cannot access Metal buffer contents on the " - "host. The buffer is not CPU-addressable, for example because it uses " - "private storage."); - } + assert(buf->storageMode() != MTL::StorageModePrivate); return buf->contents(); } diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 2b45f52d88..239e6603dd 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -7,7 +7,6 @@ #include #include "mlx/dtype_utils.h" -#include "mlx/ops.h" #include "mlx/types/limits.h" #include "mlx/utils.h" diff --git a/python/src/buffer.h b/python/src/buffer.h index 49ad0d64d0..272a918883 100644 --- a/python/src/buffer.h +++ b/python/src/buffer.h @@ -92,6 +92,7 @@ extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) { nb::gil_scoped_release nogil; a.eval(); } + std::vector shape(a.shape().begin(), a.shape().end()); std::vector strides(a.strides().begin(), a.strides().end()); for (auto& s : strides) { diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 799f7dcba5..c690a7c32f 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -140,82 +140,55 @@ nb::dlpack::dtype mlx_dtype_to_dl_dtype(mx::Dtype dtype) { } } -template -mx::array metal_dlpack_to_mlx_contiguous( - nb::ndarray owner, - const mx::Shape& shape, - mx::Dtype type, - std::optional dtype, - std::optional copy) { - auto itemsize = mx::size_of(type); - if (owner.itemsize() != itemsize) { - throw std::invalid_argument( - "Cannot convert Metal DLPack dtype to mlx dtype."); +mx::array metal_dlpack_to_mlx( + nb::ndarray nd_array, + mx::Dtype src_dtype, + mx::Dtype dst_dtype, + bool copy) { + if (!mx::metal::is_available()) { + throw std::invalid_argument("Metal DLPack import is not available."); } - - auto byte_offset = owner.byte_offset(); - if (byte_offset % itemsize != 0) { + auto shape = get_shape(nd_array); + if (nd_array.itemsize() != mx::size_of(src_dtype)) { throw std::invalid_argument( - "Metal DLPack byte offset is not aligned to dtype size."); + "Cannot convert Metal DLPack dtype to mlx dtype."); } - - auto can_reuse_buffer = - mx::allocator::can_reuse_alien_buffer(owner.data_handle()); - + auto data_handle = nd_array.data_handle(); auto out = mx::array( - mx::allocator::Buffer(owner.data_handle()), + mx::allocator::Buffer(data_handle), shape, - type, + src_dtype, [](mx::allocator::Buffer) {}); - auto flags = out.flags(); - auto offset = static_cast(byte_offset / itemsize); out.set_data( out.buffer(), out.data_size(), out.strides(), - flags, - offset, - [owner = std::move(owner)](mx::allocator::Buffer) {}); - - if (copy.has_value() && copy.value() == false && !can_reuse_buffer) { - throw std::invalid_argument( - "Cannot import a private Metal DLPack buffer without a copy."); - } + out.flags(), + nd_array.byte_offset(), + [owner = std::move(nd_array)](mx::allocator::Buffer) {}); - if (dtype || !can_reuse_buffer) { - auto result_dtype = dtype.value_or(out.dtype()); - auto result = mx::astype(out, result_dtype, true, mx::Device::gpu); + if (copy) { + auto result = mx::astype(out, dst_dtype, true, mx::Device::gpu); result.eval(); - result.detach(); return result; } return out; } -mx::array metal_dlpack_to_mlx( - nb::ndarray nd_array, - std::optional dtype, - std::optional copy) { - if (!mx::metal::is_available()) { - throw std::invalid_argument("Metal DLPack import is not available."); - } - auto shape = get_shape(nd_array); - - return dispatch_dlpack_dtype( - nd_array.dtype(), - [&](mx::Dtype type) { - return metal_dlpack_to_mlx_contiguous( - nd_array, shape, type, dtype, copy); - }, - "Cannot convert Metal DLPack array to mlx array."); -} - mx::array nd_array_to_mlx( nb::ndarray nd_array, - std::optional dst_dtype, - std::optional src_dtype, + std::optional requested_dtype, + std::optional src_dlpack_dtype_override, std::optional copy) { - if (copy.has_value() && copy.value() == false && dst_dtype) { + auto src_dlpack_dtype = src_dlpack_dtype_override.value_or(nd_array.dtype()); + auto src_mlx_dtype = + mlx_dtype_from_dlpack(src_dlpack_dtype, "Cannot convert array to mlx."); + auto dst_dtype = requested_dtype.value_or(src_mlx_dtype); + bool can_reuse_buffer = + mx::allocator::can_reuse_alien_buffer(nd_array.data_handle()); + bool should_copy = + copy.value_or(false) || dst_dtype != src_mlx_dtype || !can_reuse_buffer; + if (copy.has_value() && copy.value() == false && dst_dtype != src_mlx_dtype) { throw std::invalid_argument( "Cannot convert DLPack array to requested dtype without a copy."); } @@ -226,22 +199,21 @@ mx::array nd_array_to_mlx( "Cannot import a CPU DLPack array without a copy."); } auto shape = get_shape(nd_array); - auto type = src_dtype.value_or(nd_array.dtype()); return dispatch_dlpack_dtype( - type, - [&](mx::Dtype default_dtype) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dst_dtype.value_or(default_dtype)); + src_dlpack_dtype, + [&](mx::Dtype) { + return nd_array_to_mlx_contiguous(nd_array, shape, dst_dtype); }, "Cannot convert numpy array to mlx array."); } - case nb::device::metal::value: - if (copy.has_value() && copy.value() == true && !dst_dtype) { - dst_dtype = mlx_dtype_from_dlpack( - nd_array.dtype(), - "Cannot convert Metal DLPack array to mlx array."); + case nb::device::metal::value: { + if (copy.has_value() && copy.value() == false && !can_reuse_buffer) { + throw std::invalid_argument( + "Cannot import a private Metal DLPack buffer without a copy."); } - return metal_dlpack_to_mlx(nd_array, dst_dtype, copy); + return metal_dlpack_to_mlx( + nd_array, src_mlx_dtype, dst_dtype, should_copy); + } case nb::device::cuda::value: case nb::device::cuda_managed::value: throw std::invalid_argument("CUDA DLPack import is not supported."); @@ -338,7 +310,8 @@ nb::ndarray<> mlx_to_dlpack( nb::gil_scoped_release nogil; arr.eval(); if (device_type == nb::device::cpu::value) { - data = static_cast(arr.buffer().raw_ptr()) + arr.offset(); + data = arr.buffer().raw_ptr(); + byte_offset = arr.offset(); } else { data = arr.buffer().ptr(); byte_offset = arr.offset(); @@ -703,10 +676,7 @@ mx::array create_array( } else if (nb::isinstance(v)) { auto arr = nb::cast(v); auto dtype = t.value_or(arr.dtype()); - if (copy.has_value() && copy.value() == false && dtype != arr.dtype()) { - throw std::invalid_argument( - "Unable to avoid copy while creating an array as requested."); - } + check_copy_false(copy); return mx::astype(arr, dtype, copy); } else if (nb::ndarray_check(v)) { using ContigArray = nb::ndarray; @@ -720,13 +690,7 @@ mx::array create_array( } else { nd = nb::cast(v); } - auto type = nb_dtype.value_or(nd.dtype()); - std::optional dst_dtype; - if (t && - *t != mlx_dtype_from_dlpack(type, "Cannot convert array to mlx.")) { - dst_dtype = t; - } - return nd_array_to_mlx(nd, dst_dtype, nb_dtype, copy); + return nd_array_to_mlx(nd, t, nb_dtype, copy); } else { check_copy_false(copy); auto arr = to_array_with_accessor(v); diff --git a/python/src/convert.h b/python/src/convert.h index 9546e88b9a..4e2902ff76 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -64,7 +64,7 @@ struct ArrayLike { mx::array nd_array_to_mlx( nb::ndarray nd_array, std::optional mx_dtype, - std::optional nb_dtype = std::nullopt, + std::optional src_dlpack_dtype_override = std::nullopt, std::optional copy = std::nullopt); nb::ndarray mlx_to_np_array(const mx::array& a); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 87cd28235b..9a30b6e753 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2368,8 +2368,8 @@ def test_asarray_copy(self): self.assertEqual( mx.asarray(existing, dtype=mx.float32, copy=True).dtype, mx.float32 ) - self.assertEqual(mx.asarray(existing, copy=False).tolist(), [1, 2, 3]) - + with self.assertRaises(ValueError): + mx.asarray(existing, copy=False) with self.assertRaises(ValueError): mx.asarray(existing, dtype=mx.float32, copy=False) @@ -2395,8 +2395,9 @@ def test_asarray(self): # MLX array inputs arr = mx.array([1, 2, 3]) self.assertEqual(mx.asarray(arr).tolist(), [1, 2, 3]) - self.assertEqual(mx.asarray(arr, copy=False).tolist(), [1, 2, 3]) self.assertEqual(mx.asarray(arr, copy=True).tolist(), [1, 2, 3]) + with self.assertRaises(ValueError): + mx.asarray(arr, copy=False) arr_int = mx.array([1, 2, 3], dtype=mx.int32) arr_float = mx.asarray(arr_int, dtype=mx.float32) From a7c4a44f2e8f5a475437e2f397f8727377084dd0 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Mon, 25 May 2026 03:26:46 +0800 Subject: [PATCH 23/26] Support strided DLPack arrays --- python/src/array.cpp | 2 +- python/src/convert.cpp | 108 ++++++++++++++++++++++++++++--------- python/src/convert.h | 2 +- python/src/indexing.cpp | 2 +- python/src/ops.cpp | 2 +- python/src/utils.cpp | 2 +- python/src/utils.h | 2 +- python/tests/test_array.py | 72 +++++++++++++++++++++++++ 8 files changed, 162 insertions(+), 30 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index db0d7508d1..ab846addda 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -483,7 +483,7 @@ void init_array(nb::module_& m) { throw std::invalid_argument( "Invalid pickle state: expected (ndarray, Dtype::Val)"); } - using ND = nb::ndarray; + using ND = nb::ndarray; ND nd = nb::cast(state[0]); auto val = static_cast(nb::cast(state[1])); if (val == mx::Dtype::Val::bfloat16) { diff --git a/python/src/convert.cpp b/python/src/convert.cpp index c690a7c32f..d245cebc9c 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. +#include #include #include #include @@ -11,6 +12,7 @@ #include "python/src/utils.h" #include "mlx/allocator.h" +#include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/cuda.h" #include "mlx/backend/metal/metal.h" #include "mlx/ops.h" @@ -47,15 +49,31 @@ mx::Shape get_shape(const nb::ndarray& nd_array) { return shape; } -template -mx::array nd_array_to_mlx_contiguous( - nb::ndarray nd_array, +template +mx::Strides get_strides(const nb::ndarray& nd_array) { + mx::Strides strides; + strides.reserve(nd_array.ndim()); + for (int i = 0; i < nd_array.ndim(); i++) { + strides.push_back(nd_array.stride(i)); + } + return strides; +} + +size_t strided_storage_size( const mx::Shape& shape, - mx::Dtype dtype) { - // Make a copy of the numpy buffer - // Get buffer ptr pass to array constructor - auto data_ptr = nd_array.data(); - return mx::array(static_cast(data_ptr), shape, dtype); + const mx::Strides& strides) { + size_t storage_size = 1; + for (int i = 0; i < shape.size(); i++) { + if (shape[i] == 0) { + return 0; + } + if (strides[i] < 0) { + throw std::invalid_argument( + "Cannot convert DLPack arrays with negative strides to mlx array."); + } + storage_size += (shape[i] - 1) * strides[i]; + } + return storage_size; } template @@ -140,8 +158,42 @@ nb::dlpack::dtype mlx_dtype_to_dl_dtype(mx::Dtype dtype) { } } +template +mx::array cpu_nd_array_to_mlx( + nb::ndarray nd_array, + const mx::Shape& shape, + mx::Dtype dst_dtype) { + return dispatch_dlpack_dtype( + mlx_dtype_to_dl_dtype(dst_dtype), + [&](mx::Dtype) { + auto out = mx::array(shape, dst_dtype, nullptr, {}); + auto strides = get_strides(nd_array); + auto storage_size = strided_storage_size(shape, strides); + auto [no_bsx_size, is_row_contiguous, is_col_contiguous] = shape.empty() + ? std::make_tuple(storage_size, true, true) + : mx::check_contiguity(shape, strides); + auto flags = out.flags(); + flags.contiguous = no_bsx_size == storage_size; + flags.row_contiguous = is_row_contiguous; + flags.col_contiguous = is_col_contiguous; + out.set_data( + mx::allocator::malloc(storage_size * mx::size_of(dst_dtype)), + storage_size, + std::move(strides), + flags); + if (storage_size > 0) { + auto src = static_cast(nd_array.data()); + auto dst = out.data(); + std::copy(src, src + storage_size, dst); + } + out.set_status(mx::array::Status::available); + return out; + }, + "Cannot convert numpy array to mlx array."); +} + mx::array metal_dlpack_to_mlx( - nb::ndarray nd_array, + nb::ndarray nd_array, mx::Dtype src_dtype, mx::Dtype dst_dtype, bool copy) { @@ -153,19 +205,25 @@ mx::array metal_dlpack_to_mlx( throw std::invalid_argument( "Cannot convert Metal DLPack dtype to mlx dtype."); } + auto strides = get_strides(nd_array); + auto storage_size = strided_storage_size(shape, strides); + auto [no_bsx_size, is_row_contiguous, is_col_contiguous] = shape.empty() + ? std::make_tuple(storage_size, true, true) + : mx::check_contiguity(shape, strides); auto data_handle = nd_array.data_handle(); - auto out = mx::array( - mx::allocator::Buffer(data_handle), - shape, - src_dtype, - [](mx::allocator::Buffer) {}); + mx::array out(shape, src_dtype, nullptr, {}); + auto flags = out.flags(); + flags.contiguous = no_bsx_size == storage_size; + flags.row_contiguous = is_row_contiguous; + flags.col_contiguous = is_col_contiguous; out.set_data( - out.buffer(), - out.data_size(), - out.strides(), - out.flags(), + mx::allocator::Buffer(data_handle), + storage_size, + std::move(strides), + flags, nd_array.byte_offset(), [owner = std::move(nd_array)](mx::allocator::Buffer) {}); + out.set_status(mx::array::Status::available); if (copy) { auto result = mx::astype(out, dst_dtype, true, mx::Device::gpu); @@ -176,7 +234,7 @@ mx::array metal_dlpack_to_mlx( } mx::array nd_array_to_mlx( - nb::ndarray nd_array, + nb::ndarray nd_array, std::optional requested_dtype, std::optional src_dlpack_dtype_override, std::optional copy) { @@ -184,15 +242,17 @@ mx::array nd_array_to_mlx( auto src_mlx_dtype = mlx_dtype_from_dlpack(src_dlpack_dtype, "Cannot convert array to mlx."); auto dst_dtype = requested_dtype.value_or(src_mlx_dtype); - bool can_reuse_buffer = + auto device_type = nd_array.device_type(); + bool can_reuse_buffer = device_type == nb::device::cpu::value || mx::allocator::can_reuse_alien_buffer(nd_array.data_handle()); + // call this for cpu array will raise error, use or to aviod it bool should_copy = copy.value_or(false) || dst_dtype != src_mlx_dtype || !can_reuse_buffer; if (copy.has_value() && copy.value() == false && dst_dtype != src_mlx_dtype) { throw std::invalid_argument( "Cannot convert DLPack array to requested dtype without a copy."); } - switch (nd_array.device_type()) { + switch (device_type) { case nb::device::cpu::value: { if (copy.has_value() && copy.value() == false) { throw std::invalid_argument( @@ -201,8 +261,8 @@ mx::array nd_array_to_mlx( auto shape = get_shape(nd_array); return dispatch_dlpack_dtype( src_dlpack_dtype, - [&](mx::Dtype) { - return nd_array_to_mlx_contiguous(nd_array, shape, dst_dtype); + [&](mx::Dtype src_dtype) { + return cpu_nd_array_to_mlx(nd_array, shape, dst_dtype); }, "Cannot convert numpy array to mlx array."); } @@ -679,7 +739,7 @@ mx::array create_array( check_copy_false(copy); return mx::astype(arr, dtype, copy); } else if (nb::ndarray_check(v)) { - using ContigArray = nb::ndarray; + using ContigArray = nb::ndarray; ContigArray nd; std::optional nb_dtype; // Nanobind does not recognize bfloat16 numpy array: diff --git a/python/src/convert.h b/python/src/convert.h index 4e2902ff76..0b8e569afb 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -62,7 +62,7 @@ struct ArrayLike { }; mx::array nd_array_to_mlx( - nb::ndarray nd_array, + nb::ndarray nd_array, std::optional mx_dtype, std::optional src_dlpack_dtype_override = std::nullopt, std::optional copy = std::nullopt); diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index c44cd9b943..3df4c96882 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -926,7 +926,7 @@ mlx_compute_slice_update_args( } std::optional extract_boolean_mask(const nb::object& obj) { - using NDArray = nb::ndarray; + using NDArray = nb::ndarray; if (nb::isinstance(obj)) { return mx::array(nb::cast(obj), mx::bool_); } else if (nb::isinstance(obj)) { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 02524a9160..8104773346 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1793,7 +1793,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "from_dlpack", - [](nb::ndarray x, std::optional copy) { + [](nb::ndarray x, std::optional copy) { return nd_array_to_mlx(x, std::nullopt, std::nullopt, copy); }, nb::arg(), diff --git a/python/src/utils.cpp b/python/src/utils.cpp index fa4a53428b..4416531dc8 100644 --- a/python/src/utils.cpp +++ b/python/src/utils.cpp @@ -38,7 +38,7 @@ mx::array to_array( return mx::array(static_cast(*pv), mx::complex64); } else if (auto pv = std::get_if(&v); pv) { return *pv; - } else if (auto pv = std::get_if>(&v); pv) { + } else if (auto pv = std::get_if>(&v); pv) { return nd_array_to_mlx(*pv, dtype); } else { return to_array_with_accessor(std::get(v).obj); diff --git a/python/src/utils.h b/python/src/utils.h index 8956fc210f..d2fcbd037e 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -24,7 +24,7 @@ using ScalarOrArray = std::variant< // Must be above ndarray mx::array, // Must be above complex - nb::ndarray, + nb::ndarray, std::complex, ArrayLike>; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 9a30b6e753..054542292e 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2072,6 +2072,30 @@ def test_from_dlpack_cpu(self): with self.assertRaises(ValueError): mx.from_dlpack(x, copy=False) + def test_from_dlpack_cpu_strided(self): + x = np.arange(12, dtype=np.float32).reshape(3, 4) + view = x.T + y = mx.from_dlpack(view) + + self.assertEqual(y.tolist(), view.tolist()) + expected = view.copy().tolist() + x[0, 0] = 99 + self.assertEqual(y.tolist(), expected) + + stepped = np.arange(20, dtype=np.int32)[2:10:2] + y = mx.from_dlpack(stepped) + self.assertEqual(y.tolist(), [2, 4, 6, 8]) + stepped[0] = 99 + self.assertEqual(y.tolist(), [2, 4, 6, 8]) + + broadcast = np.broadcast_to(np.array([7], dtype=np.int32), (3,)) + y = mx.from_dlpack(broadcast) + self.assertEqual(y.tolist(), [7, 7, 7]) + + negative_stride = np.arange(5, dtype=np.float32)[::-1] + with self.assertRaises(ValueError): + mx.from_dlpack(negative_stride) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_import(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) @@ -2154,6 +2178,54 @@ def test_torch_mps_dlpack_data_offset(self): torch.mps.synchronize() self.assertEqual((view_mx + 1).tolist(), (view + 1).cpu().numpy().tolist()) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_strided_view(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + view = x.T + torch.mps.synchronize() + y = mx.asarray(view, copy=False) + self.assertEqual(y.tolist(), view.cpu().numpy().tolist()) + + x[0, 1] = 99 + torch.mps.synchronize() + self.assertEqual(y.tolist(), view.cpu().numpy().tolist()) + + y_copy = mx.asarray(view, copy=True) + expected = view.cpu().numpy().tolist() + x[0, 2] = 77 + torch.mps.synchronize() + self.assertEqual(y_copy.tolist(), expected) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_stepped_view(self): + x = torch.arange(20, device="mps", dtype=torch.int32) + view = x[2:10:2] + torch.mps.synchronize() + y = mx.asarray(view, copy=False) + self.assertEqual(y.tolist(), [2, 4, 6, 8]) + + x[4] = 99 + torch.mps.synchronize() + self.assertEqual(y.tolist(), [2, 99, 6, 8]) + + y_copy = mx.asarray(view, copy=True) + expected = y.tolist() + x[6] = 77 + torch.mps.synchronize() + self.assertEqual(y_copy.tolist(), expected) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_broadcast_stride(self): + x = torch.tensor([7], device="mps", dtype=torch.int32) + view = x.expand(3) + torch.mps.synchronize() + y = mx.asarray(view, copy=False) + self.assertEqual(y.tolist(), [7, 7, 7]) + + x.add_(1) + torch.mps.synchronize() + self.assertEqual(y.tolist(), [8, 8, 8]) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_bfloat16(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) From 30e1484afc141b9048b0a0968525981d37904781 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Mon, 25 May 2026 17:41:18 +0800 Subject: [PATCH 24/26] Make copied DLPack imports row-contiguous --- docs/src/usage/numpy.rst | 7 ++++--- python/src/convert.cpp | 42 +++++++++++++++++++++++--------------- python/src/ops.cpp | 8 ++++++-- python/tests/test_array.py | 13 ++++++++++++ 4 files changed, 49 insertions(+), 21 deletions(-) diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index 36154898f9..f7a2169660 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -109,9 +109,10 @@ without a copy when the underlying Metal buffer is not private. Private Metal buffers are copied into MLX-managed storage instead. Passing ``copy=False`` requires zero-copy import and raises an error if a copy would be needed. Passing ``copy=True`` asks MLX to create a new array instead of reusing the -Metal buffer. ``mx.array`` also creates a new array instead of reusing the -Metal buffer. MLX arrays exported to PyTorch with DLPack are exported without a -copy on Metal. +Metal buffer. Copied DLPack inputs are materialized as row-contiguous MLX +arrays; zero-copy imports preserve the DLPack strides. ``mx.array`` also +creates a new array instead of reusing the Metal buffer. MLX arrays exported to +PyTorch with DLPack are exported without a copy on Metal. In particular, PyTorch 2.12 and later use shared storage for ordinary MPS tensors on Apple silicon, while older PyTorch versions may use private storage diff --git a/python/src/convert.cpp b/python/src/convert.cpp index d245cebc9c..fc853fff31 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -76,6 +76,19 @@ size_t strided_storage_size( return storage_size; } +size_t strided_offset( + size_t index, + const mx::Shape& shape, + const mx::Strides& strides) { + size_t offset = 0; + for (size_t i = shape.size(); i-- > 0;) { + auto dim_index = index % shape[i]; + index /= shape[i]; + offset += dim_index * strides[i]; + } + return offset; +} + template auto dispatch_dlpack_dtype( nb::dlpack::dtype type, @@ -168,23 +181,14 @@ mx::array cpu_nd_array_to_mlx( [&](mx::Dtype) { auto out = mx::array(shape, dst_dtype, nullptr, {}); auto strides = get_strides(nd_array); - auto storage_size = strided_storage_size(shape, strides); - auto [no_bsx_size, is_row_contiguous, is_col_contiguous] = shape.empty() - ? std::make_tuple(storage_size, true, true) - : mx::check_contiguity(shape, strides); - auto flags = out.flags(); - flags.contiguous = no_bsx_size == storage_size; - flags.row_contiguous = is_row_contiguous; - flags.col_contiguous = is_col_contiguous; - out.set_data( - mx::allocator::malloc(storage_size * mx::size_of(dst_dtype)), - storage_size, - std::move(strides), - flags); - if (storage_size > 0) { + strided_storage_size(shape, strides); + out.set_data(mx::allocator::malloc(out.nbytes())); + if (out.size() > 0) { auto src = static_cast(nd_array.data()); auto dst = out.data(); - std::copy(src, src + storage_size, dst); + for (size_t i = 0; i < out.size(); ++i) { + dst[i] = static_cast(src[strided_offset(i, shape, strides)]); + } } out.set_status(mx::array::Status::available); return out; @@ -216,11 +220,17 @@ mx::array metal_dlpack_to_mlx( flags.contiguous = no_bsx_size == storage_size; flags.row_contiguous = is_row_contiguous; flags.col_contiguous = is_col_contiguous; + auto import_flags = flags; + if (copy && !import_flags.row_contiguous) { + // Force the copy primitive to materialize the virtual strided input into a + // row-contiguous output instead of preserving a dense non-row layout. + import_flags.contiguous = false; + } out.set_data( mx::allocator::Buffer(data_handle), storage_size, std::move(strides), - flags, + import_flags, nd_array.byte_offset(), [owner = std::move(nd_array)](mx::allocator::Buffer) {}); out.set_status(mx::array::Status::available); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 8104773346..6398d8d330 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1783,7 +1783,9 @@ void init_ops(nb::module_& m) { dtype (Dtype, optional): The desired data-type for the array. copy (bool, optional): Whether to copy the input. If ``True``, always copy. If ``False``, never copy. If ``None``, share memory - when possible and copy otherwise. + when possible and copy otherwise. Copied DLPack inputs are + materialized as row-contiguous MLX arrays, while zero-copy + imports preserve the DLPack strides. Returns: array: An array interpretation of the input. @@ -1809,7 +1811,9 @@ void init_ops(nb::module_& m) { ``__dlpack_device__``. copy (bool, optional): Whether to copy the input. If ``True``, always copy. If ``False``, never copy. If ``None``, share memory - when possible and copy otherwise. + when possible and copy otherwise. Copied inputs are materialized + as row-contiguous MLX arrays, while zero-copy imports preserve + the DLPack strides. Returns: array: An array containing the input data. diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 054542292e..f9df665cae 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2078,6 +2078,8 @@ def test_from_dlpack_cpu_strided(self): y = mx.from_dlpack(view) self.assertEqual(y.tolist(), view.tolist()) + self.assertTrue(memoryview(y).c_contiguous) + self.assertEqual(memoryview(y).strides, (12, 4)) expected = view.copy().tolist() x[0, 0] = 99 self.assertEqual(y.tolist(), expected) @@ -2085,12 +2087,16 @@ def test_from_dlpack_cpu_strided(self): stepped = np.arange(20, dtype=np.int32)[2:10:2] y = mx.from_dlpack(stepped) self.assertEqual(y.tolist(), [2, 4, 6, 8]) + self.assertTrue(memoryview(y).c_contiguous) + self.assertEqual(memoryview(y).strides, (4,)) stepped[0] = 99 self.assertEqual(y.tolist(), [2, 4, 6, 8]) broadcast = np.broadcast_to(np.array([7], dtype=np.int32), (3,)) y = mx.from_dlpack(broadcast) self.assertEqual(y.tolist(), [7, 7, 7]) + self.assertTrue(memoryview(y).c_contiguous) + self.assertEqual(memoryview(y).strides, (4,)) negative_stride = np.arange(5, dtype=np.float32)[::-1] with self.assertRaises(ValueError): @@ -2192,10 +2198,17 @@ def test_torch_mps_dlpack_strided_view(self): y_copy = mx.asarray(view, copy=True) expected = view.cpu().numpy().tolist() + self.assertTrue(memoryview(y_copy).c_contiguous) + self.assertEqual(memoryview(y_copy).strides, (12, 4)) x[0, 2] = 77 torch.mps.synchronize() self.assertEqual(y_copy.tolist(), expected) + z = mx.asarray(view, dtype=mx.float16) + self.assertEqual(z.dtype, mx.float16) + self.assertTrue(memoryview(z).c_contiguous) + self.assertEqual(memoryview(z).strides, (6, 2)) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_stepped_view(self): x = torch.arange(20, device="mps", dtype=torch.int32) From 744e6382bff8aebbf16219bd4a0b436131f4c0bc Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Tue, 26 May 2026 14:45:30 +0800 Subject: [PATCH 25/26] Fix DLPack copy build regressions --- mlx/array.h | 9 +++++++++ mlx/ops.cpp | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mlx/array.h b/mlx/array.h index d8082a4502..8e14ca4726 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -438,6 +438,15 @@ class MLX_API array { void set_data(allocator::Buffer buffer, Deleter d = allocator::free); + void set_data( + allocator::Buffer buffer, + size_t data_size, + Strides strides, + Flags flags, + Deleter d) { + set_data(buffer, data_size, std::move(strides), flags, 0, std::move(d)); + } + void set_data( allocator::Buffer buffer, size_t data_size, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 3a093582b7..081cb61c7a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -265,8 +265,9 @@ array astype( if (dtype == a.dtype() && !copy.value_or(false)) { return a; } + auto copied_shape = a.shape(); // |a| will be moved return array( - a.shape(), + std::move(copied_shape), dtype, std::make_shared(to_stream(s), dtype), {std::move(a)}); From f733754e6ce32e85bd3aa7677035ab3d1e0fe9c8 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Tue, 26 May 2026 16:02:43 +0800 Subject: [PATCH 26/26] Avoid PyTorch MPS scalar updates in DLPack tests --- python/tests/test_array.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/python/tests/test_array.py b/python/tests/test_array.py index f9df665cae..545ab64c3a 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2121,7 +2121,7 @@ def test_torch_mps_array_copies_dlpack_input(self): torch.mps.synchronize() y = mx.array(x) - x.add_(10) + x.zero_() torch.mps.synchronize() self.assertEqual(y.tolist(), [0.0, 1.0, 2.0]) @@ -2131,7 +2131,7 @@ def test_torch_mps_asarray_copy_true_copies_dlpack_input(self): torch.mps.synchronize() y = mx.asarray(x, copy=True) - x.add_(10) + x.zero_() torch.mps.synchronize() self.assertEqual(y.tolist(), [0.0, 1.0, 2.0]) @@ -2141,9 +2141,9 @@ def test_torch_mps_dlpack_zero_copy_shares_updates(self): torch.mps.synchronize() y = mx.asarray(x) - x.add_(100) + x.zero_() torch.mps.synchronize() - self.assertEqual((y + 1).tolist(), (x + 1).cpu().numpy().tolist()) + self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) y += 10 mx.eval(y) @@ -2156,7 +2156,7 @@ def test_torch_mps_dlpack_matching_dtype_argument_shares_updates(self): y = mx.asarray(x, dtype=mx.float32, copy=False) self.assertEqual(y.dtype, mx.float32) - x.add_(100) + x.zero_() torch.mps.synchronize() self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) @@ -2170,7 +2170,7 @@ def test_torch_mps_dlpack_different_dtype_argument_copies(self): self.assertEqual(z.dtype, mx.float16) self.assertEqual(z.tolist(), expected) - x.add_(100) + x.zero_() torch.mps.synchronize() self.assertEqual(z.tolist(), expected) @@ -2182,7 +2182,7 @@ def test_torch_mps_dlpack_data_offset(self): view = torch.arange(12, device="mps", dtype=torch.float32)[3:9] view_mx = mx.asarray(view) torch.mps.synchronize() - self.assertEqual((view_mx + 1).tolist(), (view + 1).cpu().numpy().tolist()) + self.assertEqual(view_mx.tolist(), view.cpu().numpy().tolist()) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_strided_view(self): @@ -2235,9 +2235,9 @@ def test_torch_mps_dlpack_broadcast_stride(self): y = mx.asarray(view, copy=False) self.assertEqual(y.tolist(), [7, 7, 7]) - x.add_(1) + x.zero_() torch.mps.synchronize() - self.assertEqual(y.tolist(), [8, 8, 8]) + self.assertEqual(y.tolist(), [0, 0, 0]) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_bfloat16(self): @@ -2294,10 +2294,10 @@ def test_mlx_dlpack_export_torch_update_writes_mlx_buffer(self): self.assertEqual(t.device.type, "mps") self.assertEqual(t.cpu().numpy().tolist(), [2.0, 3.0, 4.0, 5.0]) - t.add_(10) + t.zero_() torch.mps.synchronize() - self.assertEqual(y.tolist(), [12.0, 13.0, 14.0, 15.0]) - self.assertEqual(x.tolist(), [0.0, 1.0, 12.0, 13.0, 14.0, 15.0, 6.0, 7.0]) + self.assertEqual(y.tolist(), [0.0, 0.0, 0.0, 0.0]) + self.assertEqual(x.tolist(), [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 6.0, 7.0]) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_from_dlpack_torch_mps_copy_none_shares_updates(self): @@ -2305,13 +2305,13 @@ def test_from_dlpack_torch_mps_copy_none_shares_updates(self): torch.mps.synchronize() y = mx.from_dlpack(x) - x.add_(10) + x.zero_() torch.mps.synchronize() - self.assertEqual(y.tolist(), [10.0, 11.0, 12.0]) + self.assertEqual(y.tolist(), [0.0, 0.0, 0.0]) y += 10 mx.eval(y) - self.assertEqual(x.cpu().numpy().tolist(), [20.0, 21.0, 22.0]) + self.assertEqual(x.cpu().numpy().tolist(), [10.0, 10.0, 10.0]) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_from_dlpack_torch_mps_copy_false_shares_updates(self): @@ -2319,9 +2319,9 @@ def test_from_dlpack_torch_mps_copy_false_shares_updates(self): torch.mps.synchronize() y = mx.from_dlpack(x, copy=False) - x.add_(10) + x.zero_() torch.mps.synchronize() - self.assertEqual(y.tolist(), [10.0, 11.0, 12.0]) + self.assertEqual(y.tolist(), [0.0, 0.0, 0.0]) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_from_dlpack_torch_mps_copy_true_copies(self): @@ -2329,7 +2329,7 @@ def test_from_dlpack_torch_mps_copy_true_copies(self): torch.mps.synchronize() y = mx.from_dlpack(x, copy=True) - x.add_(10) + x.zero_() torch.mps.synchronize() self.assertEqual(y.tolist(), [0.0, 1.0, 2.0])