Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7af5b6a
Support Metal DLPack zero-copy import
XXXXRT666 May 10, 2026
bbffe6a
Add from_dlpack copy controls
XXXXRT666 May 10, 2026
11cda58
Support Metal DLPack zero-copy sharing
XXXXRT666 May 11, 2026
361143f
Share DLPack arrays when dtype matches
XXXXRT666 May 11, 2026
206dce4
Clarify MPS DLPack host access test
XXXXRT666 May 19, 2026
a7c39c7
Pin nanobind with Metal DLPack support
XXXXRT666 May 19, 2026
dc90eba
Use host accessibility check for Metal raw pointer
XXXXRT666 May 19, 2026
2dea5e9
Add array copy for shared DLPack buffers
XXXXRT666 May 21, 2026
ef12528
Keep FFT torch baseline on CPU
XXXXRT666 May 21, 2026
9400ca4
Handle private Metal DLPack buffers with copies
XXXXRT666 May 21, 2026
7b7b4d0
Leave nanobind shallow clone note
XXXXRT666 May 21, 2026
310999b
Reduce Metal DLPack test redundancy
XXXXRT666 May 21, 2026
da3559a
Support DLPack copy export
XXXXRT666 May 21, 2026
3aaf373
Copy private Metal DLPack buffers on import
XXXXRT666 May 22, 2026
05a073e
Clean up Metal DLPack owner handling
XXXXRT666 May 22, 2026
6cab38d
Clarify DLPack conversion dtype names
XXXXRT666 May 22, 2026
f88366d
Clean up DLPack conversion paths
XXXXRT666 May 22, 2026
07784c2
Revert "Add array copy for shared DLPack buffers"
XXXXRT666 May 22, 2026
09fa65d
Address DLPack review cleanup
XXXXRT666 May 23, 2026
f5bab00
Address DLPack buffer reuse review
XXXXRT666 May 23, 2026
a4378cf
Add copy control to asarray
XXXXRT666 May 23, 2026
7e87757
Address DLPack review feedback
XXXXRT666 May 24, 2026
a7c4a44
Support strided DLPack arrays
XXXXRT666 May 24, 2026
30e1484
Make copied DLPack imports row-contiguous
XXXXRT666 May 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ if(MLX_BUILD_PYTHON_BINDINGS)
FetchContent_Declare(
nanobind
GIT_REPOSITORY https://github.com/wjakob/nanobind.git
GIT_TAG v2.12.0
GIT_SHALLOW TRUE
Comment thread
zcbenz marked this conversation as resolved.
GIT_TAG 117ba9ee0253e1bbada678bb4b5d6c6e4ed441eb
# GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(nanobind)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
Expand Down
1 change: 1 addition & 0 deletions docs/src/python/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Operations
floor
floor_divide
full
from_dlpack
gather_mm
gather_qmm
greater
Expand Down
53 changes: 50 additions & 3 deletions docs/src/usage/numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,62 @@ PyTorch
-------

PyTorch supports DLPack inputs and can import MLX arrays directly.
MLX can also import PyTorch tensors through DLPack with ``mx.asarray`` or
``mx.from_dlpack``. Use ``torch.as_tensor`` to import an MLX array with
DLPack; ``torch.tensor`` copies the data instead. Similarly, ``mx.asarray``
can share DLPack inputs when possible, while ``mx.array`` copies:

.. code-block:: python

import mlx.core as mx
import torch

a = mx.arange(3)
b = torch.tensor(a)
c = mx.array(b.cpu())
a = mx.arange(3, dtype=mx.float32)
mx.eval(a)

shared = torch.as_tensor(a)
copied = torch.tensor(a)

Creating an MLX array from a CPU tensor copies the data into MLX-owned storage.
The arrays do not share memory:

.. code-block:: python

b = torch.arange(3)
c = mx.array(b)

b += 10
print(c.tolist()) # [0, 1, 2]

Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to
``mx.asarray`` or to ``mx.from_dlpack`` with ``copy=None``, MLX imports it
without a copy when the underlying Metal buffer is not private. Private Metal
buffers are copied into MLX-managed storage instead. Passing ``copy=False``
requires zero-copy import and raises an error if a copy would be needed.
Passing ``copy=True`` asks MLX to create a new array instead of reusing the
Metal buffer. Copied DLPack inputs are materialized as row-contiguous MLX
arrays; zero-copy imports preserve the DLPack strides. ``mx.array`` also
creates a new array instead of reusing the Metal buffer. MLX arrays exported to
PyTorch with DLPack are exported without a copy on Metal.

In particular, PyTorch 2.12 and later use shared storage for ordinary MPS
tensors on Apple silicon, while older PyTorch versions may use private storage
and require a copy on import. DLPack conversion does not synchronize pending
Metal work; synchronize or evaluate the producing framework before reading the
converted array.

.. code-block:: python

b = torch.arange(3, device="mps", dtype=torch.float32)
torch.mps.synchronize()
c = mx.asarray(b) # zero-copy if the Metal buffer can be reused
d = mx.from_dlpack(b, copy=True) # explicit copy

.. code-block:: python

a = mx.arange(3, dtype=mx.float32)
mx.eval(a)
b = torch.as_tensor(a) # zero-copy DLPack import on Metal

JAX
---
Expand Down
2 changes: 2 additions & 0 deletions mlx/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,6 @@ inline void release(Buffer buffer) {
allocator().release(buffer);
}

MLX_API bool can_reuse_alien_buffer(void* ptr);

} // namespace mlx::core::allocator
4 changes: 3 additions & 1 deletion mlx/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
other.data_size(),
other.strides(),
other.flags(),
0,
[](auto) {});
cpy.array_desc_->offset = other.array_desc_->offset;
return cpy;
Expand Down Expand Up @@ -180,9 +181,10 @@ void array::set_data(
size_t data_size,
Strides strides,
Flags flags,
int64_t offset,
Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->offset = 0;
array_desc_->offset = offset;
array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides);
array_desc_->flags = flags;
Expand Down
1 change: 1 addition & 0 deletions mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ class MLX_API array {
size_t data_size,
Strides strides,
Flags flags,
int64_t offset = 0,
Deleter d = allocator::free);

