Skip to content

Add Metal DLPack zero-copy sharing#3531

Open
XXXXRT666 wants to merge 24 commits into
ml-explore:mainfrom
XXXXRT666:metal-dlpack-zero-copy-draft
Open

Add Metal DLPack zero-copy sharing#3531
XXXXRT666 wants to merge 24 commits into
ml-explore:mainfrom
XXXXRT666:metal-dlpack-zero-copy-draft

Conversation

@XXXXRT666
Copy link
Copy Markdown
Contributor

@XXXXRT666 XXXXRT666 commented May 11, 2026

Proposed changes

This draft adds zero-copy Metal DLPack sharing for MLX arrays and PyTorch MPS tensors.

This PR builds on the merged DLPack import PR #3495 and requires nanobind support.

The main changes are:

  • Import Metal DLPack arrays by wrapping the underlying Metal buffer instead of copying through CPU.
  • Export MLX arrays to Metal DLPack using the MLX Metal buffer and DLPack byte_offset.
  • Add mx.from_dlpack(..., copy=...) controls for Metal DLPack inputs.
  • Keep mx.array(...) zero-copy for Metal DLPack inputs unless an explicit different dtype is requested.
  • Document the explicit synchronization requirements between PyTorch MPS and MLX.

The shared lifetime is tied to the exported or imported buffer. Synchronization remains explicit: PyTorch writes require torch.mps.synchronize() before MLX reads, and MLX writes require mx.eval(...) before PyTorch reads.

For MLX arrays exported to PyTorch, later MLX updates may rebind the MLX array to a new buffer while the PyTorch tensor continues to reference the exported buffer.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@megacpp
Copy link
Copy Markdown

megacpp commented May 13, 2026

Hi @XXXXRT666 — read through this PR after @awni redirected us here from #3548. The nb::ndarray<nb::ro, nb::c_contig> approach over the in-flight nanobind PR (#1338) is materially cleaner than the manual capsule parsing we had in our downstream PoC, and lifting is_host_accessible() into mlx/allocator.h is the right level of abstraction. Closed the RFC; happy with this being the path forward for #2848.

Wanted to offer some testing help that complements the PyTorch MPS bring-up you have:

