diff --git a/CMakeLists.txt b/CMakeLists.txt index a2395d02f6..9380748597 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -357,8 +357,8 @@ if(MLX_BUILD_PYTHON_BINDINGS) FetchContent_Declare( nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git - GIT_TAG v2.12.0 - GIT_SHALLOW TRUE + GIT_TAG 117ba9ee0253e1bbada678bb4b5d6c6e4ed441eb + # GIT_SHALLOW TRUE EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(nanobind) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 6aabfc6954..19608ff37b 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -77,6 +77,7 @@ Operations floor floor_divide full + from_dlpack gather_mm gather_qmm greater diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index bf71938dff..f7a2169660 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -76,15 +76,62 @@ PyTorch ------- PyTorch supports DLPack inputs and can import MLX arrays directly. +MLX can also import PyTorch tensors through DLPack with ``mx.asarray`` or +``mx.from_dlpack``. Use ``torch.as_tensor`` to import an MLX array with +DLPack; ``torch.tensor`` copies the data instead. Similarly, ``mx.asarray`` +can share DLPack inputs when possible, while ``mx.array`` copies: .. code-block:: python import mlx.core as mx import torch - a = mx.arange(3) - b = torch.tensor(a) - c = mx.array(b.cpu()) + a = mx.arange(3, dtype=mx.float32) + mx.eval(a) + + shared = torch.as_tensor(a) + copied = torch.tensor(a) + +Creating an MLX array from a CPU tensor copies the data into MLX-owned storage. +The arrays do not share memory: + +.. code-block:: python + + b = torch.arange(3) + c = mx.array(b) + + b += 10 + print(c.tolist()) # [0, 1, 2] + +Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to +``mx.asarray`` or to ``mx.from_dlpack`` with ``copy=None``, MLX imports it +without a copy when the underlying Metal buffer is not private. Private Metal +buffers are copied into MLX-managed storage instead. Passing ``copy=False`` +requires zero-copy import and raises an error if a copy would be needed. +Passing ``copy=True`` asks MLX to create a new array instead of reusing the +Metal buffer. Copied DLPack inputs are materialized as row-contiguous MLX +arrays; zero-copy imports preserve the DLPack strides. ``mx.array`` also +creates a new array instead of reusing the Metal buffer. MLX arrays exported to +PyTorch with DLPack are exported without a copy on Metal. + +In particular, PyTorch 2.12 and later use shared storage for ordinary MPS +tensors on Apple silicon, while older PyTorch versions may use private storage +and require a copy on import. DLPack conversion does not synchronize pending +Metal work; synchronize or evaluate the producing framework before reading the +converted array. + +.. code-block:: python + + b = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + c = mx.asarray(b) # zero-copy if the Metal buffer can be reused + d = mx.from_dlpack(b, copy=True) # explicit copy + +.. code-block:: python + + a = mx.arange(3, dtype=mx.float32) + mx.eval(a) + b = torch.as_tensor(a) # zero-copy DLPack import on Metal JAX --- diff --git a/mlx/allocator.h b/mlx/allocator.h index 824deac2c7..c831b7247d 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -72,4 +72,6 @@ inline void release(Buffer buffer) { allocator().release(buffer); } +MLX_API bool can_reuse_alien_buffer(void* ptr); + } // namespace mlx::core::allocator diff --git a/mlx/array.cpp b/mlx/array.cpp index e694ccfd67..07d3acd616 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -64,6 +64,7 @@ array array::unsafe_weak_copy(const array& other) { other.data_size(), other.strides(), other.flags(), + 0, [](auto) {}); cpy.array_desc_->offset = other.array_desc_->offset; return cpy; @@ -180,9 +181,10 @@ void array::set_data( size_t data_size, Strides strides, Flags flags, + int64_t offset, Deleter d) { array_desc_->data = std::make_shared(buffer, d); - array_desc_->offset = 0; + array_desc_->offset = offset; array_desc_->data_size = data_size; array_desc_->strides = std::move(strides); array_desc_->flags = flags; diff --git a/mlx/array.h b/mlx/array.h index 60d5e50bbd..8e14ca4726 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -443,6 +443,16 @@ class MLX_API array { size_t data_size, Strides strides, Flags flags, + Deleter d) { + set_data(buffer, data_size, std::move(strides), flags, 0, std::move(d)); + } + + void set_data( + allocator::Buffer buffer, + size_t data_size, + Strides strides, + Flags flags, + int64_t offset = 0, Deleter d = allocator::free); void copy_shared_buffer( diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 718ae33e9c..21358929da 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -416,6 +416,10 @@ void* Buffer::raw_ptr() { return cbuf.data; } +bool can_reuse_alien_buffer(void* ptr) { + return true; +} + } // namespace allocator size_t get_active_memory() { diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 222c6fd9fa..60459c67c2 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -7,6 +7,7 @@ #include #include +#include #include namespace mlx::core { @@ -24,7 +25,17 @@ void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; } - return static_cast(ptr_)->contents(); + auto* buf = static_cast(ptr_); + assert(buf->storageMode() != MTL::StorageModePrivate); + return buf->contents(); +} + +bool can_reuse_alien_buffer(void* ptr) { + if (!ptr) { + return true; + } + auto* buf = static_cast(ptr); + return buf->storageMode() != MTL::StorageModePrivate; } } // namespace allocator diff --git a/mlx/backend/no_gpu/allocator.cpp b/mlx/backend/no_gpu/allocator.cpp index 1fdcc262a4..a800e381a7 100644 --- a/mlx/backend/no_gpu/allocator.cpp +++ b/mlx/backend/no_gpu/allocator.cpp @@ -178,6 +178,10 @@ void* Buffer::raw_ptr() { return static_cast(ptr_) + 1; } +bool can_reuse_alien_buffer(void*) { + return true; +} + } // namespace allocator size_t get_active_memory() { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 6ad41e2e38..081cb61c7a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -257,8 +257,12 @@ array linspace( s); } -array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) { - if (dtype == a.dtype()) { +array astype( + array a, + Dtype dtype, + std::optional copy, + StreamOrDevice s /* = {} */) { + if (dtype == a.dtype() && !copy.value_or(false)) { return a; } auto copied_shape = a.shape(); // |a| will be moved diff --git a/mlx/ops.h b/mlx/ops.h index 208964d1aa..d047045a23 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -47,7 +47,11 @@ MLX_API array linspace( StreamOrDevice s = {}); /** Convert an array to the given data type. */ -MLX_API array astype(array a, Dtype dtype, StreamOrDevice s = {}); +MLX_API array +astype(array a, Dtype dtype, std::optional copy, StreamOrDevice s = {}); +inline array astype(array a, Dtype dtype, StreamOrDevice s = {}) { + return astype(std::move(a), dtype, std::nullopt, s); +} /** Create a view of an array with the given shape and strides. */ MLX_API array as_strided( diff --git a/python/src/array.cpp b/python/src/array.cpp index 28c12f622c..ab846addda 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -2,11 +2,13 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -295,7 +297,7 @@ void init_array(nb::module_& m) { .def( "__init__", [](mx::array* aptr, nb::object v, std::optional t) { - new (aptr) mx::array(create_array(v, t)); + new (aptr) mx::array(create_array(v, t, true)); }, "val"_a, "dtype"_a = nb::none(), @@ -373,7 +375,9 @@ void init_array(nb::module_& m) { )pbdoc") .def( "astype", - &mx::astype, + [](const mx::array& a, mx::Dtype dtype, mx::StreamOrDevice s) { + return mx::astype(a, dtype, s); + }, "dtype"_a, "stream"_a = nb::none(), R"pbdoc( @@ -479,7 +483,7 @@ void init_array(nb::module_& m) { throw std::invalid_argument( "Invalid pickle state: expected (ndarray, Dtype::Val)"); } - using ND = nb::ndarray; + using ND = nb::ndarray; ND nd = nb::cast(state[0]); auto val = static_cast(nb::cast(state[1])); if (val == mx::Dtype::Val::bfloat16) { @@ -496,7 +500,23 @@ void init_array(nb::module_& m) { new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt)); } }) - .def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); }) + .def( + "__dlpack__", + [](const mx::array& a, + nb::object, + nb::object, + std::optional> dl_device, + std::optional copy) { + if (copy.value_or(false)) { + return mlx_to_dlpack(mx::astype(a, a.dtype(), true), dl_device); + } + return mlx_to_dlpack(a, dl_device); + }, + nb::kw_only(), + "stream"_a = nb::none(), + "max_version"_a = nb::none(), + "dl_device"_a = nb::none(), + "copy"_a = nb::none()) .def( "__dlpack_device__", [](const mx::array& a) { diff --git a/python/src/convert.cpp b/python/src/convert.cpp index a5455c2b33..fc853fff31 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -1,13 +1,21 @@ // Copyright © 2024 Apple Inc. +#include #include #include +#include #include +#include #include "python/src/convert.h" #include "python/src/utils.h" +#include "mlx/allocator.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/cuda.h" +#include "mlx/backend/metal/metal.h" +#include "mlx/ops.h" #include "mlx/utils.h" enum PyScalarT { @@ -31,82 +39,256 @@ int check_shape_dim(int64_t dim) { return static_cast(dim); } -template -mx::array nd_array_to_mlx_contiguous( - nb::ndarray nd_array, - const mx::Shape& shape, - mx::Dtype dtype) { - // Make a copy of the numpy buffer - // Get buffer ptr pass to array constructor - auto data_ptr = nd_array.data(); - return mx::array(static_cast(data_ptr), shape, dtype); -} - -mx::array nd_array_to_mlx( - nb::ndarray nd_array, - std::optional dtype, - std::optional nb_dtype) { - if (nd_array.device_type() != nb::device::cpu::value) { - throw std::invalid_argument( - "Cannot convert non-CPU DLPack array to mlx array."); - } - - // Compute the shape and size +template +mx::Shape get_shape(const nb::ndarray& nd_array) { mx::Shape shape; shape.reserve(nd_array.ndim()); for (int i = 0; i < nd_array.ndim(); i++) { shape.push_back(check_shape_dim(nd_array.shape(i))); } - auto type = nb_dtype.value_or(nd_array.dtype()); + return shape; +} + +template +mx::Strides get_strides(const nb::ndarray& nd_array) { + mx::Strides strides; + strides.reserve(nd_array.ndim()); + for (int i = 0; i < nd_array.ndim(); i++) { + strides.push_back(nd_array.stride(i)); + } + return strides; +} - // Copy data and make array +size_t strided_storage_size( + const mx::Shape& shape, + const mx::Strides& strides) { + size_t storage_size = 1; + for (int i = 0; i < shape.size(); i++) { + if (shape[i] == 0) { + return 0; + } + if (strides[i] < 0) { + throw std::invalid_argument( + "Cannot convert DLPack arrays with negative strides to mlx array."); + } + storage_size += (shape[i] - 1) * strides[i]; + } + return storage_size; +} + +size_t strided_offset( + size_t index, + const mx::Shape& shape, + const mx::Strides& strides) { + size_t offset = 0; + for (size_t i = shape.size(); i-- > 0;) { + auto dim_index = index % shape[i]; + index /= shape[i]; + offset += dim_index * strides[i]; + } + return offset; +} + +template +auto dispatch_dlpack_dtype( + nb::dlpack::dtype type, + F&& f, + const char* error_message) { if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::bool_)); + return f.template operator()(mx::bool_); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::uint8)); + return f.template operator()(mx::uint8); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::uint16)); + return f.template operator()(mx::uint16); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::uint32)); + return f.template operator()(mx::uint32); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::uint64)); + return f.template operator()(mx::uint64); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::int8)); + return f.template operator()(mx::int8); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::int16)); + return f.template operator()(mx::int16); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::int32)); + return f.template operator()(mx::int32); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::int64)); + return f.template operator()(mx::int64); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::float16)); + return f.template operator()(mx::float16); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::bfloat16)); + return f.template operator()(mx::bfloat16); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::float32)); + return f.template operator()(mx::float32); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::float32)); + return f.template operator()(mx::float32); } else if (type == nb::dtype>()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::complex64)); + return f.template operator()(mx::complex64); } else if (type == nb::dtype>()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::complex64)); + return f.template operator()(mx::complex64); } else { - throw std::invalid_argument("Cannot convert numpy array to mlx array."); + throw std::invalid_argument(error_message); + } +} + +mx::Dtype mlx_dtype_from_dlpack( + nb::dlpack::dtype type, + const char* error_message) { + return dispatch_dlpack_dtype( + type, [](mx::Dtype dtype) { return dtype; }, error_message); +} + +nb::dlpack::dtype mlx_dtype_to_dl_dtype(mx::Dtype dtype) { + switch (dtype) { + case mx::bool_: + return nb::dtype(); + case mx::uint8: + return nb::dtype(); + case mx::uint16: + return nb::dtype(); + case mx::uint32: + return nb::dtype(); + case mx::uint64: + return nb::dtype(); + case mx::int8: + return nb::dtype(); + case mx::int16: + return nb::dtype(); + case mx::int32: + return nb::dtype(); + case mx::int64: + return nb::dtype(); + case mx::float16: + return nb::dtype(); + case mx::bfloat16: + return nb::dtype(); + case mx::float32: + return nb::dtype(); + case mx::float64: + return nb::dtype(); + case mx::complex64: + return nb::dtype>(); + default: + throw nb::buffer_error("Cannot export mlx array with unsupported dtype."); + } +} + +template +mx::array cpu_nd_array_to_mlx( + nb::ndarray nd_array, + const mx::Shape& shape, + mx::Dtype dst_dtype) { + return dispatch_dlpack_dtype( + mlx_dtype_to_dl_dtype(dst_dtype), + [&](mx::Dtype) { + auto out = mx::array(shape, dst_dtype, nullptr, {}); + auto strides = get_strides(nd_array); + strided_storage_size(shape, strides); + out.set_data(mx::allocator::malloc(out.nbytes())); + if (out.size() > 0) { + auto src = static_cast(nd_array.data()); + auto dst = out.data(); + for (size_t i = 0; i < out.size(); ++i) { + dst[i] = static_cast(src[strided_offset(i, shape, strides)]); + } + } + out.set_status(mx::array::Status::available); + return out; + }, + "Cannot convert numpy array to mlx array."); +} + +mx::array metal_dlpack_to_mlx( + nb::ndarray nd_array, + mx::Dtype src_dtype, + mx::Dtype dst_dtype, + bool copy) { + if (!mx::metal::is_available()) { + throw std::invalid_argument("Metal DLPack import is not available."); + } + auto shape = get_shape(nd_array); + if (nd_array.itemsize() != mx::size_of(src_dtype)) { + throw std::invalid_argument( + "Cannot convert Metal DLPack dtype to mlx dtype."); + } + auto strides = get_strides(nd_array); + auto storage_size = strided_storage_size(shape, strides); + auto [no_bsx_size, is_row_contiguous, is_col_contiguous] = shape.empty() + ? std::make_tuple(storage_size, true, true) + : mx::check_contiguity(shape, strides); + auto data_handle = nd_array.data_handle(); + mx::array out(shape, src_dtype, nullptr, {}); + auto flags = out.flags(); + flags.contiguous = no_bsx_size == storage_size; + flags.row_contiguous = is_row_contiguous; + flags.col_contiguous = is_col_contiguous; + auto import_flags = flags; + if (copy && !import_flags.row_contiguous) { + // Force the copy primitive to materialize the virtual strided input into a + // row-contiguous output instead of preserving a dense non-row layout. + import_flags.contiguous = false; + } + out.set_data( + mx::allocator::Buffer(data_handle), + storage_size, + std::move(strides), + import_flags, + nd_array.byte_offset(), + [owner = std::move(nd_array)](mx::allocator::Buffer) {}); + out.set_status(mx::array::Status::available); + + if (copy) { + auto result = mx::astype(out, dst_dtype, true, mx::Device::gpu); + result.eval(); + return result; + } + return out; +} + +mx::array nd_array_to_mlx( + nb::ndarray nd_array, + std::optional requested_dtype, + std::optional src_dlpack_dtype_override, + std::optional copy) { + auto src_dlpack_dtype = src_dlpack_dtype_override.value_or(nd_array.dtype()); + auto src_mlx_dtype = + mlx_dtype_from_dlpack(src_dlpack_dtype, "Cannot convert array to mlx."); + auto dst_dtype = requested_dtype.value_or(src_mlx_dtype); + auto device_type = nd_array.device_type(); + bool can_reuse_buffer = device_type == nb::device::cpu::value || + mx::allocator::can_reuse_alien_buffer(nd_array.data_handle()); + // call this for cpu array will raise error, use or to aviod it + bool should_copy = + copy.value_or(false) || dst_dtype != src_mlx_dtype || !can_reuse_buffer; + if (copy.has_value() && copy.value() == false && dst_dtype != src_mlx_dtype) { + throw std::invalid_argument( + "Cannot convert DLPack array to requested dtype without a copy."); + } + switch (device_type) { + case nb::device::cpu::value: { + if (copy.has_value() && copy.value() == false) { + throw std::invalid_argument( + "Cannot import a CPU DLPack array without a copy."); + } + auto shape = get_shape(nd_array); + return dispatch_dlpack_dtype( + src_dlpack_dtype, + [&](mx::Dtype src_dtype) { + return cpu_nd_array_to_mlx(nd_array, shape, dst_dtype); + }, + "Cannot convert numpy array to mlx array."); + } + case nb::device::metal::value: { + if (copy.has_value() && copy.value() == false && !can_reuse_buffer) { + throw std::invalid_argument( + "Cannot import a private Metal DLPack buffer without a copy."); + } + return metal_dlpack_to_mlx( + nd_array, src_mlx_dtype, dst_dtype, should_copy); + } + case nb::device::cuda::value: + case nb::device::cuda_managed::value: + throw std::invalid_argument("CUDA DLPack import is not supported."); + default: + throw std::invalid_argument("Unsupported DLPack device."); } } @@ -119,11 +301,12 @@ nb::ndarray mlx_to_nd_array_impl( a.eval(); } std::vector shape(a.shape().begin(), a.shape().end()); + auto owner = nb::cast(a); return nb::ndarray( a.data(), a.ndim(), shape.data(), - /* owner= */ nb::none(), + /* owner= */ owner, a.strides().data(), t.value_or(nb::dtype())); } @@ -168,8 +351,56 @@ nb::ndarray mlx_to_np_array(const mx::array& a) { return mlx_to_nd_array(a); } -nb::ndarray<> mlx_to_dlpack(const mx::array& a) { - return mlx_to_nd_array<>(a); +nb::ndarray<> mlx_to_dlpack( + const mx::array& a, + std::optional> dl_device) { + auto default_device = mx::metal::is_available() + ? std::tuple{nb::device::metal::value, 0} + : (mx::cu::is_available() ? std::tuple{nb::device::cuda_managed::value, 0} + : std::tuple{nb::device::cpu::value, 0}); + auto [device_type, device_id] = dl_device.value_or(default_device); + + if (device_type == nb::device::cuda::value || + device_type == nb::device::cuda_managed::value) { + throw nb::buffer_error("CUDA DLPack export is not supported."); + } + if (device_type != nb::device::cpu::value && + device_type != nb::device::metal::value) { + throw nb::buffer_error( + "Cannot export mlx array to requested DLPack device."); + } + if (device_type == nb::device::metal::value && !mx::metal::is_available()) { + throw nb::buffer_error("Metal DLPack export is not available."); + } + + auto arr = a; + void* data = nullptr; + uint64_t byte_offset = 0; + { + nb::gil_scoped_release nogil; + arr.eval(); + if (device_type == nb::device::cpu::value) { + data = arr.buffer().raw_ptr(); + byte_offset = arr.offset(); + } else { + data = arr.buffer().ptr(); + byte_offset = arr.offset(); + } + } + + std::vector shape(arr.shape().begin(), arr.shape().end()); + auto owner = nb::cast(arr); + return nb::ndarray<>( + data, + arr.ndim(), + shape.data(), + /* owner= */ owner, + arr.strides().data(), + mlx_dtype_to_dl_dtype(arr.dtype()), + device_type, + device_id, + '\0', + byte_offset); } nb::object to_scalar(mx::array& a) { @@ -471,10 +702,22 @@ mx::array array_from_list(nb::tuple pl, std::optional dtype) { return array_from_list_impl(pl, dtype); } -mx::array create_array(nb::object v, std::optional t) { +void check_copy_false(std::optional copy) { + if (copy.has_value() && copy.value() == false) { + throw std::invalid_argument( + "Unable to avoid copy while creating an array as requested."); + } +} + +mx::array create_array( + nb::object v, + std::optional t, + std::optional copy) { if (nb::isinstance(v)) { + check_copy_false(copy); return mx::array(nb::cast(v), t.value_or(mx::bool_)); } else if (nb::isinstance(v)) { + check_copy_false(copy); auto val = nb::cast(v); auto default_type = (val > std::numeric_limits::max() || val < std::numeric_limits::min()) @@ -482,6 +725,7 @@ mx::array create_array(nb::object v, std::optional t) { : mx::int32; return mx::array(val, t.value_or(default_type)); } else if (nb::isinstance(v)) { + check_copy_false(copy); auto out_type = t.value_or(mx::float32); if (out_type == mx::float64) { return mx::array(nb::cast(v), out_type); @@ -489,18 +733,23 @@ mx::array create_array(nb::object v, std::optional t) { return mx::array(nb::cast(v), out_type); } } else if (PyComplex_Check(v.ptr())) { + check_copy_false(copy); return mx::array( static_cast(nb::cast>(v)), t.value_or(mx::complex64)); } else if (nb::isinstance(v)) { + check_copy_false(copy); return array_from_list(nb::cast(v), t); } else if (nb::isinstance(v)) { + check_copy_false(copy); return array_from_list(nb::cast(v), t); } else if (nb::isinstance(v)) { auto arr = nb::cast(v); - return mx::astype(arr, t.value_or(arr.dtype())); + auto dtype = t.value_or(arr.dtype()); + check_copy_false(copy); + return mx::astype(arr, dtype, copy); } else if (nb::ndarray_check(v)) { - using ContigArray = nb::ndarray; + using ContigArray = nb::ndarray; ContigArray nd; std::optional nb_dtype; // Nanobind does not recognize bfloat16 numpy array: @@ -511,9 +760,10 @@ mx::array create_array(nb::object v, std::optional t) { } else { nd = nb::cast(v); } - return nd_array_to_mlx(nd, t, nb_dtype); + return nd_array_to_mlx(nd, t, nb_dtype, copy); } else { + check_copy_false(copy); auto arr = to_array_with_accessor(v); - return mx::astype(arr, t.value_or(arr.dtype())); + return mx::astype(arr, t.value_or(arr.dtype()), copy); } } diff --git a/python/src/convert.h b/python/src/convert.h index 9341dd3122..0b8e569afb 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -2,12 +2,12 @@ #pragma once #include +#include #include #include #include "mlx/array.h" -#include "mlx/ops.h" namespace mx = mlx::core; namespace nb = nanobind; @@ -62,18 +62,24 @@ struct ArrayLike { }; mx::array nd_array_to_mlx( - nb::ndarray nd_array, + nb::ndarray nd_array, std::optional mx_dtype, - std::optional nb_dtype = std::nullopt); + std::optional src_dlpack_dtype_override = std::nullopt, + std::optional copy = std::nullopt); nb::ndarray mlx_to_np_array(const mx::array& a); -nb::ndarray<> mlx_to_dlpack(const mx::array& a); +nb::ndarray<> mlx_to_dlpack( + const mx::array& a, + std::optional> dl_device = std::nullopt); nb::object to_scalar(mx::array& a); nb::object tolist(mx::array& a); -mx::array create_array(nb::object v, std::optional t); +mx::array create_array( + nb::object v, + std::optional t, + std::optional copy = true); mx::array array_from_list(nb::list pl, std::optional dtype); mx::array array_from_list(nb::tuple pl, std::optional dtype); diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index c44cd9b943..3df4c96882 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -926,7 +926,7 @@ mlx_compute_slice_update_args( } std::optional extract_boolean_mask(const nb::object& obj) { - using NDArray = nb::ndarray; + using NDArray = nb::ndarray; if (nb::isinstance(obj)) { return mx::array(nb::cast(obj), mx::bool_); } else if (nb::isinstance(obj)) { diff --git a/python/src/mlx_func.cpp b/python/src/mlx_func.cpp index 09aceabe9b..9955e134f9 100644 --- a/python/src/mlx_func.cpp +++ b/python/src/mlx_func.cpp @@ -2,6 +2,8 @@ #include "python/src/mlx_func.h" +#include + // A garbage collected function which wraps nb::cpp_function // See https://github.com/wjakob/nanobind/discussions/919 diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 198a4861d9..6398d8d330 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1767,33 +1767,56 @@ void init_ops(nb::module_& m) { "asarray", [](const nb::object& a, std::optional dtype, - std::optional copy) { - if (copy.has_value() && !*copy) { - throw std::invalid_argument("[asarray] copy=False is not supported."); - } - return create_array(a, dtype); - }, + std::optional copy) { return create_array(a, dtype, copy); }, nb::arg(), "dtype"_a = nb::none(), nb::kw_only(), "copy"_a = nb::none(), nb::sig( - "def asarray(a: Union[scalar, array, Sequence], dtype: Optional[Dtype] = None, *, copy: Optional[bool] = None) -> array"), + "def asarray(a: Union[scalar, array, Sequence], dtype: " + "Optional[Dtype] = None, *, copy: Optional[bool] = None) -> array"), R"pbdoc( Convert the input to an array. Args: a: Input data. dtype (Dtype, optional): The desired data-type for the array. - copy (bool, optional): Must be ``True`` or unspecified. ``False`` - is not supported, since MLX has no in-place operations and - cannot return a non-copying view. + copy (bool, optional): Whether to copy the input. If ``True``, + always copy. If ``False``, never copy. If ``None``, share memory + when possible and copy otherwise. Copied DLPack inputs are + materialized as row-contiguous MLX arrays, while zero-copy + imports preserve the DLPack strides. Returns: array: An array interpretation of the input. Raises: - ValueError: If ``copy`` is ``False``. + ValueError: If ``copy`` is ``False`` and a copy is required. + )pbdoc"); + m.def( + "from_dlpack", + [](nb::ndarray x, std::optional copy) { + return nd_array_to_mlx(x, std::nullopt, std::nullopt, copy); + }, + nb::arg(), + nb::kw_only(), + "copy"_a = nb::none(), + nb::sig( + "def from_dlpack(x: DLPackCompatible, /, *, copy: Optional[bool] = None) -> array"), + R"pbdoc( + Create an array from an object that supports DLPack. + + Args: + x: Input object implementing ``__dlpack__`` and + ``__dlpack_device__``. + copy (bool, optional): Whether to copy the input. If ``True``, + always copy. If ``False``, never copy. If ``None``, share memory + when possible and copy otherwise. Copied inputs are materialized + as row-contiguous MLX arrays, while zero-copy imports preserve + the DLPack strides. + + Returns: + array: An array containing the input data. )pbdoc"); m.def( "zeros_like", diff --git a/python/src/utils.cpp b/python/src/utils.cpp index fa4a53428b..4416531dc8 100644 --- a/python/src/utils.cpp +++ b/python/src/utils.cpp @@ -38,7 +38,7 @@ mx::array to_array( return mx::array(static_cast(*pv), mx::complex64); } else if (auto pv = std::get_if(&v); pv) { return *pv; - } else if (auto pv = std::get_if>(&v); pv) { + } else if (auto pv = std::get_if>(&v); pv) { return nd_array_to_mlx(*pv, dtype); } else { return to_array_with_accessor(std::get(v).obj); diff --git a/python/src/utils.h b/python/src/utils.h index 8956fc210f..d2fcbd037e 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -24,7 +24,7 @@ using ScalarOrArray = std::variant< // Must be above ndarray mx::array, // Must be above complex - nb::ndarray, + nb::ndarray, std::complex, ArrayLike>; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 8258b38c94..545ab64c3a 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2034,33 +2034,304 @@ def test_add_numpy(self): self.assertEqual(z.item(), 3) def test_dlpack(self): + class CpuDLPack: + def __init__(self, array): + self.array = array + + def __dlpack_device__(self): + return (1, 0) + + def __dlpack__(self, *args, **kwargs): + kwargs["dl_device"] = (1, 0) + return self.array.__dlpack__(*args, **kwargs) + x = mx.array(1, dtype=mx.int32) - y = np.from_dlpack(x) + y = np.from_dlpack(CpuDLPack(x)) self.assertTrue(mx.array_equal(y, x)) x = mx.array([[1.0, 2.0], [3.0, 4.0]]) - y = np.from_dlpack(x) + y = np.from_dlpack(CpuDLPack(x)) self.assertTrue(mx.array_equal(y, x)) x = mx.arange(16).reshape(4, 4) x = x[::2, ::2] - y = np.from_dlpack(x) + y = np.from_dlpack(CpuDLPack(x)) self.assertTrue(mx.array_equal(y, x)) + def test_from_dlpack_cpu(self): + x = np.arange(3, dtype=np.float32) + + y = mx.from_dlpack(x) + x += 10 + self.assertEqual(y.tolist(), [0.0, 1.0, 2.0]) + + y = mx.from_dlpack(x, copy=True) + x += 10 + self.assertEqual(y.tolist(), [10.0, 11.0, 12.0]) + + with self.assertRaises(ValueError): + mx.from_dlpack(x, copy=False) + + def test_from_dlpack_cpu_strided(self): + x = np.arange(12, dtype=np.float32).reshape(3, 4) + view = x.T + y = mx.from_dlpack(view) + + self.assertEqual(y.tolist(), view.tolist()) + self.assertTrue(memoryview(y).c_contiguous) + self.assertEqual(memoryview(y).strides, (12, 4)) + expected = view.copy().tolist() + x[0, 0] = 99 + self.assertEqual(y.tolist(), expected) + + stepped = np.arange(20, dtype=np.int32)[2:10:2] + y = mx.from_dlpack(stepped) + self.assertEqual(y.tolist(), [2, 4, 6, 8]) + self.assertTrue(memoryview(y).c_contiguous) + self.assertEqual(memoryview(y).strides, (4,)) + stepped[0] = 99 + self.assertEqual(y.tolist(), [2, 4, 6, 8]) + + broadcast = np.broadcast_to(np.array([7], dtype=np.int32), (3,)) + y = mx.from_dlpack(broadcast) + self.assertEqual(y.tolist(), [7, 7, 7]) + self.assertTrue(memoryview(y).c_contiguous) + self.assertEqual(memoryview(y).strides, (4,)) + + negative_stride = np.arange(5, dtype=np.float32)[::-1] + with self.assertRaises(ValueError): + mx.from_dlpack(negative_stride) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_torch_mps_dlpack_non_cpu_error(self): + def test_torch_mps_dlpack_import(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) self.assertEqual(x.__dlpack_device__()[0], 8) - with self.assertRaisesRegex(ValueError, "non-CPU DLPack"): - mx.array(x) + y = mx.asarray(x) + self.assertEqual(y.dtype, mx.float32) + torch.mps.synchronize() + self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) + self.assertIn("array(", repr(y)) + mv = memoryview(y) + self.assertEqual(mv.tolist(), x.cpu().numpy().tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_array_copies_dlpack_input(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.array(x) + + x.zero_() + torch.mps.synchronize() + self.assertEqual(y.tolist(), [0.0, 1.0, 2.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_asarray_copy_true_copies_dlpack_input(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.asarray(x, copy=True) + + x.zero_() + torch.mps.synchronize() + self.assertEqual(y.tolist(), [0.0, 1.0, 2.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_zero_copy_shares_updates(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + torch.mps.synchronize() + y = mx.asarray(x) + + x.zero_() + torch.mps.synchronize() + self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) + + y += 10 + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), y.tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_matching_dtype_argument_shares_updates(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + torch.mps.synchronize() + y = mx.asarray(x, dtype=mx.float32, copy=False) + self.assertEqual(y.dtype, mx.float32) + + x.zero_() + torch.mps.synchronize() + self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_different_dtype_argument_copies(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + torch.mps.synchronize() + z = mx.asarray(x, dtype=mx.float16) + expected = x.to(torch.float16).cpu().numpy().tolist() + + self.assertEqual(z.dtype, mx.float16) + self.assertEqual(z.tolist(), expected) + + x.zero_() + torch.mps.synchronize() + self.assertEqual(z.tolist(), expected) + + with self.assertRaises(ValueError): + mx.asarray(x, dtype=mx.float16, copy=False) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_data_offset(self): + view = torch.arange(12, device="mps", dtype=torch.float32)[3:9] + view_mx = mx.asarray(view) + torch.mps.synchronize() + self.assertEqual(view_mx.tolist(), view.cpu().numpy().tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_strided_view(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + view = x.T + torch.mps.synchronize() + y = mx.asarray(view, copy=False) + self.assertEqual(y.tolist(), view.cpu().numpy().tolist()) + + x[0, 1] = 99 + torch.mps.synchronize() + self.assertEqual(y.tolist(), view.cpu().numpy().tolist()) + + y_copy = mx.asarray(view, copy=True) + expected = view.cpu().numpy().tolist() + self.assertTrue(memoryview(y_copy).c_contiguous) + self.assertEqual(memoryview(y_copy).strides, (12, 4)) + x[0, 2] = 77 + torch.mps.synchronize() + self.assertEqual(y_copy.tolist(), expected) + + z = mx.asarray(view, dtype=mx.float16) + self.assertEqual(z.dtype, mx.float16) + self.assertTrue(memoryview(z).c_contiguous) + self.assertEqual(memoryview(z).strides, (6, 2)) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_stepped_view(self): + x = torch.arange(20, device="mps", dtype=torch.int32) + view = x[2:10:2] + torch.mps.synchronize() + y = mx.asarray(view, copy=False) + self.assertEqual(y.tolist(), [2, 4, 6, 8]) + + x[4] = 99 + torch.mps.synchronize() + self.assertEqual(y.tolist(), [2, 99, 6, 8]) + + y_copy = mx.asarray(view, copy=True) + expected = y.tolist() + x[6] = 77 + torch.mps.synchronize() + self.assertEqual(y_copy.tolist(), expected) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_broadcast_stride(self): + x = torch.tensor([7], device="mps", dtype=torch.int32) + view = x.expand(3) + torch.mps.synchronize() + y = mx.asarray(view, copy=False) + self.assertEqual(y.tolist(), [7, 7, 7]) + + x.zero_() + torch.mps.synchronize() + self.assertEqual(y.tolist(), [0, 0, 0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_bfloat16(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + bf = x.to(torch.bfloat16) + bf_mx = mx.asarray(bf) + + self.assertEqual(bf_mx.dtype, mx.bfloat16) + torch.mps.synchronize() + self.assertEqual( + bf_mx.astype(mx.float32).tolist(), + bf.to(torch.float32).cpu().numpy().tolist(), + ) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_array_operand(self): a = mx.array([1]) b = torch.tensor([2]) self.assertTrue(mx.array_equal(a + b, mx.array([3]))) - with self.assertRaisesRegex(ValueError, "non-CPU DLPack"): - a + b.to("mps") + b_mps = b.to("mps") + torch.mps.synchronize() + self.assertTrue(mx.array_equal(a + b_mps, mx.array([3]))) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_mlx_dlpack_exports_mps_tensor_to_torch(self): + x = mx.array([1]).astype(mx.float16) + mx.eval(x) + y = torch.utils.dlpack.from_dlpack(x) + torch.mps.synchronize() + + self.assertEqual(y.device.type, "mps") + self.assertEqual(y.dtype, torch.float16) + self.assertEqual(y.cpu().numpy().tolist(), [1.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_mlx_dlpack_exports_mps_tensor_to_torch_tensor(self): + x = mx.array([1]).astype(mx.float16) + mx.eval(x) + y = torch.tensor(x) + torch.mps.synchronize() + + self.assertEqual(y.device.type, "mps") + self.assertEqual(y.dtype, torch.float16) + self.assertEqual(y.cpu().numpy().tolist(), [1.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_mlx_dlpack_export_torch_update_writes_mlx_buffer(self): + x = mx.arange(8, dtype=mx.float32) + y = x[2:6] + mx.eval(y) + t = torch.utils.dlpack.from_dlpack(y) + + self.assertEqual(t.device.type, "mps") + self.assertEqual(t.cpu().numpy().tolist(), [2.0, 3.0, 4.0, 5.0]) + + t.zero_() + torch.mps.synchronize() + self.assertEqual(y.tolist(), [0.0, 0.0, 0.0, 0.0]) + self.assertEqual(x.tolist(), [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 6.0, 7.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_from_dlpack_torch_mps_copy_none_shares_updates(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.from_dlpack(x) + + x.zero_() + torch.mps.synchronize() + self.assertEqual(y.tolist(), [0.0, 0.0, 0.0]) + + y += 10 + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), [10.0, 10.0, 10.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_from_dlpack_torch_mps_copy_false_shares_updates(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.from_dlpack(x, copy=False) + + x.zero_() + torch.mps.synchronize() + self.assertEqual(y.tolist(), [0.0, 0.0, 0.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_from_dlpack_torch_mps_copy_true_copies(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.from_dlpack(x, copy=True) + + x.zero_() + torch.mps.synchronize() + self.assertEqual(y.tolist(), [0.0, 1.0, 2.0]) def test_getitem_with_list(self): a = mx.array([1, 2, 3, 4, 5]) @@ -2182,9 +2453,10 @@ def test_asarray_copy(self): self.assertEqual( mx.asarray(existing, dtype=mx.float32, copy=True).dtype, mx.float32 ) - with self.assertRaises(ValueError): mx.asarray(existing, copy=False) + with self.assertRaises(ValueError): + mx.asarray(existing, dtype=mx.float32, copy=False) def test_asarray(self): # List inputs @@ -2208,17 +2480,27 @@ def test_asarray(self): # MLX array inputs arr = mx.array([1, 2, 3]) self.assertEqual(mx.asarray(arr).tolist(), [1, 2, 3]) + self.assertEqual(mx.asarray(arr, copy=True).tolist(), [1, 2, 3]) + with self.assertRaises(ValueError): + mx.asarray(arr, copy=False) arr_int = mx.array([1, 2, 3], dtype=mx.int32) arr_float = mx.asarray(arr_int, dtype=mx.float32) self.assertEqual(arr_float.dtype, mx.float32) self.assertEqual(arr_float.tolist(), [1.0, 2.0, 3.0]) + with self.assertRaises(ValueError): + mx.asarray(arr_int, dtype=mx.float32, copy=False) # NumPy array inputs np_arr = np.array([1.0, 2.0, 3.0], dtype=np.float32) mx_arr = mx.asarray(np_arr) self.assertEqual(mx_arr.tolist(), [1.0, 2.0, 3.0]) self.assertEqual(mx_arr.dtype, mx.float32) + with self.assertRaises(ValueError): + mx.asarray(np_arr, copy=False) + + with self.assertRaises(ValueError): + mx.asarray([1, 2, 3], copy=False) # dtype parameter self.assertEqual(mx.asarray([1, 2, 3], dtype=mx.float32).dtype, mx.float32) diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 9c03530b04..0bb190bd99 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -398,12 +398,14 @@ def g(x): x = mx.random.normal(sh) else: x = mx.random.normal((*sh, 2)).view(mx.complex64).squeeze() + mx.eval(x) + x_torch = torch.tensor(x, device="cpu") fx = f(x) - gx = g(torch.tensor(x)) + gx = g(x_torch) self.assertLess((fx - gx).abs().max() / gx.abs().mean(), 1e-4) dfdx = mx.grad(f)(x) - dgdx = torch.func.grad(g)(torch.tensor(x)) + dgdx = torch.func.grad(g)(x_torch) self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4)