forked from data-apis/array-api-compat
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_helpers.py
More file actions
34 lines (27 loc) · 1.15 KB
/
_helpers.py
File metadata and controls
34 lines (27 loc) · 1.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from importlib import import_module
import pytest
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
all_libraries = wrapped_libraries + [
"array_api_strict", "jax.numpy", "ndonnx", "sparse"
]
def import_(library, wrapper=False):
if library in ('cupy', 'ndonnx'):
pytest.importorskip(library)
if wrapper:
if 'jax' in library:
# JAX v0.4.32 implements the array API directly in jax.numpy
# Older jax versions use jax.experimental.array_api
jax_numpy = import_module("jax.numpy")
if not hasattr(jax_numpy, "__array_api_version__"):
library = 'jax.experimental.array_api'
elif library in wrapped_libraries:
library = 'array_api_compat.' + library
return import_module(library)
def xfail(request: pytest.FixtureRequest, reason: str) -> None:
"""
XFAIL the currently running test.
Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately
halting it, so that it may result in a XPASS.
xref https://github.com/pandas-dev/pandas/issues/38902
"""
request.node.add_marker(pytest.mark.xfail(reason=reason))