Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 100 additions & 4 deletions src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import platform
import logging
from typing import Callable, Optional, Sequence, Union
Expand All @@ -7,13 +8,61 @@

from .dataset import CellMapDataset
from .dataset_writer import CellMapDatasetWriter
from .image import CellMapImage
from .multidataset import CellMapMultiDataset
from .mutable_sampler import MutableSubsetRandomSampler
from .subdataset import CellMapSubset

logger = logging.getLogger(__name__)


def _set_tensorstore_context(dataset, context) -> None:
Comment thread
rhoadesScholar marked this conversation as resolved.
"""
Recursively set a TensorStore context on every CellMapImage in the dataset tree.

This must be called before workers are spawned so the bounded cache_pool
limit is picked up by every worker process (via fork inheritance on Linux,
or via pickle on Windows/macOS spawn).

If an image's TensorStore array has already been opened (``_array`` cached),
the new context cannot affect that array; a warning is emitted.
"""
if isinstance(dataset, CellMapMultiDataset):
for ds in dataset.datasets:
_set_tensorstore_context(ds, context)
elif isinstance(dataset, CellMapSubset):
_set_tensorstore_context(dataset.dataset, context)
elif isinstance(dataset, CellMapDataset):
dataset.context = context
all_sources = list(dataset.input_sources.values()) + list(
dataset.target_sources.values()
)
for source in all_sources:
if isinstance(source, CellMapImage):
_apply_context_to_image(source, context)
elif isinstance(source, dict):
for sub_source in source.values():
if isinstance(sub_source, CellMapImage):
_apply_context_to_image(sub_source, context)
else:
logger.warning(
"Unsupported dataset type %s in _set_tensorstore_context; "
"TensorStore context was not applied.",
type(dataset).__name__,
)


def _apply_context_to_image(image: "CellMapImage", context) -> None:
Comment thread
rhoadesScholar marked this conversation as resolved.
"""Set the TensorStore context on a single CellMapImage, warning if already opened."""
if "_array" in getattr(image, "__dict__", {}):
logger.warning(
"TensorStore array already opened for %s; "
"cache_pool limit will not apply to this image.",
getattr(image, "path", image),
)
image.context = context


