Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cuda_bindings/tests/cufile.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
// e.g : export CUFILE_ENV_PATH_JSON="/home/<xxx>/cufile.json"


"properties" : {
"allow_compat_mode" : true
},

"execution" : {
// max number of workitems in the queue;
"max_io_queue_depth": 128,
Expand Down
55 changes: 39 additions & 16 deletions cuda_bindings/tests/test_cufile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import platform
import subprocess
import tempfile
from contextlib import suppress
from contextlib import contextmanager, suppress
from functools import cache

import pytest
Expand All @@ -28,6 +28,16 @@
cufile = pytest.importorskip("cuda.bindings.cufile", reason="skipping tests on Windows")


@contextmanager
def _cufile_driver_session():
"""Open the cuFile driver for a block; always close in a finally (mirrors try/finally)."""
cufile.driver_open()
try:
yield
finally:
cufile.driver_close()


@pytest.fixture
def cufile_env_json(monkeypatch):
"""Set CUFILE_ENV_PATH_JSON environment variable for async tests."""
Expand Down Expand Up @@ -1422,7 +1432,7 @@ def test_batch_io_large_operations():
@pytest.mark.skipif(
cufileVersionLessThan(1140), reason="cuFile parameter APIs require cuFile library version 1.14.0 or later"
)
@pytest.mark.usefixtures("ctx")
@pytest.mark.usefixtures("ctx", "cufile_env_json")
def test_set_get_parameter_size_t():
"""Test setting and getting size_t parameters with cuFile validation."""
param_val_pairs = (
Expand All @@ -1439,8 +1449,13 @@ def test_set_get_parameter_size_t():
(cufile.SizeTConfigParameter.EXECUTION_MAX_REQUEST_PARALLELISM, 4), # Max 4 parallel requests
)

# Snapshot baselines after driver_open so getters reflect merged config (defaults + JSON),
# not pre-open pending state that could restore invalid values (e.g. 0 for per-buffer cache).
with _cufile_driver_session():
originals = {param: cufile.get_parameter_size_t(param) for param, _ in param_val_pairs}

Comment thread
rsarpangalav marked this conversation as resolved.
def test_param(param, val):
orig_val = cufile.get_parameter_size_t(param)
orig_val = originals[param]
cufile.set_parameter_size_t(param, val)
retrieved_val = cufile.get_parameter_size_t(param)
assert retrieved_val == val
Expand All @@ -1454,9 +1469,11 @@ def test_param(param, val):
@pytest.mark.skipif(
cufileVersionLessThan(1140), reason="cuFile parameter APIs require cuFile library version 1.14.0 or later"
)
@pytest.mark.usefixtures("ctx")
@pytest.mark.usefixtures("ctx", "cufile_env_json")
def test_set_get_parameter_bool():
"""Test setting and getting boolean parameters with cuFile validation."""
# Load the compat-enabled test config before the first driver_open so the compat
# bool params can still be round-tripped on systems without nvidia-fs.
param_val_pairs = (
(cufile.BoolConfigParameter.PROPERTIES_USE_POLL_MODE, True),
(cufile.BoolConfigParameter.PROPERTIES_ALLOW_COMPAT_MODE, False),
Expand All @@ -1471,28 +1488,29 @@ def test_set_get_parameter_bool():
(cufile.BoolConfigParameter.SKIP_TOPOLOGY_DETECTION, False),
(cufile.BoolConfigParameter.STREAM_MEMOPS_BYPASS, True),
)
# PROFILE_NVTX is deprecated (CTK 13.1.0+); cuFile >= 1.16 rejects bool getters for it.
if cufile.get_version() >= 1160:
param_val_pairs = tuple((p, v) for p, v in param_val_pairs if p is not cufile.BoolConfigParameter.PROFILE_NVTX)

with _cufile_driver_session():
originals = {param: cufile.get_parameter_bool(param) for param, _ in param_val_pairs}

def test_param(param, val):
orig_val = cufile.get_parameter_bool(param)
orig_val = originals[param]
cufile.set_parameter_bool(param, val)
retrieved_val = cufile.get_parameter_bool(param)
assert retrieved_val is val
cufile.set_parameter_bool(param, orig_val)

try:
# Test setting and getting various boolean parameters
for param, val in param_val_pairs:
test_param(param, val)
except cufile.cuFileError:
if cufile.get_version() < 1160:
raise
assert param is cufile.BoolConfigParameter.PROFILE_NVTX # Deprecated in CTK 13.1.0
# Test setting and getting various boolean parameters
for param, val in param_val_pairs:
test_param(param, val)


@pytest.mark.skipif(
cufileVersionLessThan(1140), reason="cuFile parameter APIs require cuFile library version 1.14.0 or later"
)
@pytest.mark.usefixtures("ctx")
@pytest.mark.usefixtures("ctx", "cufile_env_json")
def test_set_get_parameter_string(tmp_path):
"""Test setting and getting string parameters with cuFile validation."""
temp_dir = tempfile.gettempdir()
Expand All @@ -1513,8 +1531,11 @@ def test_set_get_parameter_string(tmp_path):
), # Test log directory
)

with _cufile_driver_session():
originals = {param: cufile.get_parameter_string(param, 256) for param, _, _ in param_val_pairs}

def test_param(param, val, default_val):
orig_val = cufile.get_parameter_string(param, 256)
orig_val = originals[param]

val_b = val.encode("utf-8")
val_buf = ctypes.create_string_buffer(val_b)
Expand Down Expand Up @@ -1951,7 +1972,9 @@ def test_set_parameter_posix_pool_slab_array(slab_sizes, slab_counts, driver_con
retrieved_sizes_addr = ctypes.addressof(retrieved_sizes)
retrieved_counts_addr = ctypes.addressof(retrieved_counts)

cufile.get_parameter_posix_pool_slab_array(retrieved_sizes_addr, retrieved_counts_addr, n_slab_sizes)
# Open cuFile driver AFTER setting parameters
with _cufile_driver_session():
cufile.get_parameter_posix_pool_slab_array(retrieved_sizes_addr, retrieved_counts_addr, n_slab_sizes)

# Verify they match what we set
assert list(retrieved_sizes) == slab_sizes
Expand Down
Loading