Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion cuda_core/cuda/core/_device_resources.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,18 @@ cdef class SMResourceOptions:
Preferred co-scheduled SM count; the driver tries to satisfy
this but may fall back to ``coscheduled_sm_count``.
(Default to ``None``)
backfill : bool or Sequence[bool], optional
If ``True``, allow the driver to relax the co-scheduling
constraint when assigning SMs. This enables requesting
arbitrary aligned SM counts that the driver would otherwise
reject due to hardware topology constraints.
(Default to ``False``)
"""

count: int | SequenceABC | None = None
coscheduled_sm_count: int | SequenceABC | None = None
preferred_coscheduled_sm_count: int | SequenceABC | None = None
backfill: bool | SequenceABC = False


@dataclass
Expand Down Expand Up @@ -172,6 +179,12 @@ cdef inline int _resolve_group_count(SMResourceOptions options) except?-1:
n_groups,
count_is_scalar,
)
_validate_split_field_length(
options.backfill,
"backfill",
n_groups,
count_is_scalar,
)
return n_groups


Expand Down Expand Up @@ -243,6 +256,7 @@ IF CUDA_CORE_BUILD_MAJOR >= 13:
cdef list counts = _broadcast_field(options.count, n_groups)
cdef list coscheduled = _broadcast_field(options.coscheduled_sm_count, n_groups)
cdef list preferred = _broadcast_field(options.preferred_coscheduled_sm_count, n_groups)
cdef list backfills = _broadcast_field(options.backfill, n_groups)
cdef int i

for i in range(n_groups):
Expand All @@ -252,7 +266,10 @@ IF CUDA_CORE_BUILD_MAJOR >= 13:
params[i].coscheduledSmCount = <unsigned int>(coscheduled[i])
if preferred[i] is not None:
params[i].preferredCoscheduledSmCount = <unsigned int>(preferred[i])
params[i].flags = 0
params[i].flags = (
cydriver.CUdevSmResourceGroup_flags.CU_DEV_SM_RESOURCE_GROUP_BACKFILL
if backfills[i] else 0
)
return 0


Expand Down
56 changes: 32 additions & 24 deletions cuda_core/tests/test_green_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,16 @@ def fill_kernel(init_cuda):
return mod.get_kernel("fill")


def _aligned_half(sm):
"""Compute half the SM count, rounded down to min_partition_size alignment."""
def _safe_two_group_count(sm):
"""Return a safe per-group SM count for a 2-group split.

Uses min_partition_size which is always a valid split size regardless
of hardware topology. Returns None if the device doesn't have enough SMs.
"""
min_size = sm.min_partition_size
half = (sm.sm_count // 2 // min_size) * min_size
return half
if sm.sm_count < 2 * min_size:
return None
return min_size


@contextlib.contextmanager
Expand Down Expand Up @@ -238,30 +243,33 @@ def test_discovery_respects_alignment(self, sm_resource):
assert groups[0].sm_count % sm_resource.coscheduled_alignment == 0

def test_two_groups(self, sm_resource):
"""Two-group split with explicit aligned counts."""
half = _aligned_half(sm_resource)
if half < sm_resource.min_partition_size:
"""Two-group split with min_partition_size (always topology-safe)."""
count = _safe_two_group_count(sm_resource)
if count is None:
pytest.skip("Not enough SMs for a 2-group split")

groups, rem = sm_resource.split(SMResourceOptions(count=(half, half)))
groups, rem = sm_resource.split(SMResourceOptions(count=(count, count)))

assert len(groups) == 2
assert groups[0].sm_count > 0
assert groups[1].sm_count > 0
assert groups[0].sm_count >= count
assert groups[1].sm_count >= count
total = groups[0].sm_count + groups[1].sm_count + rem.sm_count
assert total <= sm_resource.sm_count

def test_two_groups_each_meets_request(self, sm_resource):
min_size = sm_resource.min_partition_size
half = _aligned_half(sm_resource)
if half < min_size:
pytest.skip("Not enough SMs for a 2-group split")
def test_two_groups_backfill(self, sm_resource):
"""Two-group split with backfill allows larger partitions."""
align = sm_resource.coscheduled_alignment
if align == 0:
align = sm_resource.min_partition_size
half = (sm_resource.sm_count // 2 // align) * align
if half < sm_resource.min_partition_size:
pytest.skip("Not enough SMs for a 2-group backfill split")

groups, _ = sm_resource.split(SMResourceOptions(count=(min_size, min_size)))
groups, rem = sm_resource.split(SMResourceOptions(count=(half, half), backfill=True))

assert len(groups) == 2
assert groups[0].sm_count >= min_size
assert groups[1].sm_count >= min_size
assert groups[0].sm_count >= half
assert groups[1].sm_count >= half

def test_dry_run_matches_real(self, sm_resource):
"""Dry-run reports the same SM counts as a real split."""
Expand Down Expand Up @@ -352,11 +360,11 @@ def test_green_ctx_sm_resources(self, green_ctx, sm_resource):

def test_green_ctx_resources_reflect_partition(self, init_cuda, sm_resource):
"""Two green contexts should have disjoint SM partitions."""
half = _aligned_half(sm_resource)
if half < sm_resource.min_partition_size:
count = _safe_two_group_count(sm_resource)
if count is None:
pytest.skip("Not enough SMs for a 2-group split")

groups, _ = sm_resource.split(SMResourceOptions(count=(half, half)))
groups, _ = sm_resource.split(SMResourceOptions(count=(count, count)))

ctx_a = ctx_b = None
try:
Expand Down Expand Up @@ -425,11 +433,11 @@ def test_launch_and_verify(self, init_cuda, green_ctx, fill_kernel):
def test_two_green_contexts_independent(self, init_cuda, sm_resource, fill_kernel):
"""Two SM groups -> two green contexts -> two independent kernels."""
dev = init_cuda
half = _aligned_half(sm_resource)
if half < sm_resource.min_partition_size:
count = _safe_two_group_count(sm_resource)
if count is None:
pytest.skip("Not enough SMs for a 2-group split")

groups, _ = sm_resource.split(SMResourceOptions(count=(half, half)))
groups, _ = sm_resource.split(SMResourceOptions(count=(count, count)))
assert len(groups) == 2

ctx_a = ctx_b = None
Expand Down
Loading