class CellMapDataLoader:
"""
Optimized DataLoader wrapper for CellMapDataset that uses PyTorch's native DataLoader.
Expand Down Expand Up @@ -50,6 +99,7 @@ def __init__(
rng: Optional[torch.Generator] = None,
device: Optional[str | torch.device] = None,
iterations_per_epoch: Optional[int] = None,
tensorstore_cache_bytes: Optional[int] = None,
**kwargs,
):
"""
Expand All @@ -67,6 +117,14 @@ def __init__(
rng: The random number generator.
device: The device to use ("cuda", "mps", or "cpu").
iterations_per_epoch: Iterations per epoch for large datasets.
tensorstore_cache_bytes: Total TensorStore chunk-cache budget in bytes
shared across all worker processes. The budget is split evenly:
``per_worker = tensorstore_cache_bytes // max(1, num_workers)``.
Defaults to the ``CELLMAP_TENSORSTORE_CACHE_BYTES`` environment
variable if set, otherwise no limit is applied (TensorStore's
default unbounded cache). Set to ``0`` to disable caching
entirely. Bounding this value prevents persistent worker
processes from accumulating chunk data unboundedly across epochs.
Comment thread
rhoadesScholar marked this conversation as resolved.
**kwargs: Additional PyTorch DataLoader arguments.
"""
self.dataset = dataset
Expand All @@ -80,10 +138,10 @@ def __init__(

if platform.system() == "Windows" and num_workers > 0:
logger.warning(
"CellMapDataLoader: num_workers=%d on Windows may cause nested "
"threading x multiprocessing issues with TensorStore. "
"The internal read limiter serializes reads, but num_workers=0 "
"is safer if crashes occur.",
"CellMapDataLoader: num_workers=%d on Windows. "
"The dataset uses a synchronous (single-thread) executor internally "
"so TensorStore reads are never dispatched to ThreadPoolExecutor "
"worker threads. If crashes persist, try num_workers=0.",
num_workers,
)

Expand All @@ -98,6 +156,44 @@ def __init__(
self.device = device
self.iterations_per_epoch = iterations_per_epoch

# Bound TensorStore chunk-cache to prevent unbounded RAM growth in
# persistent worker processes (Linux fork, Windows/macOS spawn).
# Resolve from parameter, then env var, then leave unconfigured.
if tensorstore_cache_bytes is None:
_env = os.environ.get("CELLMAP_TENSORSTORE_CACHE_BYTES")
if _env is not None:
try:
tensorstore_cache_bytes = int(_env)
except ValueError as exc:
raise ValueError(
"Invalid value for environment variable "
"CELLMAP_TENSORSTORE_CACHE_BYTES: "
f"{_env!r}. Expected an integer number of bytes."
) from exc
if tensorstore_cache_bytes is not None and tensorstore_cache_bytes < 0:
raise ValueError(
f"tensorstore_cache_bytes must be >= 0 when set; got {tensorstore_cache_bytes}"
)
self.tensorstore_cache_bytes = tensorstore_cache_bytes

if tensorstore_cache_bytes is not None and not isinstance(
dataset, CellMapDatasetWriter
):
import tensorstore as ts

effective_workers = max(1, num_workers)
per_worker_bytes = tensorstore_cache_bytes // effective_workers
bounded_ctx = ts.Context(
{"cache_pool": {"total_bytes_limit": per_worker_bytes}}
)
_set_tensorstore_context(dataset, bounded_ctx)
logger.info(
"TensorStore cache bounded: total=%d bytes / %d worker(s) = %d bytes each",
tensorstore_cache_bytes,
effective_workers,
per_worker_bytes,
)
Comment thread
rhoadesScholar marked this conversation as resolved.

# Extract DataLoader parameters with optimized defaults
# pin_memory only works with CUDA, so default to True only when CUDA is available
# and device is CUDA
Expand Down
81 changes: 63 additions & 18 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import platform
from concurrent.futures import Future as _ConcurrentFuture
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Mapping, Optional, Sequence

Expand All @@ -28,6 +29,44 @@
_OS_NAME = platform.system()
_DATA_BACKEND = os.environ.get("CELLMAP_DATA_BACKEND", "tensorstore")

# On Windows + TensorStore, calling tensorstore's .read().result() from a
# Python ThreadPoolExecutor worker thread causes a hard native crash
# (STATUS_STACK_BUFFER_OVERRUN / abort, exit code 0xC0000409). The
# limit_tensorstore_reads semaphore only prevents *concurrent* Python reads
# but does not fix the per-thread crash. The safest fix is to run all
# dataset __getitem__ work synchronously in the calling thread so that
# TensorStore is never invoked from a ThreadPoolExecutor worker on Windows.
_USE_IMMEDIATE_EXECUTOR = (
_OS_NAME == "Windows" and _DATA_BACKEND.lower() == "tensorstore"
)


class _ImmediateExecutor:
"""Drop-in for ThreadPoolExecutor that runs tasks in the calling thread.

On Windows + TensorStore the real ThreadPoolExecutor causes native crashes.
This executor avoids that by executing every submitted callable synchronously
before returning, so the returned Future is already resolved.
``as_completed`` handles pre-resolved futures correctly (yields immediately).
``shutdown`` is a no-op because there are no threads to join.
"""

def submit(self, fn, /, *args, **kwargs):
Comment thread
rhoadesScholar marked this conversation as resolved.
f = _ConcurrentFuture()
try:
f.set_result(fn(*args, **kwargs))
except Exception as exc: # noqa: BLE001
f.set_exception(exc)
return f

def shutdown(self, wait=True, *, cancel_futures=False):
pass # nothing to shut down


_IMMEDIATE_EXECUTOR: _ImmediateExecutor | None = (
_ImmediateExecutor() if _USE_IMMEDIATE_EXECUTOR else None
)


# %%
class CellMapDataset(CellMapBaseDataset, Dataset):
Expand Down Expand Up @@ -179,32 +218,38 @@ def __init__(
atexit.register(self.close)

@property
def executor(self) -> ThreadPoolExecutor:
def executor(self) -> ThreadPoolExecutor | _ImmediateExecutor:
"""
Lazy initialization of persistent ThreadPoolExecutor.
This eliminates the performance bottleneck of creating new executors per __getitem__ call.
Lazy initialization of persistent executor.

On Windows + TensorStore returns a module-level ``_ImmediateExecutor``
that runs every submitted callable synchronously in the calling thread.
This avoids the native crash (0xC0000409 / STATUS_STACK_BUFFER_OVERRUN)
that occurs when TensorStore's ``.read().result()`` is called from a
Python ``ThreadPoolExecutor`` worker thread on Windows.

On all other platforms returns the usual persistent ``ThreadPoolExecutor``.

In both cases ``self._executor`` and ``self._executor_pid`` are kept in
sync so that ``close()``, ``__del__``, and tests can inspect them
consistently regardless of platform.
"""
# Add pid tracking to detect process forking and prevent shared executors
current_pid = os.getpid()

if _USE_IMMEDIATE_EXECUTOR:
# Use the module-level singleton but still track state so that
# _executor / _executor_pid are never left as None after first access.
if self._executor is None or self._executor_pid != current_pid:
self._executor = _IMMEDIATE_EXECUTOR
self._executor_pid = current_pid
return self._executor # type: ignore[return-value]

# Non-Windows path: detect process forking and create a fresh executor.
if self._executor_pid != current_pid:
# Process was forked, need new executor
self._executor = None
self._executor_pid = current_pid

if self._executor is None:
if _OS_NAME == "Windows":
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None and self._max_workers > 1:
logger.warning(
"CellMapDataset running inside DataLoader worker "
"(id=%d, total=%d) on Windows with max_workers=%d. "
"Prefer max_workers=1 or num_workers=0 to avoid nested "
"threading x multiprocessing crashes. "
"TensorStore reads are still serialized by the read limiter.",
worker_info.id,
worker_info.num_workers,
self._max_workers,
)
self._executor = ThreadPoolExecutor(max_workers=self._max_workers)
return self._executor

Expand Down
Loading