Skip to content

Commit 174c1fb

Browse files
committed
Honor gpu_id in process_slices
1 parent ddf7687 commit 174c1fb

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/arraybridge/slice_processing.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
from arraybridge.converters import detect_memory_type
99
from arraybridge.stack_utils import stack_slices, unstack_slices
10+
from arraybridge.utils import _get_device_id
1011

1112

12-
def process_slices(image, func, args, kwargs):
13+
def process_slices(image, func, args, kwargs, gpu_id=None):
1314
"""
1415
Process a 3D array slice-by-slice using the provided function.
1516
@@ -25,14 +26,18 @@ def process_slices(image, func, args, kwargs):
2526
func: Function to apply to each slice
2627
args: Positional arguments to pass to func
2728
kwargs: Keyword arguments to pass to func
29+
gpu_id: Optional GPU device ID override. If not provided, attempts
30+
to derive from the input image and falls back to 0.
2831
2932
Returns:
3033
Processed 3D array, or tuple of (processed_3d_array, special_outputs...)
3134
if func returns tuples
3235
"""
3336
# Detect memory type and use proper OpenHCS utilities
3437
memory_type = detect_memory_type(image)
35-
gpu_id = 0 # Default GPU ID for slice processing
38+
if gpu_id is None:
39+
detected_gpu_id = _get_device_id(image, memory_type)
40+
gpu_id = 0 if detected_gpu_id is None else detected_gpu_id
3641

3742
# Unstack 3D array into 2D slices
3843
slices_2d = unstack_slices(image, memory_type, gpu_id)

0 commit comments

Comments
 (0)