Skip to content

Commit 53dc5f5

Browse files
feat: implement Windows read limiter to prevent crashes from concurrent TensorStore reads
- Add `read_limiter.py` with a semaphore-based context manager for managing TensorStore reads on Windows. - Update `CellMapDataset` to use the read limiter in `__getitem__` methods. - Introduce logging warnings for potential threading issues on Windows with multiple DataLoader workers. - Implement `close()` method in `CellMapDataset` for safe cleanup. - Add tests in `test_windows_stress.py` to validate read limiter functionality and executor lifecycle. - Enhance README with Windows compatibility guidelines and environment variable configurations.
1 parent d09f8b3 commit 53dc5f5

7 files changed

Lines changed: 660 additions & 13 deletions

File tree

CHANGELOG.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,55 @@
11
# CHANGELOG
22

3+
## Unreleased
4+
5+
### Fix
6+
7+
* fix: prevent Windows hard-crash from concurrent TensorStore reads
8+
9+
- Add `src/cellmap_data/read_limiter.py`: global `threading.Semaphore` that
10+
gates TensorStore materializations on Windows+TensorStore backend. No-op on
11+
Linux/macOS and when using the Dask backend.
12+
- Wrap the three read-triggering lines in `CellMapDataset.__getitem__`
13+
(`get_input_array`, `get_target_array` raw-only, `get_label_array`) with
14+
the `limit_tensorstore_reads()` context manager. Torch-only operations
15+
(`infer_label_array`, stacking, `.to(device)`) are left unconstrained.
16+
- Configure via `CELLMAP_MAX_CONCURRENT_READS` env var (default `"1"` on
17+
Windows; unlimited elsewhere). Must be set before importing `cellmap_data`.
18+
19+
### Feature
20+
21+
* feat: add `CellMapDataset.close()` and `atexit` registration
22+
23+
- `close()` calls `executor.shutdown(wait=True, cancel_futures=True)` and
24+
resets `_executor` to `None`, enabling safe deterministic cleanup.
25+
- `atexit.register(self.close)` ensures the executor is always shut down at
26+
interpreter exit, even when `__del__` is not called.
27+
28+
* feat: add nested-worker warning on Windows
29+
30+
- When `CellMapDataset.executor` is lazily created inside a DataLoader worker
31+
process (`torch.utils.data.get_worker_info() is not None`) on Windows with
32+
`max_workers > 1`, a `logger.warning` is emitted.
33+
- `CellMapDataLoader` warns when `num_workers > 0` on Windows.
34+
35+
* feat: improve init logging
36+
37+
- Replaced `logger.debug` with `logger.info` at dataset construction time,
38+
now including OS, backend, `max_workers`, and `max_concurrent_reads`.
39+
40+
### Test
41+
42+
* test: add `tests/test_windows_stress.py`
43+
44+
- `TestReadLimiterUnit`: semaphore state, context manager correctness,
45+
exception propagation, 50-thread deadlock test.
46+
- `TestExecutorLifecycle`: `close()` idempotency, executor recreation.
47+
- `TestConcurrentGetitem`: 200-iteration serial tests (both multi-class and
48+
raw-only paths); multi-thread tests where each thread has its own dataset
49+
instance, accurately mirroring DataLoader `num_workers > 0` behavior.
50+
- `test_windows_high_concurrency_no_crash`: 8 simulated workers × 100
51+
iterations each; skipped on non-Windows.
52+
353
## v0.1.0 (2024-09-06)
454

555
### Build

