diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 198a4861d9..1904f20ba3 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1795,6 +1795,48 @@ void init_ops(nb::module_& m) { Raises: ValueError: If ``copy`` is ``False``. )pbdoc"); + m.def( + "from_dlpack", + [](const nb::object& x, + const nb::object& device, + std::optional copy) { + if (!device.is_none()) { + throw std::invalid_argument( + "[from_dlpack] device keyword is not supported."); + } + if (copy.has_value() && !*copy) { + throw std::invalid_argument( + "[from_dlpack] copy=False is not supported."); + } + return create_array(x, std::nullopt); + }, + nb::arg(), + nb::kw_only(), + "device"_a = nb::none(), + "copy"_a = nb::none(), + nb::sig( + "def from_dlpack(x: DLPackCompatible, /, *, device: Optional[Any] = None, copy: Optional[bool] = None) -> array"), + R"pbdoc( + Construct an mlx array from a DLPack-compatible object. + + The input must implement the ``__dlpack__`` and ``__dlpack_device__`` + methods as described in the + `Python array API `_. + + Args: + x: Input DLPack-compatible object. + device: Must be ``None``. Specifying a target device is not + supported. + copy (bool, optional): Must be ``True`` or unspecified. ``False`` + is not supported, since MLX has no in-place operations and + cannot return a non-copying view. + + Returns: + array: An mlx array constructed from the input. + + Raises: + ValueError: If ``device`` is not ``None`` or ``copy`` is ``False``. + )pbdoc"); m.def( "zeros_like", &mx::zeros_like, diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 8258b38c94..bd4e4f3993 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2175,6 +2175,22 @@ def test_array_namespace_asarray(self): arr_pass = xp.asarray(existing) self.assertEqual(arr_pass.tolist(), [4, 5, 6]) + def test_array_namespace_from_dlpack(self): + xp = mx.array(1.0).__array_namespace__() + self.assertTrue(hasattr(xp, "from_dlpack")) + + existing = mx.array([1, 2, 3]) + arr = xp.from_dlpack(existing) + self.assertIsInstance(arr, mx.array) + self.assertEqual(arr.tolist(), [1, 2, 3]) + + # device and copy=False are not supported + with self.assertRaises(ValueError): + xp.from_dlpack(existing, copy=False) + + with self.assertRaises(ValueError): + xp.from_dlpack(existing, device="cpu") + def test_asarray_copy(self): existing = mx.array([1, 2, 3])