void copy_shared_buffer(
Expand Down
4 changes: 4 additions & 0 deletions mlx/backend/cuda/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,10 @@ void* Buffer::raw_ptr() {
return cbuf.data;
}

bool can_reuse_alien_buffer(void* ptr) {
return true;
}

} // namespace allocator

size_t get_active_memory() {
Expand Down
13 changes: 12 additions & 1 deletion mlx/backend/metal/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <mach/vm_page_size.h>
#include <unistd.h>
#include <cassert>
#include <cstdlib>

namespace mlx::core {
Expand All @@ -24,7 +25,17 @@ void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<MTL::Buffer*>(ptr_)->contents();
auto* buf = static_cast<MTL::Buffer*>(ptr_);
assert(buf->storageMode() != MTL::StorageModePrivate);
return buf->contents();
}

bool can_reuse_alien_buffer(void* ptr) {
if (!ptr) {
return true;
}
auto* buf = static_cast<MTL::Buffer*>(ptr);
return buf->storageMode() != MTL::StorageModePrivate;
}

} // namespace allocator
Expand Down
4 changes: 4 additions & 0 deletions mlx/backend/no_gpu/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ void* Buffer::raw_ptr() {
return static_cast<size_t*>(ptr_) + 1;
}

bool can_reuse_alien_buffer(void*) {
return true;
}

} // namespace allocator

size_t get_active_memory() {
Expand Down
11 changes: 7 additions & 4 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,16 @@ array linspace(
s);
}

array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
if (dtype == a.dtype()) {
array astype(
array a,
Dtype dtype,
std::optional<bool> copy,
StreamOrDevice s /* = {} */) {
if (dtype == a.dtype() && !copy.value_or(false)) {
return a;
}
auto copied_shape = a.shape(); // |a| will be moved
return array(
std::move(copied_shape),
a.shape(),
dtype,
std::make_shared<AsType>(to_stream(s), dtype),
{std::move(a)});
Expand Down
6 changes: 5 additions & 1 deletion mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ MLX_API array linspace(
StreamOrDevice s = {});

/** Convert an array to the given data type. */
MLX_API array astype(array a, Dtype dtype, StreamOrDevice s = {});
MLX_API array
astype(array a, Dtype dtype, std::optional<bool> copy, StreamOrDevice s = {});
inline array astype(array a, Dtype dtype, StreamOrDevice s = {}) {
return astype(std::move(a), dtype, std::nullopt, s);
}

/** Create a view of an array with the given shape and strides. */
MLX_API array as_strided(
Expand Down
28 changes: 24 additions & 4 deletions python/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
#include <cstdint>
#include <cstring>
#include <sstream>
#include <tuple>

#include <nanobind/ndarray.h>
#include <nanobind/stl/complex.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <nanobind/typing.h>
Expand Down Expand Up @@ -295,7 +297,7 @@ void init_array(nb::module_& m) {
.def(
"__init__",
[](mx::array* aptr, nb::object v, std::optional<mx::Dtype> t) {
new (aptr) mx::array(create_array(v, t));
new (aptr) mx::array(create_array(v, t, true));
},
"val"_a,
"dtype"_a = nb::none(),
Expand Down Expand Up @@ -373,7 +375,9 @@ void init_array(nb::module_& m) {
)pbdoc")
.def(
"astype",
&mx::astype,
[](const mx::array& a, mx::Dtype dtype, mx::StreamOrDevice s) {
return mx::astype(a, dtype, s);
},
"dtype"_a,
"stream"_a = nb::none(),
R"pbdoc(
Expand Down Expand Up @@ -479,7 +483,7 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Invalid pickle state: expected (ndarray, Dtype::Val)");
}
using ND = nb::ndarray<nb::ro, nb::c_contig>;
using ND = nb::ndarray<nb::ro>;
ND nd = nb::cast<ND>(state[0]);
auto val = static_cast<mx::Dtype::Val>(nb::cast<uint8_t>(state[1]));
if (val == mx::Dtype::Val::bfloat16) {
Expand All @@ -496,7 +500,23 @@ void init_array(nb::module_& m) {
new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt));
}
})
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
.def(
"__dlpack__",
[](const mx::array& a,
nb::object,
nb::object,
std::optional<std::tuple<int, int>> dl_device,
std::optional<bool> copy) {
if (copy.value_or(false)) {
return mlx_to_dlpack(mx::astype(a, a.dtype(), true), dl_device);
}
return mlx_to_dlpack(a, dl_device);
},
nb::kw_only(),
"stream"_a = nb::none(),
"max_version"_a = nb::none(),
"dl_device"_a = nb::none(),
"copy"_a = nb::none())
.def(
"__dlpack_device__",
[](const mx::array& a) {
Expand Down
Loading
Loading