Add __array__ and DLPack protocols to OrtValue#27980
Add __array__ and DLPack protocols to OrtValue#27980tianleiwu merged 3 commits intomicrosoft:mainfrom
Conversation
…alue Expose standard Python protocols on the public OrtValue class so it interoperates with numpy, PyTorch, JAX, and other frameworks without requiring users to reach into the private _ortvalue attribute. - __array__(dtype, copy): supports np.asarray(ortvalue) and np.array() following the numpy array protocol, with numpy 2.0 copy semantics. - __dlpack__(stream): returns a DLPack capsule for zero-copy sharing. - __dlpack_device__(): returns (device_type, device_id) tuple. - from_dlpack(data): classmethod accepting any __dlpack__-compatible object or raw capsule, with automatic bool dtype detection from the source object's dtype to avoid the uint8/bool ambiguity in older DLPack versions. Fixes microsoft#24071
There was a problem hiding this comment.
Pull request overview
This PR adds standard Python interoperability protocols to the public onnxruntime.OrtValue wrapper so it can participate in NumPy’s __array__ protocol and the DLPack exchange protocol without users reaching into private _ortvalue bindings (addressing #24071).
Changes:
- Add
OrtValue.__array__to supportnp.asarray(ortvalue)/np.array(ortvalue)with optional dtype conversion and NumPy 2.0copysemantics. - Expose DLPack protocol methods on
OrtValue:__dlpack__,__dlpack_device__, andfrom_dlpackwith bool dtype auto-detection for protocol objects. - Add Python unit tests covering
__array__, public DLPack methods,from_dlpackobject/capsule inputs, and bool vs uint8 behavior.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| onnxruntime/python/onnxruntime_inference_collection.py | Adds __array__ and DLPack protocol wrappers (__dlpack__, __dlpack_device__, from_dlpack) to the public OrtValue class. |
| onnxruntime/test/python/onnxruntime_test_python.py | Adds tests validating NumPy array protocol behavior and public DLPack interop on OrtValue. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…sertions Use np.asarray instead of np.array for the dtype-only branch in __array__, so that when the requested dtype already matches the underlying buffer, no unnecessary copy is made. Add data_ptr/ctypes.data assertions in the __array__ tests to verify zero-copy behavior for both the default and same-dtype paths.
tianleiwu
left a comment
There was a problem hiding this comment.
Code Review — PR #27980
Summary: Clean addition of __array__, __dlpack__, __dlpack_device__, and from_dlpack to the public OrtValue class. The wrapper pattern correctly delegates to the C-level bindings. Tests cover the main use cases including bool/uint8 disambiguation.
Python Interop Surface (onnxruntime_inference_collection.py)
Positive: The DLPack protocol wrapper is thin and correct — it forwards to proven C-level implementations without adding unnecessary logic.
Positive: Boolean auto-detection in from_dlpack addresses a real DLPack limitation (bool → uint8 encoding). Inspecting the source object's dtype before consuming the capsule is the right approach.
See inline comments for two suggestions.
Tests (onnxruntime_test_python.py)
Positive: Good coverage across float32, int64, bool, and uint8 dtypes. The uint8-not-falsely-detected-as-bool test is particularly important.
Positive: Shared-memory assertion via data_ptr() comparison in test_ort_value_from_dlpack_protocol_object directly validates zero-copy semantics.
str(data.dtype) returns framework-specific representations that don't always match simple string comparisons (e.g. TensorFlow returns "<dtype: 'bool'>"). Use the .name attribute when available (numpy, cupy, tensorflow) and fall back to str() for frameworks like PyTorch where .name is absent.
|
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
Summary
__array__,__dlpack__,__dlpack_device__, andfrom_dlpackto the publicOrtValueclassOrtValuefrom_dlpackto avoid the uint8/bool ambiguity in older DLPack versionsMotivation
Fixes #24071
The C-level
C.OrtValuealready supports__dlpack__,__dlpack_device__, andfrom_dlpack, but the public Python wrapperOrtValueclass does not expose them. Users currently have to access the private_ortvalueattribute (e.g.ortvalue._ortvalue.__dlpack__()) for DLPack interop. Similarly,np.asarray(ortvalue)doesn't work because__array__is not implemented.This makes
OrtValuea well-behaved tensor type that works out of the box with:np.asarray(ortvalue)/np.array(ortvalue)via__array__torch.from_dlpack(ortvalue)via__dlpack__/__dlpack_device__OrtValue.from_dlpack(torch_tensor)via thefrom_dlpackclassmethodChanges
onnxruntime/python/onnxruntime_inference_collection.py:__array__(dtype, copy): Delegates toself.numpy()with optional dtype conversion. Supports numpy 2.0copysemantics while remaining compatible with older numpy versions.__dlpack__(*, stream): Thin wrapper over the C-level__dlpack__.__dlpack_device__(): Thin wrapper over the C-level__dlpack_device__.from_dlpack(data): Classmethod that accepts any__dlpack__-compatible object or raw DLPack capsule. Detects boolean dtype from the source object'sdtypeattribute ordata_type()method, avoiding the uint8/bool false-positive thatis_dlpack_uint8_tensorwould produce on genuine uint8 data.onnxruntime/test/python/onnxruntime_test_python.py:test_ort_value_array_protocol: Testsnp.asarray/np.arraywith float32, int64, bool dtypes, and dtype conversion.test_ort_value_dlpack_protocol: Tests__dlpack__and__dlpack_device__on the public class.test_ort_value_from_dlpack_protocol_object: Testsfrom_dlpackwith numpy arrays and OrtValue-to-OrtValue round-trip, verifying zero-copy (shared memory).test_ort_value_from_dlpack_bool: Tests bool round-trip and verifies uint8 is not falsely detected as bool.Test Plan
ruff checkpasses on both modified filesruff format --checkpasses on both modified fileslintrunnerreports no issuestest_ort_value_dlpacktest continues to pass