We maintain a downstream TileLang fork (https://github.com/DatasunriseOU/tilelang) whose TVM-FFI bridge exports kDLMetal DLPack capsules for tensors backed by id<MTLBuffer>. That gives a non-PyTorch Metal DLPack producer that exercises the same import path you're adding here. Specifically it covers:

  • kDLMetal producers that do not require the PyTorch-MPS workaround for __dlpack__ — exercises the import path directly.
  • Round-trip mx.array → DLPack → TVM-FFI Metal kernel → DLPack → mx.array zero-copy.
  • Custom Metal kernels (via mlx.fast.metal_kernel) consuming an imported mx.array whose underlying MTLBuffer was allocated outside MLX.
  • storageMode matrix (we hit Shared and Managed; Private is the obvious edge case the spec needs to nail down — your is_host_accessible() decision likely answers this implicitly but worth a sanity check from the producer side).
  • byte_offset != 0 cases that we'd previously rejected outright in our PoC — your PR seems to handle these via byte_offset-aware import; happy to write a TileLang-side test for it.

If useful, once the PR converges I can:

  1. Pull this branch into our TileLang test matrix and report back on any rough edges (CI on macOS Metal hardware).
  2. Send a minimal standalone repro (no TileLang dependency) for any of the above scenarios if you'd like them as additions to python/tests/test_array.py.
  3. Beta-test the mx.from_dlpack(..., copy=...) semantics against the dtype-mismatch-shares case (002360faa) once the API stabilizes.

Tag me here when you'd like input — no rush, just don't want this to slip past once it's review-ready.

(For the orthogonal mx.empty() piece that was also in our PoC, opened it as a separate issue per @awni's guidance.)

megacpp pushed a commit to DatasunriseOU/mlx that referenced this pull request May 13, 2026
Required by ml-explore/mlx PR ml-explore#3531 (Metal DLPack zero-copy sharing).
SHA: 33f52e635db5e6229060481d16a167230a1a474b
PR:   wjakob/nanobind#1338
Branch: metal-dlpack-cast
@XXXXRT666 XXXXRT666 force-pushed the metal-dlpack-zero-copy-draft branch from 002360f to 4e16f1d Compare May 14, 2026 04:39
@McPatate
Copy link
Copy Markdown

This would be super cool if it landed for end to end "0-copy" support in safetensors! I'm working (safetensors/safetensors#767) on adding reading bytes from disk in raw MTLBuffers, which can then be handed to the framework via dlpack with 0-copy. Works well with torch, would be happy to see that land in mlx!

Also, support for byte_offset !=0 would be nice (already in the PR but commenting to notify it's useful) since we can go one step further: currently the mps path is pread -> MTLBuffer, but that goes through kernel pages before hitting userspace buffer. Having byte_offset non zero support would enable mmap-ing the file and creating MTLBuffers that reference specific slices of the mmap, which would demand-fault pages from disk into the page cache on first access and give userspace access directly, leaving only the disk -> kernel-side copy.

Quick question on the dl_tensor.data convention, torch's mps treats it as id<MTLBuffer>, passing the contents segfaults. Curious to know which direction MLX will be taking, as it impacts us downstream!

megacpp pushed a commit to DatasunriseOU/mlx that referenced this pull request May 14, 2026
megacpp pushed a commit to DatasunriseOU/mlx that referenced this pull request May 14, 2026
Required by ml-explore/mlx PR ml-explore#3531 (Metal DLPack zero-copy sharing).
SHA: 33f52e635db5e6229060481d16a167230a1a474b
PR:   wjakob/nanobind#1338
Branch: metal-dlpack-cast
@XXXXRT666
Copy link
Copy Markdown
Contributor Author

Quick question on the dl_tensor.data convention, torch's mps treats it as id<MTLBuffer>, passing the contents segfaults. Curious to know which direction MLX will be taking, as it impacts us downstream!

https://dmlc.github.io/dlpack/latest/c_api.html#c.DLTensor.data

The data pointer points to the allocated data. This will be CUDA device pointer, cl_mem handle in OpenCL, or id<MTLBuffer> for Metal.

@XXXXRT666 XXXXRT666 force-pushed the metal-dlpack-zero-copy-draft branch from 4e16f1d to a17cd99 Compare May 19, 2026 07:44
@XXXXRT666 XXXXRT666 marked this pull request as ready for review May 19, 2026 08:40
Comment thread docs/src/usage/numpy.rst Outdated
Comment thread mlx/backend/cuda/allocator.cpp
Comment thread mlx/backend/metal/allocator.cpp
Comment thread CMakeLists.txt
Comment thread python/src/convert.cpp
Comment thread docs/src/usage/numpy.rst Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
@XXXXRT666
Copy link
Copy Markdown
Contributor Author

One API question: should mx.array(...) always copy DLPack inputs, and should zero-copy / copy control live in mx.asarray(..., copy=...) instead?

That would match the mental model used by NumPy/PyTorch more closely: array creates a new array, while asarray may avoid a copy depending on copy. In that design, mx.from_dlpack(..., copy=...) could remain the explicit DLPack entry point, while mx.array(torch_mps_tensor) would not unexpectedly share the underlying Metal buffer by default.

Comment thread python/src/array.cpp Outdated
Comment thread python/src/buffer.h Outdated
Comment thread python/src/convert.h Outdated
Comment thread python/src/ops.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread mlx/ops.cpp Outdated
Comment thread mlx/ops.cpp Outdated
Comment thread python/src/mlx_func.cpp
Comment thread python/src/convert.cpp Outdated
@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented May 23, 2026

One API question: should mx.array(...) always copy DLPack inputs, and should zero-copy / copy control live in mx.asarray(..., copy=...) instead?

That would match the mental model used by NumPy/PyTorch more closely: array creates a new array, while asarray may avoid a copy depending on copy. In that design, mx.from_dlpack(..., copy=...) could remain the explicit DLPack entry point, while mx.array(torch_mps_tensor) would not unexpectedly share the underlying Metal buffer by default.

I think this is very good design.

@XXXXRT666
Copy link
Copy Markdown
Contributor Author

mlx 0.31.2

mlx_to_torch  (median µs / call, lower is better)
  shape               current → mps   dlpack → mps   dlpack stay-on-cpu
  16K  f32                      378            191                   16
  1M   f32                      404            402                   17
  16M  f32 (64MB)              2717           2501                   16
  1M  bf16                      276 n/a (TypeError)      n/a (TypeError)
  16M bf16 (32MB)              1302 n/a (TypeError)      n/a (TypeError)

torch_to_mlx  (median µs / call, lower is better)
  shape                current (.cpu+numpy)      dlpack (from MPS)
  16K  f32                              175                    172
  1M   f32                              691                    666
  16M  f32 (64MB)                      8216                   8470
  1M  bf16                              522                    422
  16M bf16 (32MB)                      2023                   2037

mlx 0.32.0.dev20260523+4e8decde9

mlx_to_torch  (median µs / call, lower is better)
  shape               current → mps   dlpack → mps   dlpack stay-on-cpu
  16K  f32                      237             18                   15
  1M   f32                      371             16                   17
  16M  f32 (64MB)              2688             15                   15
  1M  bf16                      279             19                   16
  16M bf16 (32MB)              1340             16                   14

torch_to_mlx  (median µs / call, lower is better)
  shape                current (.cpu+numpy)      dlpack (from MPS)
  16K  f32                              189                     16
  1M   f32                              676                     16
  16M  f32 (64MB)                      8806                     17
  1M  bf16                              484                     16
  16M bf16 (32MB)                      2106                     16
benchmark function
# --- mlx -> torch candidates ---------------------------------------------------


def mlx_to_torch_current(arr: mx.array, device: torch.device) -> torch.Tensor:
    arr = mx.contiguous(arr)
    mx.eval(arr)
    buf = memoryview(arr)
    dtype_map = {
        mx.float32: torch.float32,
        mx.float16: torch.float16,
        mx.bfloat16: torch.bfloat16,
    }
    t = torch.frombuffer(buf, dtype=dtype_map[arr.dtype]).reshape(arr.shape)
    if device.type == "mps":
        t = t.to(device)
    return t


def mlx_to_torch_dlpack_mps(arr: mx.array, device: torch.device) -> torch.Tensor:
    mx.eval(arr)
    t = torch.from_dlpack(arr)
    if device.type == "mps":
        t = t.to(device)
    return t


def mlx_to_torch_dlpack_cpu(arr: mx.array, device: torch.device) -> torch.Tensor:
    """Force a CPU-typed capsule via `dl_device=(kDLCPU, 0)` (Phase 2+).
    Falls back to the no-kwarg form for builds that don't accept it."""
    mx.eval(arr)
    try:
        cap = arr.__dlpack__(dl_device=(1, 0))
    except TypeError:
        # Older builds: zero-arg lambda. Capsule is already kDLCPU there.
        cap = arr.__dlpack__()
    return torch.from_dlpack(cap)


# --- torch -> mlx candidates ---------------------------------------------------


def torch_to_mlx_current(t: torch.Tensor) -> mx.array:
    if t.device.type != "cpu":
        t = t.cpu()
    t = t.detach()
    if t.dtype == torch.bfloat16:
        return mx.array(t)
    return mx.array(t.numpy())


def torch_to_mlx_dlpack(t: torch.Tensor) -> mx.array:
    """Use mx.from_dlpack when the API exists; old MLX falls back to CPU copy."""
    if hasattr(mx, "from_dlpack"):
        return mx.from_dlpack(t)
    return mx.array(t.detach().cpu())

@XXXXRT666
Copy link
Copy Markdown
Contributor Author

I think this is very good design.

I updated the PR to follow this design: mx.array(...) now copies DLPack inputs, while mx.asarray(..., copy=...) and mx.from_dlpack(..., copy=...) provide the copy-control paths.

@XXXXRT666 XXXXRT666 force-pushed the metal-dlpack-zero-copy-draft branch from 0607c24 to a4378cf Compare May 23, 2026 08:43
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This basically looks good to me, thanks for the nice work!

Comment thread mlx/backend/metal/allocator.cpp Outdated
Comment thread mlx/utils.cpp Outdated
Comment thread python/src/buffer.h
Comment thread mlx/backend/cuda/allocator.cpp Outdated
Comment thread mlx/array.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
@XXXXRT666
Copy link
Copy Markdown
Contributor Author

I just realized that nanobind does not handle c_contig arrays very well. For torch.Tensor, it calls contiguous(), but it does not perform torch.accelerator.synchronization afterwards.

@XXXXRT666 XXXXRT666 force-pushed the metal-dlpack-zero-copy-draft branch from 47326da to 4eaea96 Compare May 24, 2026 19:31
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
@XXXXRT666 XXXXRT666 force-pushed the metal-dlpack-zero-copy-draft branch from 4eaea96 to a7c4a44 Compare May 25, 2026 04:49
Comment thread python/src/convert.cpp
auto src = static_cast<const SrcT*>(nd_array.data());
auto dst = out.data<DstT>();
for (size_t i = 0; i < out.size(); ++i) {
dst[i] = static_cast<DstT>(src[strided_offset(i, shape, strides)]);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be really slow, would it be practical to preserve the original strides and and just do memcpy?

out.set_data(
    mx::allocator::malloc(data_size * itemsize),
    data_size,
    strides);
std::copy(src, src + data_size, dst);

Comment thread python/src/convert.cpp
if (copy && !import_flags.row_contiguous) {
// Force the copy primitive to materialize the virtual strided input into a
// row-contiguous output instead of preserving a dense non-row layout.
import_flags.contiguous = false;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you collaborate on this? If the input is truly contiguous we shouldn't need this for copying, otherwise the flag was set for non-contiguous input.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants