77
88from arraybridge .converters import detect_memory_type
99from arraybridge .stack_utils import stack_slices , unstack_slices
10+ from arraybridge .utils import _get_device_id
1011
1112
12- def process_slices (image , func , args , kwargs ):
13+ def process_slices (image , func , args , kwargs , gpu_id = None ):
1314 """
1415 Process a 3D array slice-by-slice using the provided function.
1516
@@ -25,14 +26,18 @@ def process_slices(image, func, args, kwargs):
2526 func: Function to apply to each slice
2627 args: Positional arguments to pass to func
2728 kwargs: Keyword arguments to pass to func
29+ gpu_id: Optional GPU device ID override. If not provided, attempts
30+ to derive from the input image and falls back to 0.
2831
2932 Returns:
3033 Processed 3D array, or tuple of (processed_3d_array, special_outputs...)
3134 if func returns tuples
3235 """
3336 # Detect memory type and use proper OpenHCS utilities
3437 memory_type = detect_memory_type (image )
35- gpu_id = 0 # Default GPU ID for slice processing
38+ if gpu_id is None :
39+ detected_gpu_id = _get_device_id (image , memory_type )
40+ gpu_id = 0 if detected_gpu_id is None else detected_gpu_id
3641
3742 # Unstack 3D array into 2D slices
3843 slices_2d = unstack_slices (image , memory_type , gpu_id )
0 commit comments