README.md

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,51 @@ input_arrays = {
399399
}
400400
```
401401

402+
## Windows Compatibility
403+
404+
CellMap-Data includes specific hardening for Windows to prevent native hard-crashes caused by concurrent TensorStore reads from multiple threads.
405+
406+
### TensorStore Read Limiter
407+
408+
On Windows, concurrent materializations of TensorStore-backed xarray arrays (triggered by `source[center]`, `.interp`, `.__array__`, etc.) can cause the Python process to abort. A global semaphore serializes these reads automatically:
409+
410+
```python
411+
# The limiter activates automatically on Windows with the default TensorStore backend.
412+
# No code changes required — it is transparent to all callers.
413+
414+
# Override the concurrency limit (default is 1 on Windows):
415+
import os
416+
os.environ["CELLMAP_MAX_CONCURRENT_READS"] = "2" # set BEFORE importing cellmap_data
417+
418+
from cellmap_data import CellMapDataset
419+
```
420+
421+
### Environment Variables
422+
423+
| Variable | Default | Description |
424+
|---|---|---|
425+
| `CELLMAP_DATA_BACKEND` | `"tensorstore"` | Backend for array reads (`"tensorstore"` or `"dask"`) |
426+
| `CELLMAP_MAX_WORKERS` | `8` | Max threads in the internal `ThreadPoolExecutor` |
427+
| `CELLMAP_MAX_CONCURRENT_READS` | `1` (Windows) / unlimited | Max concurrent TensorStore reads (Windows+TensorStore only) |
428+
429+
### Recommendations for Windows
430+
431+
- Keep the default `num_workers=0` in `CellMapDataLoader` (safest on Windows); the internal executor still parallelizes per-array I/O within each `__getitem__` call.
432+
- If you need `num_workers > 0`, each DataLoader worker process gets its own dataset copy and its own read semaphore — this is safe.
433+
- Do **not** share a single `CellMapDataset` instance across multiple threads that each call `__getitem__` concurrently. Use separate dataset instances instead (which is exactly what DataLoader workers do).
434+
435+
### Explicit Shutdown
436+
437+
`CellMapDataset` registers an `atexit` handler and exposes an explicit `close()` method for deterministic cleanup:
438+
439+
```python
440+
dataset = CellMapDataset(...)
441+
try:
442+
# ... training ...
443+
finally:
444+
dataset.close() # shuts down the internal ThreadPoolExecutor immediately
445+
```
446+
402447
## Performance Optimization
403448

404449
### Memory Management
@@ -407,15 +452,15 @@ input_arrays = {
407452
- Automatic GPU memory management
408453
- Streaming data loading for large volumes
409454

410-
### Parallel Processing
455+
### Parallel Processing
411456

412-
- Multi-threaded data loading
457+
- Multi-threaded data loading via persistent `ThreadPoolExecutor`
413458
- CUDA streams for GPU optimization
414459
- Process-safe dataset pickling
415460

416461
### Caching Strategy
417462

418-
- Persistent ThreadPoolExecutor for reduced overhead
463+
- Persistent `ThreadPoolExecutor` per process (lazy-initialized, PID-tracked)
419464
- Optimized coordinate transformations
420465
- Minimal redundant computations
421466

src/cellmap_data/dataloader.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,15 @@ def __init__(
7878
self.is_train = is_train
7979
self.rng = rng
8080

81+
if platform.system() == "Windows" and num_workers > 0:
82+
logger.warning(
83+
"CellMapDataLoader: num_workers=%d on Windows may cause nested "
84+
"threading x multiprocessing issues with TensorStore. "
85+
"The internal read limiter serializes reads, but num_workers=0 "
86+
"is safer if crashes occur.",
87+
num_workers,
88+
)
89+
8190
# Set device
8291
if device is None:
8392
if torch.cuda.is_available():

src/cellmap_data/dataset.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# %%
2+
import atexit
23
import functools
34
import logging
45
import os
6+
import platform
57
from concurrent.futures import ThreadPoolExecutor, as_completed
68
from typing import Any, Callable, Mapping, Optional, Sequence
79

@@ -15,12 +17,17 @@
1517
from .empty_image import EmptyImage
1618
from .image import CellMapImage
1719
from .mutable_sampler import MutableSubsetRandomSampler
20+
from .read_limiter import MAX_CONCURRENT_READS, limit_tensorstore_reads
1821
from .utils import get_sliced_shape, is_array_2D, min_redundant_inds, split_target_path
1922

2023
logger = logging.getLogger(__name__)
2124
if logger.level == logging.NOTSET:
2225
logger.setLevel(logging.INFO)
2326

27+
# Cache system values to avoid repeated calls during dataset instantiation
28+
_OS_NAME = platform.system()
29+
_DATA_BACKEND = os.environ.get("CELLMAP_DATA_BACKEND", "tensorstore")
30+
2431

2532
# %%
2633
class CellMapDataset(CellMapBaseDataset, Dataset):
@@ -154,14 +161,22 @@ def __init__(
154161
int(os.environ.get("CELLMAP_MAX_WORKERS", 8)), # Cap at 8 by default
155162
)
156163

157-
logger.debug(
158-
"CellMapDataset initialized with %d inputs, %d targets, %d classes. "
159-
"Using ThreadPoolExecutor with %d workers for parallel I/O.",
164+
logger.info(
165+
"CellMapDataset: OS=%s backend=%s max_workers=%d max_concurrent_reads=%s "
166+
"inputs=%d targets=%d classes=%d",
167+
_OS_NAME,
168+
_DATA_BACKEND,
169+
self._max_workers,
170+
(
171+
str(MAX_CONCURRENT_READS)
172+
if MAX_CONCURRENT_READS is not None
173+
else "unlimited"
174+
),
160175
len(self.input_arrays),
161176
len(self.target_arrays),
162177
len(self.classes),
163-
self._max_workers,
164178
)
179+
atexit.register(self.close)
165180

166181
@property
167182
def executor(self) -> ThreadPoolExecutor:
@@ -177,6 +192,19 @@ def executor(self) -> ThreadPoolExecutor:
177192
self._executor_pid = current_pid
178193

179194
if self._executor is None:
195+
if _OS_NAME == "Windows":
196+
worker_info = torch.utils.data.get_worker_info()
197+
if worker_info is not None and self._max_workers > 1:
198+
logger.warning(
199+
"CellMapDataset running inside DataLoader worker "
200+
"(id=%d, total=%d) on Windows with max_workers=%d. "
201+
"Prefer max_workers=1 or num_workers=0 to avoid nested "
202+
"threading x multiprocessing crashes. "
203+
"TensorStore reads are still serialized by the read limiter.",
204+
worker_info.id,
205+
worker_info.num_workers,
206+
self._max_workers,
207+
)
180208
self._executor = ThreadPoolExecutor(max_workers=self._max_workers)
181209
return self._executor
182210

@@ -188,6 +216,16 @@ def __del__(self):
188216
if hasattr(self, "_executor") and self._executor is not None:
189217
self._executor.shutdown(wait=True)
190218

219+
def close(self) -> None:
220+
"""Shut down the ThreadPoolExecutor and release resources.
221+
222+
Called automatically via atexit to ensure clean shutdown at interpreter
223+
exit, regardless of whether __del__ is invoked.
224+
"""
225+
if hasattr(self, "_executor") and self._executor is not None:
226+
self._executor.shutdown(wait=True, cancel_futures=True)
227+
self._executor = None
228+
191229
def __new__(
192230
cls,
193231
raw_path: str,
@@ -218,7 +256,7 @@ def __new__(
218256
):
219257
from cellmap_data.multidataset import CellMapMultiDataset
220258

221-
logger.warning(
259+
logger.info(
222260
"2D arrays requested without slicing axis. Creating datasets "
223261
"that each slice along one axis. If this is not intended, "
224262
"specify the slicing axis in the input and target arrays."
@@ -549,7 +587,8 @@ def __getitem__(self, idx: ArrayLike) -> dict[str, torch.Tensor]:
549587

550588
def get_input_array(array_name: str) -> tuple[str, torch.Tensor]:
551589
self.input_sources[array_name].set_spatial_transforms(spatial_transforms)
552-
array = self.input_sources[array_name][center]
590+
with limit_tensorstore_reads():
591+
array = self.input_sources[array_name][center]
553592
return array_name, array.squeeze()[None, ...]
554593

555594
futures = [
@@ -563,7 +602,8 @@ def get_target_array(array_name: str) -> tuple[str, torch.Tensor]:
563602
self.target_sources[array_name].set_spatial_transforms(
564603
spatial_transforms
565604
)
566-
array = self.target_sources[array_name][center]
605+
with limit_tensorstore_reads():
606+
array = self.target_sources[array_name][center]
567607
return array_name, array.squeeze()[None, ...]
568608

569609
else:
@@ -578,7 +618,8 @@ def get_label_array(
578618
source = self.target_sources[array_name].get(label)
579619
if isinstance(source, (CellMapImage, EmptyImage)):
580620
source.set_spatial_transforms(spatial_transforms)
581-
array = source[center].squeeze()
621+
with limit_tensorstore_reads():
622+
array = source[center].squeeze()
582623
else:
583624
array = None
584625
return label, array

src/cellmap_data/read_limiter.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
Global TensorStore read limiter for Windows crash prevention.
3+
4+
On Windows, concurrent TensorStore materializations from multiple threads
5+
(triggered by source[center], .interp, ._TensorStoreAdapter.__array__, etc.)
6+
cause native hard crashes / aborts. This module provides a semaphore-backed
7+
context manager that serializes those reads on Windows+TensorStore while
8+
acting as a no-op on all other platforms.
9+
10+
Configuration
11+
-------------
12+
CELLMAP_DATA_BACKEND : str
13+
Set to "tensorstore" (default) to enable the limiter on Windows.
14+
Set to anything else (e.g. "dask") to disable it entirely.
15+
16+
CELLMAP_MAX_CONCURRENT_READS : int
17+
Maximum concurrent TensorStore reads allowed on Windows.
18+
Defaults to 1 (fully serialized). Increase cautiously.
19+
20+
Notes
21+
-----
22+
Both environment variables must be set **before** this module is imported,
23+
as the semaphore is created once at import time.
24+
"""
25+
26+
import os
27+
import platform
28+
import threading
29+
from contextlib import contextmanager
30+
31+
_IS_WINDOWS = platform.system() == "Windows"
32+
_IS_TENSORSTORE = (
33+
os.environ.get("CELLMAP_DATA_BACKEND", "tensorstore").lower() == "tensorstore"
34+
)
35+
36+
MAX_CONCURRENT_READS: int | None
37+
_read_semaphore: threading.Semaphore | None
38+
39+
if _IS_WINDOWS and _IS_TENSORSTORE:
40+
MAX_CONCURRENT_READS = int(os.environ.get("CELLMAP_MAX_CONCURRENT_READS", "1"))
41+
_read_semaphore = threading.Semaphore(MAX_CONCURRENT_READS)
42+
else:
43+
MAX_CONCURRENT_READS = None
44+
_read_semaphore = None
45+
46+
47+
@contextmanager
48+
def limit_tensorstore_reads():
49+
"""Context manager that gates TensorStore reads on Windows.
50+
51+
On Windows with the TensorStore backend, at most ``MAX_CONCURRENT_READS``
52+
threads may be inside this context at once. On all other platforms (or
53+
when using the Dask backend) this is a true no-op with zero overhead.
54+
55+
Usage
56+
-----
57+
::
58+
59+
with limit_tensorstore_reads():
60+
array = source[center] # the unsafe read
61+
# torch-only work continues here unconstrained
62+
"""
63+
if _read_semaphore is not None:
64+
_read_semaphore.acquire()
65+
try:
66+
yield
67+
finally:
68+
_read_semaphore.release()
69+
else:
70+
yield

0 commit comments

Comments
 (0)