From 4d21d22ab5d775a7b3609f4417053ee454ea4a2a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 8 Jan 2026 10:03:53 +0000 Subject: [PATCH 001/156] :new: Define New MultiTaskSegmentor --- requirements/requirements.txt | 3 +- tiatoolbox/models/engine/engine_abc.py | 15 ++++++++-- .../models/engine/multi_task_segmentor.py | 28 +++---------------- .../models/engine/semantic_segmentor.py | 6 ++-- 4 files changed, 21 insertions(+), 31 deletions(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index f6d707a9e..195e4e9c0 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -2,7 +2,7 @@ aiohttp>=3.8.1 albumentations>=1.3.0 bokeh>=3.1.1, <3.6.0 Click>=8.1.3, <8.2.0 -dask>=2025.10.0 +dask>=2025.12.0 defusedxml>=0.7.1 filelock>=3.9.0 flask>=2.2.2 @@ -20,6 +20,7 @@ openslide-bin>=4.0.0.2 openslide-python>=1.4.0 pandas>=2.0.0 pillow>=9.3.0 +pyarrow>=22.0.0 pydicom>=2.3.1 # Used by wsidicom pyyaml>=6.0 requests>=2.28.1 diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index c44a8a870..633d0753e 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -45,7 +45,7 @@ import torch import zarr from dask import compute -from dask.diagnostics.progress import ProgressBar +from dask.diagnostics import ProgressBar from numcodecs import Pickle from torch import nn from typing_extensions import Unpack @@ -72,6 +72,8 @@ from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.type_hints import IntPair, Resolution, Units +dask.config.set({"dataframe.convert-string": False}) + class EngineABCRunParams(TypedDict, total=False): """Parameters for configuring the :func:`EngineABC.run()` method. @@ -519,7 +521,7 @@ def infer_patches( coordinates = [] # Main output dictionary - raw_predictions = dict(zip(keys, [[]] * len(keys), strict=False)) + raw_predictions = {key: [] for key in keys} # Inference loop tqdm = get_tqdm() @@ -777,7 +779,7 @@ def save_predictions_as_zarr( url=save_path, component=f"{key}/{i}", compute=False, - object_codec=object_codec, + zarr_array_kwargs={"object_codec": object_codec}, ) write_tasks.append(task) @@ -1211,6 +1213,9 @@ def _update_run_params( If an unsupported output_type is provided. ValueError: If required configuration or input parameters are missing. + ValueError: + If save_dir is not provided and output_type is "zarr" + or "annotationstore". """ for key in kwargs: @@ -1251,6 +1256,10 @@ def _update_run_params( ) logger.info(msg) + if save_dir is None and output_type.lower() in ["zarr", "annotationstore"]: + msg = f"Please provide save_dir for output_type={output_type}" + raise ValueError(msg) + self.images = self._validate_images_masks(images=images) if masks is not None: diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 7293e78cc..5efeca3b1 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1,31 +1,10 @@ -# ***** BEGIN GPL LICENSE BLOCK ***** -# -# This program is free software; you can redistribute it and/or -# modify it under the terms of the GNU General Public License -# as published by the Free Software Foundation; either version 2 -# of the License, or (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program; if not, write to the Free Software Foundation, -# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -# -# The Original Code is Copyright (C) 2021, TIA Centre, University of Warwick -# All rights reserved. -# ***** END GPL LICENSE BLOCK ***** - -"""This module enables multi-task segmentors.""" +"""This module enables multi-task segmentor.""" from __future__ import annotations import shutil from typing import TYPE_CHECKING -# replace with the sql database once the PR in place import joblib import numpy as np from shapely.geometry import box as shapely_box @@ -33,10 +12,11 @@ from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset from tiatoolbox.models.engine.nucleus_instance_segmentor import ( - NucleusInstanceSegmentor, _process_instance_predictions, ) +from .semantic_segmentor import SemanticSegmentor + if TYPE_CHECKING: # pragma: no cover from collections.abc import Callable @@ -180,7 +160,7 @@ def _process_tile_predictions( # skipcq: PY-R1000 return new_inst_dicts, remove_insts_in_origs, sem_maps, tile_bounds -class MultiTaskSegmentor(NucleusInstanceSegmentor): +class MultiTaskSegmentor(SemanticSegmentor): """An engine specifically designed to handle tiles or WSIs inference. Note, if `model` is supplied in the arguments, it will ignore the diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 0bf1be496..9b1a2964c 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -587,7 +587,7 @@ def save_predictions( output_type: str, save_path: Path | None = None, **kwargs: Unpack[SemanticSegmentorRunParams], - ) -> dict | AnnotationStore | Path: + ) -> dict | AnnotationStore | Path | list[Path]: """Save semantic segmentation predictions to disk or return them in memory. This method saves predictions in one of the supported formats: @@ -645,11 +645,11 @@ def save_predictions( Whether to enable verbose logging. Returns: - dict | AnnotationStore | Path: + dict | AnnotationStore | Path | list[Path]: - If output_type is "dict": returns predictions as a dictionary. - If output_type is "zarr": returns path to saved Zarr file. - If output_type is "annotationstore": returns AnnotationStore - or path to .db file. + or path or list of paths to .db file. """ # Conversion to annotationstore uses a different function for SemanticSegmentor From fef42ea40f88b06bf3505ea84d1051b0f82af458 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 8 Jan 2026 10:30:37 +0000 Subject: [PATCH 002/156] :construction: Add functionalities --- .../models/engine/multi_task_segmentor.py | 186 ++++++++++++++---- 1 file changed, 144 insertions(+), 42 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 5efeca3b1..e26e4dfa6 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -3,26 +3,28 @@ from __future__ import annotations import shutil -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Unpack import joblib import numpy as np from shapely.geometry import box as shapely_box from shapely.strtree import STRtree -from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset from tiatoolbox.models.engine.nucleus_instance_segmentor import ( _process_instance_predictions, ) -from .semantic_segmentor import SemanticSegmentor +from .semantic_segmentor import SemanticSegmentor, SemanticSegmentorRunParams if TYPE_CHECKING: # pragma: no cover + import os from collections.abc import Callable + from pathlib import Path - import torch - - from tiatoolbox.type_hints import IntBounds + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.type_hints import IntBounds, IntPair, Resolution, Units + from tiatoolbox.wsicore import WSIReader from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig @@ -226,51 +228,24 @@ class MultiTaskSegmentor(SemanticSegmentor): def __init__( self: MultiTaskSegmentor, + model: str | ModelABC, batch_size: int = 8, - num_loader_workers: int = 0, - num_postproc_workers: int = 0, - model: torch.nn.Module | None = None, - pretrained_model: str | None = None, - pretrained_weights: str | None = None, - dataset_class: Callable = WSIStreamDataset, - output_types: list | None = None, + num_workers: int = 0, + weights: str | Path | None = None, *, + device: str = "cpu", verbose: bool = True, - auto_generate_mask: bool = False, ) -> None: - """Initialize :class:`MultiTaskSegmentor`.""" + """Initialize :class:`NucleusInstanceSegmentor`.""" super().__init__( - batch_size=batch_size, - num_loader_workers=num_loader_workers, - num_postproc_workers=num_postproc_workers, model=model, - pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, + batch_size=batch_size, + num_workers=num_workers, + weights=weights, + device=device, verbose=verbose, - auto_generate_mask=auto_generate_mask, - dataset_class=dataset_class, ) - self.output_types = output_types - self._futures = None - - if "hovernetplus" in str(pretrained_model): - self.output_types = ["instance", "semantic"] - elif "hovernet" in str(pretrained_model): - self.output_types = ["instance"] - - # adding more runtime placeholder - if self.output_types is not None: - if "semantic" in self.output_types: - self.wsi_layers = [] - if "instance" in self.output_types: - self._wsi_inst_info = [] - else: - msg = "Output type must be specified for instance or semantic segmentation." - raise ValueError( - msg, - ) - def _predict_one_wsi( self: MultiTaskSegmentor, wsi_idx: int, @@ -445,3 +420,130 @@ def callback( # manually call the callback rather than # attaching it when receiving/creating the future callback(*future.result()) + + def run( + self: MultiTaskSegmentor, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + *, + masks: list[os.PathLike | Path] | np.ndarray | None = None, + input_resolutions: list[dict[Units, Resolution]] | None = None, + patch_input_shape: IntPair | None = None, + ioconfig: IOSegmentorConfig | None = None, + patch_mode: bool = True, + save_dir: os.PathLike | Path | None = None, + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> AnnotationStore | Path | str | dict | list[Path]: + """Run the semantic segmentation engine on input images. + + This method orchestrates the full inference pipeline, including preprocessing, + model inference, post-processing, and saving results. It supports both + patch-level and whole slide image (WSI) modes. + + Args: + images (list[PathLike | WSIReader] | np.ndarray): + Input images or patches. Can be a list of file paths, WSIReader objects, + or a NumPy array of image patches. + masks (list[PathLike] | np.ndarray | None): + Optional masks for WSI processing. Only used when `patch_mode` is False. + input_resolutions (list[dict[Units, Resolution]] | None): + Resolution settings for input heads. Supported units are `level`, + `power`, and `mpp`. Keys should be "units" and "resolution", e.g., + [{"units": "mpp", "resolution": 0.25}]. See :class:`WSIReader` for + details. + patch_input_shape (IntPair | None): + Shape of input patches (height, width), requested at read + resolution. Must be positive. + ioconfig (IOSegmentorConfig | None): + IO configuration for patch extraction and resolution. + patch_mode (bool): + Whether to treat input as patches (`True`) or WSIs (`False`). Default + is True. + save_dir (PathLike | None): + Directory to save output files. Required for WSI mode. + overwrite (bool): + Whether to overwrite existing output files. Default is False. + output_type (str): + Desired output format: "dict", "zarr", or "annotationstore". Default + is "dict". + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters to configure segmentation. + + Optional Keys: + auto_get_mask (bool): + Automatically generate segmentation masks using + `wsireader.tissue_mask()` during processing. + batch_size (int): + Number of image patches per forward pass. + class_dict (dict): + Mapping of classification outputs to class names. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). + labels (list): + Optional labels for input images. Only a single label per image + is supported. + memory_threshold (int): + Memory usage threshold (percentage) to trigger caching behavior. + num_workers (int): + Number of workers for DataLoader and post-processing. + output_file (str): + Filename for saving output (e.g., ".zarr" or ".db"). + output_resolutions (Resolution): + Resolution used for writing output predictions. + patch_output_shape (tuple[int, int]): + Shape of output patches (height, width). + return_labels (bool): + Whether to return labels with predictions. + return_probabilities (bool): + Whether to return per-class probabilities. + scale_factor (tuple[float, float]): + Scale factor for annotations (model_mpp / slide_mpp). + Used to convert coordinates to baseline resolution. + stride_shape (tuple[int, int]): + Stride used during WSI processing. + Defaults to `patch_input_shape` if not provided. + verbose (bool): + Whether to enable verbose logging. + + Returns: + AnnotationStore | Path | str | dict | list[Path]: + - If `patch_mode` is True: returns predictions or path to saved output. + - If `patch_mode` is False: returns a dictionary mapping each WSI + to its output path. + + Examples: + >>> wsis = ['wsi1.svs', 'wsi2.svs'] + >>> image_patches = [np.ndarray, np.ndarray] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(image_patches, patch_mode=True) + >>> output + ... "/path/to/Output.db" + + >>> output = segmentor.run( + ... image_patches, + ... patch_mode=True, + ... output_type="zarr" + ... ) + >>> output + ... "/path/to/Output.zarr" + + >>> output = segmentor.run(wsis, patch_mode=False) + >>> output.keys() + ... ['wsi1.svs', 'wsi2.svs'] + >>> output['wsi1.svs'] + ... "/path/to/wsi1.db" + + """ + return super().run( + images=images, + masks=masks, + input_resolutions=input_resolutions, + patch_input_shape=patch_input_shape, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) From ab5cf7c8febbeb277a44b3704891648b2eb91bbb Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 8 Jan 2026 10:33:03 +0000 Subject: [PATCH 003/156] :fire: Remove previous codes --- .../models/engine/multi_task_segmentor.py | 181 +----------------- 1 file changed, 1 insertion(+), 180 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index e26e4dfa6..ab61722a1 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -2,13 +2,9 @@ from __future__ import annotations -import shutil from typing import TYPE_CHECKING, Unpack -import joblib import numpy as np -from shapely.geometry import box as shapely_box -from shapely.strtree import STRtree from tiatoolbox.models.engine.nucleus_instance_segmentor import ( _process_instance_predictions, @@ -26,7 +22,7 @@ from tiatoolbox.type_hints import IntBounds, IntPair, Resolution, Units from tiatoolbox.wsicore import WSIReader - from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig + from .io_config import IOSegmentorConfig # Python is yet to be able to natively pickle Object method/static method. @@ -246,181 +242,6 @@ def __init__( verbose=verbose, ) - def _predict_one_wsi( - self: MultiTaskSegmentor, - wsi_idx: int, - ioconfig: IOInstanceSegmentorConfig, - save_path: str, - mode: str, - ) -> None: - """Make a prediction on tile/wsi. - - Args: - wsi_idx (int): - Index of the tile/wsi to be processed within `self`. - ioconfig (IOInstanceSegmentorConfig): - Object which defines I/O placement - during inference and when assembling back to full tile/wsi. - save_path (str): - Location to save output prediction as well as possible - intermediate results. - mode (str): - `tile` or `wsi` to indicate run mode. - - """ - cache_dir = f"{self._cache_dir}/" - wsi_path = self.imgs[wsi_idx] - mask_path = None if self.masks is None else self.masks[wsi_idx] - wsi_reader, mask_reader = self.get_reader( - wsi_path, - mask_path, - mode, - auto_get_mask=self.auto_generate_mask, - ) - - # assume ioconfig has already been converted to `baseline` for `tile` mode - resolution = ioconfig.highest_input_resolution - wsi_proc_shape = wsi_reader.slide_dimensions(**resolution) - - # * retrieve patch placement - # this is in XY - (patch_inputs, patch_outputs) = self.get_coordinates(wsi_proc_shape, ioconfig) - if mask_reader is not None: - sel = self.filter_coordinates(mask_reader, patch_outputs, **resolution) - patch_outputs = patch_outputs[sel] - patch_inputs = patch_inputs[sel] - - # assume to be in [top_left_x, top_left_y, bot_right_x, bot_right_y] - geometries = [shapely_box(*bounds) for bounds in patch_outputs] - spatial_indexer = STRtree(geometries) - - # * retrieve tile placement and tile info flag - # tile shape will always be corrected to be multiple of output - tile_info_sets = self._get_tile_info(wsi_proc_shape, ioconfig) - - # ! running order of each set matters ! - self._futures = [] - - indices_sem = [i for i, x in enumerate(self.output_types) if x == "semantic"] - - for s_id in range(len(indices_sem)): - shape = tuple(map(int, np.fliplr([wsi_proc_shape])[0])) - self.wsi_layers.append( - np.lib.format.open_memmap( - f"{cache_dir}/{s_id}.npy", - mode="w+", - shape=shape, - dtype=np.uint8, - ), - ) - self.wsi_layers[s_id][:] = 0 - - indices_inst = [i for i, x in enumerate(self.output_types) if x == "instance"] - - if not self._wsi_inst_info: # pragma: no cover - self._wsi_inst_info = [] - self._wsi_inst_info.extend({} for _ in indices_inst) - - for set_idx, (set_bounds, set_flags) in enumerate(tile_info_sets): - for tile_idx, tile_bounds in enumerate(set_bounds): - tile_flag = set_flags[tile_idx] - - # select any patches that have their output - # within the current tile - sel_box = shapely_box(*tile_bounds) - sel_indices = list(spatial_indexer.query(sel_box)) - - tile_patch_inputs = patch_inputs[sel_indices] - tile_patch_outputs = patch_outputs[sel_indices] - self._to_shared_space(wsi_idx, tile_patch_inputs, tile_patch_outputs) - - tile_infer_output = self._infer_once() - - self._process_tile_predictions( - ioconfig, - tile_bounds, - tile_flag, - set_idx, - tile_infer_output, - ) - self._merge_post_process_results() - - # Maybe change to store semantic annotations as contours in .dat file... - for i_id, inst_idx in enumerate(indices_inst): - joblib.dump(self._wsi_inst_info[i_id], f"{save_path}.{inst_idx}.dat") - self._wsi_inst_info = [] # clean up - - for s_id, sem_idx in enumerate(indices_sem): - shutil.copyfile(f"{cache_dir}/{s_id}.npy", f"{save_path}.{sem_idx}.npy") - # may need to chain it with parents - - def _process_tile_predictions( - self: MultiTaskSegmentor, - ioconfig: IOSegmentorConfig, - tile_bounds: IntBounds, - tile_flag: list, - tile_mode: int, - tile_output: list, - ) -> None: - """Function to dispatch parallel post processing.""" - args = [ - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, - self._wsi_inst_info, - self.model.postproc_func, - self.merge_prediction, - self.pretrained_model, - ] - if self._postproc_workers is not None: - future = self._postproc_workers.submit(_process_tile_predictions, *args) - else: - future = _process_tile_predictions(*args) - self._futures.append(future) - - def _merge_post_process_results(self: MultiTaskSegmentor) -> None: - """Helper to aggregate results from parallel workers.""" - - def callback( - new_inst_dicts: dict, - remove_uuid_lists: list, - tiles: dict, - bounds: IntBounds, - ) -> None: - """Helper to aggregate worker's results.""" - # ! DEPRECATION: - # ! will be deprecated upon finalization of SQL annotation store - for inst_id, new_inst_dict in enumerate(new_inst_dicts): - self._wsi_inst_info[inst_id].update(new_inst_dict) - for inst_uuid in remove_uuid_lists[inst_id]: - self._wsi_inst_info[inst_id].pop(inst_uuid, None) - - x_start, y_start, x_end, y_end = bounds - for sem_id, tile in enumerate(tiles): - max_h, max_w = self.wsi_layers[sem_id].shape - x_end, y_end = min(x_end, max_w), min(y_end, max_h) - tile_ = tile[0 : y_end - y_start, 0 : x_end - x_start] - self.wsi_layers[sem_id][y_start:y_end, x_start:x_end] = tile_ - # ! - - for future in self._futures: - # not actually future but the results - if self._postproc_workers is None: - callback(*future) - continue - # some errors happen, log it and propagate exception - # ! this will lead to discard a whole bunch of - # ! inferred tiles within this current WSI - if future.exception() is not None: - raise future.exception() - - # aggregate the result via callback - # manually call the callback rather than - # attaching it when receiving/creating the future - callback(*future.result()) - def run( self: MultiTaskSegmentor, images: list[os.PathLike | Path | WSIReader] | np.ndarray, From eab6db0dc0fde3f4f8405c17f508ef8c0e465413 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 8 Jan 2026 10:34:09 +0000 Subject: [PATCH 004/156] :fire: Remove previous codes --- .../models/engine/multi_task_segmentor.py | 144 +----------------- 1 file changed, 3 insertions(+), 141 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index ab61722a1..4960cde3c 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -4,160 +4,22 @@ from typing import TYPE_CHECKING, Unpack -import numpy as np - -from tiatoolbox.models.engine.nucleus_instance_segmentor import ( - _process_instance_predictions, -) - from .semantic_segmentor import SemanticSegmentor, SemanticSegmentorRunParams if TYPE_CHECKING: # pragma: no cover import os - from collections.abc import Callable from pathlib import Path + import numpy as np + from tiatoolbox.annotation import AnnotationStore from tiatoolbox.models.models_abc import ModelABC - from tiatoolbox.type_hints import IntBounds, IntPair, Resolution, Units + from tiatoolbox.type_hints import IntPair, Resolution, Units from tiatoolbox.wsicore import WSIReader from .io_config import IOSegmentorConfig -# Python is yet to be able to natively pickle Object method/static method. -# Only top-level function is passable to multi-processing as caller. -# May need 3rd party libraries to use method/static method otherwise. -def _process_tile_predictions( # skipcq: PY-R1000 - ioconfig: IOSegmentorConfig, - tile_bounds: IntBounds, - tile_flag: list, - tile_mode: int, - tile_output: list, - # this would be replaced by annotation store - # in the future - ref_inst_dict: dict, - postproc: Callable, - merge_predictions: Callable, - model_name: str, -) -> tuple: - """Process Tile Predictions. - - Function to merge new tile prediction with existing prediction, - using the output from each task. - - Args: - ioconfig (:class:`IOSegmentorConfig`): Object defines information - about input and output placement of patches. - tile_bounds (:class:`numpy.array`): Boundary of the current tile, defined as - (top_left_x, top_left_y, bottom_x, bottom_y). - tile_flag (list): A list of flag to indicate if instances within - an area extended from each side (by `ioconfig.margin`) of - the tile should be replaced by those within the same spatial - region in the accumulated output this run. The format is - [top, bottom, left, right], 1 indicates removal while 0 is not. - For example, [1, 1, 0, 0] denotes replacing top and bottom instances - within `ref_inst_dict` with new ones after this processing. - tile_mode (int): A flag to indicate the type of this tile. There - are 4 flags: - - 0: A tile from tile grid without any overlapping, it is not - an overlapping tile from tile generation. The predicted - instances are immediately added to accumulated output. - - 1: Vertical tile strip that stands between two normal tiles - (flag 0). It has the the same height as normal tile but - less width (hence vertical strip). - - 2: Horizontal tile strip that stands between two normal tiles - (flag 0). It has the the same width as normal tile but - less height (hence horizontal strip). - - 3: tile strip stands at the cross section of four normal tiles - (flag 0). - tile_output (list): A list of patch predictions, that lie within this - tile, to be merged and processed. - ref_inst_dict (dict): Dictionary contains accumulated output. The - expected format is {instance_id: {type: int, - contour: List[List[int]], centroid:List[float], box:List[int]}. - postproc (callable): Function to post-process the raw assembled tile. - merge_predictions (callable): Function to merge the `tile_output` into - raw tile prediction. - model_name (string): Name of the existing models support by tiatoolbox - for processing the data. Refer to [URL] for details. - - Returns: - new_inst_dict (dict): A dictionary contain new instances to be accumulated. - The expected format is {instance_id: {type: int, - contour: List[List[int]], centroid:List[float], box:List[int]}. - remove_insts_in_orig (list): List of instance id within `ref_inst_dict` - to be removed to prevent overlapping predictions. These instances - are those get cutoff at the boundary due to the tiling process. - sem_maps (list): List of semantic segmentation maps. - tile_bounds (:class:`numpy.array`): Boundary of the current tile, defined as - (top_left_x, top_left_y, bottom_x, bottom_y). - - """ - locations, predictions = list(zip(*tile_output, strict=False)) - - # convert from WSI space to tile space - tile_tl = tile_bounds[:2] - tile_br = tile_bounds[2:] - locations = [np.reshape(loc, (2, -1)) for loc in locations] - locations_in_tile = [loc - tile_tl[None] for loc in locations] - locations_in_tile = [loc.flatten() for loc in locations_in_tile] - locations_in_tile = np.array(locations_in_tile) - - tile_shape = tile_br - tile_tl # in width height - - # as the placement output is calculated wrt highest possible resolution - # within input, the output will need to re-calibrate if it is at different - # resolution than the input - ioconfig = ioconfig.to_baseline() - fx_list = [v["resolution"] for v in ioconfig.output_resolutions] - - head_raws = [] - for idx, fx in enumerate(fx_list): - head_tile_shape = np.ceil(tile_shape * fx).astype(np.int32) - head_locations = np.ceil(locations_in_tile * fx).astype(np.int32) - head_predictions = [v[idx][0] for v in predictions] - head_raw = merge_predictions( - head_tile_shape[::-1], - head_predictions, - head_locations, - ) - head_raws.append(head_raw) - - if "hovernetplus" in model_name: - _, inst_dict, layer_map, _ = postproc(head_raws) - out_dicts = [inst_dict, layer_map] - elif "hovernet" in model_name: - _, inst_dict = postproc(head_raws) - out_dicts = [inst_dict] - else: - out_dicts = postproc(head_raws) - - inst_dicts = [out for out in out_dicts if isinstance(out, dict)] - sem_maps = [out for out in out_dicts if isinstance(out, np.ndarray)] - # Some output maps may not be aggregated into a single map - combine these - sem_maps = [ - np.argmax(s, axis=-1) if s.ndim == 3 else s # noqa: PLR2004 - for s in sem_maps - ] - - new_inst_dicts, remove_insts_in_origs = [], [] - for inst_id, inst_dict in enumerate(inst_dicts): - new_inst_dict, remove_insts_in_orig = _process_instance_predictions( - inst_dict, - ioconfig, - tile_shape, - tile_flag, - tile_mode, - tile_tl, - ref_inst_dict[inst_id], - ) - new_inst_dicts.append(new_inst_dict) - remove_insts_in_origs.append(remove_insts_in_orig) - - return new_inst_dicts, remove_insts_in_origs, sem_maps, tile_bounds - - class MultiTaskSegmentor(SemanticSegmentor): """An engine specifically designed to handle tiles or WSIs inference. From f8ff99c65d08e5adf6637f70d369b103a8a0bf53 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 8 Jan 2026 10:42:57 +0000 Subject: [PATCH 005/156] :white_check_mark: Add initialization test --- tests/engines/test_multi_task_segmentor.py | 18 +++++ .../models/engine/multi_task_segmentor.py | 67 ++----------------- 2 files changed, 22 insertions(+), 63 deletions(-) create mode 100644 tests/engines/test_multi_task_segmentor.py diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py new file mode 100644 index 000000000..2954ef80f --- /dev/null +++ b/tests/engines/test_multi_task_segmentor.py @@ -0,0 +1,18 @@ +"""Test MultiTaskSegmentor.""" + +from __future__ import annotations + +import torch + +from tiatoolbox.models.engine.multi_task_segmentor import MultiTaskSegmentor +from tiatoolbox.utils import env_detection as toolbox_env + +device = "cuda" if toolbox_env.has_gpu() else "cpu" + + +def test_mtsegmentor_init() -> None: + """Tests SemanticSegmentor initialization.""" + segmentor = MultiTaskSegmentor(model="hovernetplus-oed", device=device) + + assert isinstance(segmentor, MultiTaskSegmentor) + assert isinstance(segmentor.model, torch.nn.Module) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 4960cde3c..eae914bb7 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -2,7 +2,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Unpack +from typing import TYPE_CHECKING + +from typing_extensions import Unpack from .semantic_segmentor import SemanticSegmentor, SemanticSegmentorRunParams @@ -21,68 +23,7 @@ class MultiTaskSegmentor(SemanticSegmentor): - """An engine specifically designed to handle tiles or WSIs inference. - - Note, if `model` is supplied in the arguments, it will ignore the - `pretrained_model` and `pretrained_weights` arguments. Each WSI's instance - predictions (e.g. nuclear instances) will be store under a `.dat` file and - the semantic segmentation predictions will be stored in a `.npy` file. The - `.dat` files contains a dictionary of form: - - .. code-block:: yaml - - inst_uid: - # top left and bottom right of bounding box - box: (start_x, start_y, end_x, end_y) - # centroid coordinates - centroid: (x, y) - # array/list of points - contour: [(x1, y1), (x2, y2), ...] - # the type of nuclei - type: int - # the probabilities of being this nuclei type - prob: float - - Args: - model (nn.Module): Use externally defined PyTorch model for prediction with. - weights already loaded. Default is `None`. If provided, - `pretrained_model` argument is ignored. - pretrained_model (str): Name of the existing models support by tiatoolbox - for processing the data. Refer to [URL] for details. - By default, the corresponding pretrained weights will also be - downloaded. However, you can override with your own set of weights - via the `pretrained_weights` argument. Argument is case insensitive. - pretrained_weights (str): Path to the weight of the corresponding - `pretrained_model`. - batch_size (int) : Number of images fed into the model each time. - num_loader_workers (int) : Number of workers to load the data. - Take note that they will also perform preprocessing. - num_postproc_workers (int) : Number of workers to post-process - predictions. - verbose (bool): Whether to output logging information. - dataset_class (obj): Dataset class to be used instead of default. - auto_generate_mask (bool): To automatically generate tile/WSI tissue mask - if is not provided. - output_types (list): Ordered list describing what sort of segmentation the - output from the model postproc gives for a two-task model this may be: - ['instance', 'semantic'] - - Examples: - >>> # Sample output of a network - >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] - >>> predictor = MultiTaskSegmentor( - ... model='hovernetplus-oed', - ... output_type=['instance', 'semantic'], - ... ) - >>> output = predictor.predict(wsis, mode='wsi') - >>> list(output.keys()) - [('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')] - >>> # Each output of 'A/wsi.svs' - >>> # will be respectively stored in 'output/0.0.dat', 'output/0.1.npy' - >>> # Here, the second integer represents the task number - >>> # e.g. between 0 or 1 for a two task model - - """ + """A multitask segmentation engine for models like hovernet and hovernetplus.""" def __init__( self: MultiTaskSegmentor, From 17c7a8a1e978434493d829e9a779709ac445ca67 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 8 Jan 2026 14:13:37 +0000 Subject: [PATCH 006/156] :boom: Add patch segmentation test --- tests/engines/test_multi_task_segmentor.py | 43 ++++++++ .../models/architecture/hovernetplus.py | 6 ++ tiatoolbox/models/engine/engine_abc.py | 2 +- .../models/engine/multi_task_segmentor.py | 99 ++++++++++++++++++- 4 files changed, 148 insertions(+), 2 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 2954ef80f..c59a0c686 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -2,10 +2,19 @@ from __future__ import annotations +from pathlib import Path +from typing import TYPE_CHECKING, Final + +import numpy as np import torch from tiatoolbox.models.engine.multi_task_segmentor import MultiTaskSegmentor from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.wsicore import WSIReader + +if TYPE_CHECKING: + from collections.abc import Callable + device = "cuda" if toolbox_env.has_gpu() else "cpu" @@ -16,3 +25,37 @@ def test_mtsegmentor_init() -> None: assert isinstance(segmentor, MultiTaskSegmentor) assert isinstance(segmentor.model, torch.nn.Module) + + +def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> None: + """Tests MultiTaskSegmentor on image patches.""" + segmentor = MultiTaskSegmentor( + model="hovernetplus-oed", batch_size=32, verbose=False, device=device + ) + + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + mini_wsi = WSIReader.open(mini_wsi_svs) + size = (256, 256) + resolution = 0.25 + units: Final = "mpp" + + patch1 = mini_wsi.read_rect( + location=(0, 0), size=size, resolution=resolution, units=units + ) + patch2 = mini_wsi.read_rect( + location=(512, 512), size=size, resolution=resolution, units=units + ) + patch3 = np.zeros_like(patch1) + patches = np.stack([patch1, patch2, patch3], axis=0) + + assert not segmentor.patch_mode + + _ = segmentor.run( + images=patches, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + ) + + _ = track_tmp_path diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 700eb303f..dda7322e3 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -5,6 +5,7 @@ from collections import OrderedDict import cv2 +import dask import numpy as np import torch import torch.nn.functional as F # noqa: N812 @@ -312,6 +313,11 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple: """ np_map, hv_map, tp_map, ls_map = raw_maps + np_map = np_map.compute() if isinstance(np_map, dask.array.Array) else np_map + hv_map = hv_map.compute() if isinstance(hv_map, dask.array.Array) else hv_map + tp_map = tp_map.compute() if isinstance(tp_map, dask.array.Array) else tp_map + ls_map = ls_map.compute() if isinstance(ls_map, dask.array.Array) else ls_map + pred_inst = HoVerNetPlus._proc_np_hv(np_map, hv_map, scale_factor=0.5) # fx=0.5 as nuclear processing is at 0.5 mpp instead of 0.25 mpp diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 633d0753e..b83ecdb5f 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1285,7 +1285,7 @@ def _run_patch_mode( self: EngineABC, output_type: str, save_dir: Path, - **kwargs: EngineABCRunParams, + **kwargs: Unpack[EngineABCRunParams], ) -> dict | AnnotationStore | Path: """Run the engine in patch mode. diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index eae914bb7..ecc93b444 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -4,15 +4,20 @@ from typing import TYPE_CHECKING +import dask.array as da +import numpy as np +import torch from typing_extensions import Unpack +from tiatoolbox.utils.misc import get_tqdm + from .semantic_segmentor import SemanticSegmentor, SemanticSegmentorRunParams if TYPE_CHECKING: # pragma: no cover import os from pathlib import Path - import numpy as np + from torch.utils.data import DataLoader from tiatoolbox.annotation import AnnotationStore from tiatoolbox.models.models_abc import ModelABC @@ -45,6 +50,98 @@ def __init__( verbose=verbose, ) + def infer_patches( + self: MultiTaskSegmentor, + dataloader: DataLoader, + *, + return_coordinates: bool = False, + ) -> dict[str, list[da.Array]]: + """Run model inference on image patches and return predictions. + + This method performs batched inference using a PyTorch DataLoader, + and accumulates predictions in Dask arrays. It supports optional inclusion + of coordinates and labels in the output. + + Args: + dataloader (DataLoader): + PyTorch DataLoader containing image patches for inference. + return_coordinates (bool): + Whether to include coordinates in the output. Required when + called by `infer_wsi` and `patch_mode` is False. + + Returns: + dict[str, dask.array.Array]: + Dictionary containing prediction results as Dask arrays. + Keys include: + - "probabilities": Model output probabilities. + - "labels": Ground truth labels (if `return_labels` is True). + - "coordinates": Patch coordinates (if `return_coordinates` is + True). + + """ + keys = ["probabilities"] + labels, coordinates = [], [] + + # Expected number of outputs from the model + batch_output = self.model.infer_batch( + self.model, + torch.Tensor(dataloader.dataset[0]["image"][np.newaxis, ...]), + device=self.device, + ) + + num_expected_output = len(batch_output) + probabilities = [[] for _ in range(num_expected_output)] + + if return_coordinates: + keys.append("coordinates") + coordinates = [] + + # Main output dictionary + raw_predictions = {key: [] for key in keys} + raw_predictions["probabilities"] = [[] for _ in range(num_expected_output)] + + # Inference loop + tqdm = get_tqdm() + tqdm_loop = ( + tqdm(dataloader, leave=False, desc="Inferring patches") + if self.verbose + else self.dataloader + ) + + for batch_data in tqdm_loop: + batch_output = self.model.infer_batch( + self.model, + batch_data["image"], + device=self.device, + ) + + for i in range(num_expected_output): + probabilities[i].append( + da.from_array( + batch_output[i], # probabilities + ) + ) + + if return_coordinates: + coordinates.append( + da.from_array( + self._get_coordinates(batch_data), + ) + ) + + if self.return_labels: + labels.append(da.from_array(np.array(batch_data["label"]))) + + for i in range(num_expected_output): + raw_predictions["probabilities"][i] = da.concatenate( + probabilities[i], axis=0 + ) + + if return_coordinates: + raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) + + return raw_predictions + def run( self: MultiTaskSegmentor, images: list[os.PathLike | Path | WSIReader] | np.ndarray, From ced4924da166f81d37df7161b708761faaa6777b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 8 Jan 2026 14:53:52 +0000 Subject: [PATCH 007/156] :boom: Update loop for postprocessing --- .../models/engine/multi_task_segmentor.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index ecc93b444..aa896cc10 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -142,6 +142,44 @@ def infer_patches( return raw_predictions + def post_process_patches( # skipcq: PYL-R0201 + self: MultiTaskSegmentor, + raw_predictions: dict, + **kwargs: Unpack[SemanticSegmentorRunParams], # noqa: ARG002 + ) -> dict: + """Post-process raw patch predictions from inference. + + This method applies a post-processing function (e.g., smoothing, filtering) + to the raw model predictions. It supports delayed execution using Dask + and returns a Dask array for efficient computation. + + Args: + raw_predictions (dask.array.Array): + Raw model predictions as a dask array. + prediction_shape (tuple[int, ...]): + Shape of the prediction output. + prediction_dtype (type): + Data type of the prediction output. + **kwargs (EngineABCRunParams): + Additional runtime parameters used for post-processing. + + Returns: + dask.array.Array: + Post-processed predictions as a Dask array. + + """ + probabilities = raw_predictions["probabilities"] + predictions = [[] for _ in range(probabilities[0].shape[0])] + inst_dict = [[{}] for _ in range(probabilities[0].shape[0])] + for idx, probs_for_idx in enumerate(zip(*probabilities, strict=False)): + predictions[idx] = self.model.postproc_func(list(probs_for_idx)) + + raw_predictions["predictions"] = da.stack(predictions, axis=0) + for key in inst_dict[0]: + raw_predictions[key] = [d[key] for d in inst_dict] + + return raw_predictions + def run( self: MultiTaskSegmentor, images: list[os.PathLike | Path | WSIReader] | np.ndarray, From 8ea1056a972214de510a67889a8eeeaf7bd343e9 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 9 Jan 2026 12:17:50 +0000 Subject: [PATCH 008/156] :boom: Work in progress --- tiatoolbox/models/architecture/hovernet.py | 13 ++++++++++- .../models/architecture/hovernetplus.py | 23 ++++++++++++++++++- .../models/engine/multi_task_segmentor.py | 2 +- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 19d02e7a5..3d5b0cdff 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -780,7 +780,18 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: pred_inst = HoVerNet._proc_np_hv(np_map, hv_map) nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) - return pred_inst, nuc_inst_info_dict + nuc_task ={ + "pred_inst": pred_inst, + "nuc_inst_info_dict": nuc_inst_info_dict, + } + + task_types = ["nuc_task"] + + return [ + {'task_type':"nuc_task", 'pred_inst':..., 'nuc_inst_info_dict', ..,}, + ] + + # return task_types, nuc_task @staticmethod def infer_batch( # skipcq: PYL-W0221 diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index dda7322e3..da694283c 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -327,7 +327,28 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple: nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) layer_info_dict = HoVerNetPlus._get_layer_info(pred_layer) - return pred_inst, nuc_inst_info_dict, pred_layer, layer_info_dict + layer_task = { + "pred_layer": pred_layer, + "layer_info_dict": layer_info_dict, + } + + nuc_task ={ + "pred_inst": pred_inst, + "nuc_inst_info_dict": nuc_inst_info_dict, + } + + tissue_mask = { + "inst": mask + } + + task_types = ["layer_task", "nuc_task"] + + return [ + {'task_type':"nuc_task", 'pred_inst':..., 'nuc_inst_info_dict'...,}, + {'task_type':"nuc_task", 'pred_inst':..., 'nuc_inst_info_dict'...,}, + ] + + return task_types, layer_task, nuc_task @staticmethod def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple: diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index aa896cc10..ccf2b42e9 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -174,7 +174,7 @@ def post_process_patches( # skipcq: PYL-R0201 for idx, probs_for_idx in enumerate(zip(*probabilities, strict=False)): predictions[idx] = self.model.postproc_func(list(probs_for_idx)) - raw_predictions["predictions"] = da.stack(predictions, axis=0) + raw_predictions[curr_task_type] = da.stack(predictions, axis=0) for key in inst_dict[0]: raw_predictions[key] = [d[key] for d in inst_dict] From 335aa0737a15630ea1739c5dac834cda462db9b4 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 9 Jan 2026 17:01:37 +0000 Subject: [PATCH 009/156] :construction: Update multitask post processing outputs. --- tests/engines/test_multi_task_segmentor.py | 6 +- tiatoolbox/models/architecture/hovernet.py | 17 ++-- .../models/architecture/hovernetplus.py | 29 +++---- .../models/engine/multi_task_segmentor.py | 83 ++++++++++++++++--- 4 files changed, 91 insertions(+), 44 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index c59a0c686..4dbd50c5a 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -29,7 +29,7 @@ def test_mtsegmentor_init() -> None: def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> None: """Tests MultiTaskSegmentor on image patches.""" - segmentor = MultiTaskSegmentor( + mtsegmentor = MultiTaskSegmentor( model="hovernetplus-oed", batch_size=32, verbose=False, device=device ) @@ -48,9 +48,9 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N patch3 = np.zeros_like(patch1) patches = np.stack([patch1, patch2, patch3], axis=0) - assert not segmentor.patch_mode + assert not mtsegmentor.patch_mode - _ = segmentor.run( + _ = mtsegmentor.run( images=patches, return_probabilities=True, return_labels=False, diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 3d5b0cdff..496410545 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -713,7 +713,7 @@ def get_instance_info(pred_inst: np.ndarray, pred_type: np.ndarray = None) -> di @staticmethod # skipcq: PYL-W0221 # noqa: ERA001 - def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: + def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: """Post-processing script for image tiles. Args: @@ -780,18 +780,13 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: pred_inst = HoVerNet._proc_np_hv(np_map, hv_map) nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) - nuc_task ={ - "pred_inst": pred_inst, - "nuc_inst_info_dict": nuc_inst_info_dict, + nuclei_seg = { + "task_type": "nuclei_segmentation", + "predictions": pred_inst, + "info_dict": nuc_inst_info_dict, } - task_types = ["nuc_task"] - - return [ - {'task_type':"nuc_task", 'pred_inst':..., 'nuc_inst_info_dict', ..,}, - ] - - # return task_types, nuc_task + return (nuclei_seg,) # Ensure return type is tuple. @staticmethod def infer_batch( # skipcq: PYL-W0221 diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index da694283c..5dd6e3606 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -236,7 +236,7 @@ def _get_layer_info(pred_layer: np.ndarray) -> dict: @staticmethod # skipcq: PYL-W0221 # noqa: ERA001 - def postproc(raw_maps: list[np.ndarray]) -> tuple: + def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: """Post-processing script for image tiles. Args: @@ -327,28 +327,19 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple: nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) layer_info_dict = HoVerNetPlus._get_layer_info(pred_layer) - layer_task = { - "pred_layer": pred_layer, - "layer_info_dict": layer_info_dict, + nuclei_seg = { + "task_type": "nuclei_segmentation", + "predictions": pred_inst, + "info_dict": nuc_inst_info_dict, } - nuc_task ={ - "pred_inst": pred_inst, - "nuc_inst_info_dict": nuc_inst_info_dict, + layer_seg = { + "task_type": "layer_segmentation", + "predictions": pred_layer, + "info_dict": layer_info_dict, } - tissue_mask = { - "inst": mask - } - - task_types = ["layer_task", "nuc_task"] - - return [ - {'task_type':"nuc_task", 'pred_inst':..., 'nuc_inst_info_dict'...,}, - {'task_type':"nuc_task", 'pred_inst':..., 'nuc_inst_info_dict'...,}, - ] - - return task_types, layer_task, nuc_task + return nuclei_seg, layer_seg @staticmethod def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple: diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index ccf2b42e9..c733fd94e 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -156,10 +156,6 @@ def post_process_patches( # skipcq: PYL-R0201 Args: raw_predictions (dask.array.Array): Raw model predictions as a dask array. - prediction_shape (tuple[int, ...]): - Shape of the prediction output. - prediction_dtype (type): - Data type of the prediction output. **kwargs (EngineABCRunParams): Additional runtime parameters used for post-processing. @@ -169,14 +165,18 @@ def post_process_patches( # skipcq: PYL-R0201 """ probabilities = raw_predictions["probabilities"] - predictions = [[] for _ in range(probabilities[0].shape[0])] - inst_dict = [[{}] for _ in range(probabilities[0].shape[0])] - for idx, probs_for_idx in enumerate(zip(*probabilities, strict=False)): - predictions[idx] = self.model.postproc_func(list(probs_for_idx)) + post_process_predictions = [ + self.model.postproc_func(list(probs_for_idx)) + for probs_for_idx in zip(*probabilities, strict=False) + ] + + raw_predictions = build_post_process_raw_predictions( + post_process_predictions=post_process_predictions, + raw_predictions=raw_predictions, + ) - raw_predictions[curr_task_type] = da.stack(predictions, axis=0) - for key in inst_dict[0]: - raw_predictions[key] = [d[key] for d in inst_dict] + # Need to update info_dict + _ = raw_predictions return raw_predictions @@ -306,3 +306,64 @@ def run( output_type=output_type, **kwargs, ) + + +def build_post_process_raw_predictions( + post_process_predictions: list[tuple], raw_predictions: dict +) -> dict: + """Merge per-image outputs into a task-organized prediction structure. + + This function takes a list of outputs, where each element corresponds to one + image and contains one or more segmentation dictionaries. Each segmentation + dictionary must include a ``"task_type"`` key along with any number of + additional fields (e.g., ``"predictions"``, ``"info_dict"``, or others). + + The function reorganizes these outputs into ``raw_predictions`` by grouping + entries under their respective task types. For each task, all keys except + ``"task_type"`` are stored in dictionaries indexed by ``img_id``. Existing + content in ``raw_predictions`` is preserved and extended as needed. + + Args: + post_process_predictions (list[tuple]): + A list where each element represents one image. Each element is an + iterable of segmentation dictionaries. Each segmentation dictionary + must contain a ``"task_type"`` field and may contain any number of + additional fields. + raw_predictions (dict): + A dictionary that will be updated in-place. It may already contain + task entries or other unrelated keys. New tasks and new fields are + added dynamically as they appear in ``outputs``. + + Returns: + dict: + The updated ``raw_predictions`` dictionary, containing all tasks and + their associated per-image fields. + + """ + tasks = set() + for seg_list in post_process_predictions: + for seg in seg_list: + task = seg["task_type"] + tasks.add(task) + + # Initialize task entry if needed + if task not in raw_predictions: + raw_predictions[task] = {} + + # For every key except task_type, store values by img_id + for key, value in seg.items(): + if key == "task_type": + continue + + # Initialize list for this key + if key not in raw_predictions[task]: + raw_predictions[task][key] = [] + + raw_predictions[task][key].append(value) + + for task in tasks: + for key, values in raw_predictions[task].items(): + if all(isinstance(v, (np.ndarray, da.Array)) for v in values): + raw_predictions[task][key] = da.stack(values, axis=0) + + return raw_predictions From 3447495c69b21fca1b7307f57c91195a2868481b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 12 Jan 2026 10:50:39 +0000 Subject: [PATCH 010/156] :construction: Update postprocessing to incorporate dictionary outputs --- requirements/requirements.txt | 1 - .../models/architecture/hovernetplus.py | 47 ++++++++++++++++--- .../models/engine/multi_task_segmentor.py | 7 +++ 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 21bfb6d53..6a2bb5a56 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -21,7 +21,6 @@ openslide-bin>=4.0.0.2 openslide-python>=1.4.0 pandas>=2.0.0 pillow>=9.3.0 -pyarrow>=22.0.0 pydicom>=2.3.1 # Used by wsidicom pyyaml>=6.0 requests>=2.28.1 diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 5dd6e3606..ed7768458 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -5,8 +5,9 @@ from collections import OrderedDict import cv2 -import dask +import dask.array as da import numpy as np +import pandas as pd import torch import torch.nn.functional as F # noqa: N812 from skimage import morphology @@ -313,10 +314,10 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: """ np_map, hv_map, tp_map, ls_map = raw_maps - np_map = np_map.compute() if isinstance(np_map, dask.array.Array) else np_map - hv_map = hv_map.compute() if isinstance(hv_map, dask.array.Array) else hv_map - tp_map = tp_map.compute() if isinstance(tp_map, dask.array.Array) else tp_map - ls_map = ls_map.compute() if isinstance(ls_map, dask.array.Array) else ls_map + np_map = np_map.compute() if isinstance(np_map, da.Array) else np_map + hv_map = hv_map.compute() if isinstance(hv_map, da.Array) else hv_map + tp_map = tp_map.compute() if isinstance(tp_map, da.Array) else tp_map + ls_map = ls_map.compute() if isinstance(ls_map, da.Array) else ls_map pred_inst = HoVerNetPlus._proc_np_hv(np_map, hv_map, scale_factor=0.5) # fx=0.5 as nuclear processing is at 0.5 mpp instead of 0.25 mpp @@ -327,16 +328,48 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) layer_info_dict = HoVerNetPlus._get_layer_info(pred_layer) + nuc_inst_info_dict_ = {} + if not nuc_inst_info_dict: + nuc_inst_info_dict_ = { # inst_id should start at 1 + "box": da.empty(shape=0), + "centroid": da.empty(shape=0), + "contour": da.empty(shape=0), + "prob": da.empty(shape=0), + "type": da.empty(shape=0), + } + else: + # dask dataframe does not support transpose + nuc_inst_info_df = pd.DataFrame(nuc_inst_info_dict).transpose() + for key, col in nuc_inst_info_df.items(): + nuc_inst_info_dict_[key] = da.from_array( + col.to_numpy(), + chunks=(len(col),), # one chunk, avoids auto-rechunking + ) + nuclei_seg = { "task_type": "nuclei_segmentation", "predictions": pred_inst, - "info_dict": nuc_inst_info_dict, + "info_dict": nuc_inst_info_dict_, } + layer_info_dict_ = {} + if not nuc_inst_info_dict: + layer_info_dict_ = { # inst_id should start at 1 + "contour": da.empty(shape=0), + } + else: + # dask dataframe does not support transpose + layer_info_df = pd.DataFrame(layer_info_dict).transpose() + for key, col in layer_info_df.items(): + layer_info_dict_[key] = da.from_array( + col.to_numpy(), + chunks=(len(col),), # one chunk, avoids auto-rechunking + ) + layer_seg = { "task_type": "layer_segmentation", "predictions": pred_layer, - "info_dict": layer_info_dict, + "info_dict": layer_info_dict_, } return nuclei_seg, layer_seg diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index c733fd94e..ec91691a0 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -366,4 +366,11 @@ def build_post_process_raw_predictions( if all(isinstance(v, (np.ndarray, da.Array)) for v in values): raw_predictions[task][key] = da.stack(values, axis=0) + if all(isinstance(v, dict) for v in values): + first = values[0] + # Expand each subkey into a list + expanded = {subkey: [d[subkey] for d in values] for subkey in first} + + raw_predictions[task][key] = expanded + return raw_predictions From 84a717d99d19df158ac22bd8e8d97437d95b42e8 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 12 Jan 2026 12:15:29 +0000 Subject: [PATCH 011/156] :white_check_mark: Add test for dictionary output --- tests/engines/test_multi_task_segmentor.py | 41 +++++++++++++++++-- tiatoolbox/models/architecture/hovernet.py | 24 ++++++++++- .../models/architecture/hovernetplus.py | 5 ++- 3 files changed, 62 insertions(+), 8 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 4dbd50c5a..5d2f99a64 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -3,7 +3,7 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Final +from typing import TYPE_CHECKING, Any, Final import numpy as np import torch @@ -13,12 +13,34 @@ from tiatoolbox.wsicore import WSIReader if TYPE_CHECKING: - from collections.abc import Callable - + from collections.abc import Callable, Sequence +OutputType = dict[str, Any] | Any device = "cuda" if toolbox_env.has_gpu() else "cpu" +def assert_output_lengths(output: OutputType, expected_counts: Sequence[int]) -> None: + """Assert lengths of output dict fields against expected counts.""" + output = output["info_dict"] + for field in output: + for i, expected in enumerate(expected_counts): + assert len(output[field][i]) == expected, f"{field}[{i}] mismatch" + + +def assert_predictions_and_boxes( + output: OutputType, expected_counts: Sequence[int], *, is_zarr: bool = False +) -> None: + """Assert predictions maxima and box lengths against expected counts.""" + # predictions maxima + for idx, expected in enumerate(expected_counts): + if is_zarr and idx == 2: + # zarr output doesn't store predictions for patch 2 + continue + assert np.max(output["predictions"][idx][:]) == expected, ( + f"predictions[{idx}] mismatch" + ) + + def test_mtsegmentor_init() -> None: """Tests SemanticSegmentor initialization.""" segmentor = MultiTaskSegmentor(model="hovernetplus-oed", device=device) @@ -50,7 +72,7 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N assert not mtsegmentor.patch_mode - _ = mtsegmentor.run( + output_dict = mtsegmentor.run( images=patches, return_probabilities=True, return_labels=False, @@ -58,4 +80,15 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N patch_mode=True, ) + expected_counts = [50, 17, 0] + assert_output_lengths(output_dict["nuclei_segmentation"], expected_counts) + assert_predictions_and_boxes( + output_dict["nuclei_segmentation"], expected_counts, is_zarr=False + ) + expected_counts = [1, 1, 0] + assert_output_lengths(output_dict["layer_segmentation"], expected_counts) + assert_predictions_and_boxes( + output_dict["layer_segmentation"], expected_counts, is_zarr=False + ) + _ = track_tmp_path diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 496410545..a36e43cf1 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -6,7 +6,9 @@ from collections import OrderedDict import cv2 +import dask.array as da import numpy as np +import pandas as pd import torch import torch.nn.functional as F # noqa: N812 from scipy import ndimage @@ -678,7 +680,7 @@ def get_instance_info(pred_inst: np.ndarray, pred_type: np.ndarray = None) -> di inst_info_dict[inst_id] = { # inst_id should start at 1 "box": inst_box, "centroid": inst_centroid, - "contour": inst_contour, + "contours": inst_contour, "prob": None, "type": None, } @@ -780,10 +782,28 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: pred_inst = HoVerNet._proc_np_hv(np_map, hv_map) nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) + nuc_inst_info_dict_ = {} + if not nuc_inst_info_dict: + nuc_inst_info_dict_ = { # inst_id should start at 1 + "box": da.empty(shape=0), + "centroid": da.empty(shape=0), + "contours": da.empty(shape=0), + "prob": da.empty(shape=0), + "type": da.empty(shape=0), + } + else: + # dask dataframe does not support transpose + nuc_inst_info_df = pd.DataFrame(nuc_inst_info_dict).transpose() + for key, col in nuc_inst_info_df.items(): + nuc_inst_info_dict_[key] = da.from_array( + col.to_numpy(), + chunks=(len(col),), # one chunk, avoids auto-rechunking + ) + nuclei_seg = { "task_type": "nuclei_segmentation", "predictions": pred_inst, - "info_dict": nuc_inst_info_dict, + "info_dict": nuc_inst_info_dict_, } return (nuclei_seg,) # Ensure return type is tuple. diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index ed7768458..11c3f6e11 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -333,7 +333,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: nuc_inst_info_dict_ = { # inst_id should start at 1 "box": da.empty(shape=0), "centroid": da.empty(shape=0), - "contour": da.empty(shape=0), + "contours": da.empty(shape=0), "prob": da.empty(shape=0), "type": da.empty(shape=0), } @@ -355,7 +355,8 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: layer_info_dict_ = {} if not nuc_inst_info_dict: layer_info_dict_ = { # inst_id should start at 1 - "contour": da.empty(shape=0), + "contours": da.empty(shape=0), + "type": da.empty(shape=0), } else: # dask dataframe does not support transpose From 6065409cd722e6757ad53acc87273318b5fb5529 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 12 Jan 2026 12:22:20 +0000 Subject: [PATCH 012/156] :bug: Fix hovernet and hovernetplus tests --- tests/models/test_hovernet.py | 4 ++-- tests/models/test_hovernetplus.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/test_hovernet.py b/tests/models/test_hovernet.py index aeb003721..a88bdf1e9 100644 --- a/tests/models/test_hovernet.py +++ b/tests/models/test_hovernet.py @@ -38,7 +38,7 @@ def test_functionality(remote_sample: Callable) -> None: output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) - assert len(output[1]) > 0, "Must have some nuclei." + assert len(output[0]["info_dict"]) > 0, "Must have some nuclei." # * test original mode on CoNSeP dataset (architecture used in HoVerNet paper) patch = reader.read_bounds( @@ -55,7 +55,7 @@ def test_functionality(remote_sample: Callable) -> None: output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) - assert len(output[1]) > 0, "Must have some nuclei." + assert len(output[0]["info_dict"]) > 0, "Must have some nuclei." # test crash when providing exotic mode with pytest.raises(ValueError, match=r".*Invalid mode.*"): diff --git a/tests/models/test_hovernetplus.py b/tests/models/test_hovernetplus.py index f336ef14f..9d45202f1 100644 --- a/tests/models/test_hovernetplus.py +++ b/tests/models/test_hovernetplus.py @@ -33,5 +33,5 @@ def test_functionality(remote_sample: Callable) -> None: assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches." output = [v[0] for v in output] output = model.postproc(output) - assert len(output[1]) > 0, "Must have some nuclei." - assert len(output[3]) > 0, "Must have some layers." + assert len(output[0]["info_dict"]) > 0, "Must have some nuclei." + assert len(output[1]["info_dict"]) > 0, "Must have some layers." From bd6d4b7f2062f5ae2e07d33ec0175226822fb374 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 12 Jan 2026 17:31:57 +0000 Subject: [PATCH 013/156] :white_check_mark: Add tests for zarr output --- tests/engines/test_multi_task_segmentor.py | 49 +++- .../models/architecture/hovernetplus.py | 4 +- tiatoolbox/models/engine/engine_abc.py | 20 +- .../models/engine/multi_task_segmentor.py | 239 +++++++++++++----- 4 files changed, 230 insertions(+), 82 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 5d2f99a64..5ec6fd576 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -7,6 +7,7 @@ import numpy as np import torch +import zarr from tiatoolbox.models.engine.multi_task_segmentor import MultiTaskSegmentor from tiatoolbox.utils import env_detection as toolbox_env @@ -19,10 +20,11 @@ device = "cuda" if toolbox_env.has_gpu() else "cpu" -def assert_output_lengths(output: OutputType, expected_counts: Sequence[int]) -> None: +def assert_output_lengths( + output: OutputType, expected_counts: Sequence[int], fields: list[str] +) -> None: """Assert lengths of output dict fields against expected counts.""" - output = output["info_dict"] - for field in output: + for field in fields: for i, expected in enumerate(expected_counts): assert len(output[field][i]) == expected, f"{field}[{i}] mismatch" @@ -80,15 +82,42 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N patch_mode=True, ) - expected_counts = [50, 17, 0] - assert_output_lengths(output_dict["nuclei_segmentation"], expected_counts) + expected_counts_nuclei = [50, 17, 0] + assert_output_lengths( + output_dict["nuclei_segmentation"], + expected_counts_nuclei, + fields=["box", "centroid", "contours", "prob", "type"], + ) assert_predictions_and_boxes( - output_dict["nuclei_segmentation"], expected_counts, is_zarr=False + output_dict["nuclei_segmentation"], expected_counts_nuclei, is_zarr=False + ) + expected_counts_layer = [1, 1, 0] + assert_output_lengths( + output_dict["layer_segmentation"], + expected_counts_layer, + fields=["contours", "type"], ) - expected_counts = [1, 1, 0] - assert_output_lengths(output_dict["layer_segmentation"], expected_counts) assert_predictions_and_boxes( - output_dict["layer_segmentation"], expected_counts, is_zarr=False + output_dict["layer_segmentation"], expected_counts_layer, is_zarr=False + ) + + # Zarr output comparison + output_zarr = mtsegmentor.run( + images=patches, + patch_mode=True, + device=device, + output_type="zarr", + save_dir=track_tmp_path / "patch_output_zarr", ) + output_zarr = zarr.open(output_zarr, mode="r") - _ = track_tmp_path + assert_output_lengths( + output_zarr["nuclei_segmentation"], + expected_counts_nuclei, + fields=["box", "centroid", "contours", "prob", "type"], + ) + assert_output_lengths( + output_zarr["layer_segmentation"], + expected_counts_layer, + fields=["contours", "type"], + ) diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 11c3f6e11..ffcd42ae4 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -376,7 +376,9 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: return nuclei_seg, layer_seg @staticmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple: + def infer_batch( # skipcq: PYL-W0221 + model: nn.Module, batch_data: np.ndarray, *, device: str + ) -> tuple: """Run inference on an input batch. This contains logic for forward operation as well as batch i/o diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index f001c29f0..cce1c0ffe 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -748,6 +748,7 @@ def save_predictions_as_zarr( processed_predictions: dict, save_path: Path, keys_to_compute: list, + task_name: str | None = None, ) -> Path: """Save model predictions as a zarr file. @@ -761,6 +762,8 @@ def save_predictions_as_zarr( Path to save the zarr file. keys_to_compute (list): List of keys in processed_predictions to save. + task_name (str): + Task Name for Multitask outputs. Returns: save_path (Path): @@ -770,13 +773,20 @@ def save_predictions_as_zarr( if is_zarr(save_path): zarr_group = zarr.open(save_path, mode="r") keys_to_compute = [k for k in keys_to_compute if k not in zarr_group] + + # If the task group already exists, only compute missing keys + if task_name in zarr_group: + task_group = zarr_group[task_name] + keys_to_compute = [k for k in keys_to_compute if k not in task_group] + write_tasks = [] for key in keys_to_compute: dask_output = processed_predictions[key] if isinstance(dask_output, da.Array): dask_output = dask_output.rechunk("auto") + component = key if task_name is None else f"{task_name}/{key}" task = dask_output.to_zarr( - url=save_path, component=key, compute=False, object_codec=None + url=save_path, component=component, compute=False, object_codec=None ) write_tasks.append(task) @@ -785,9 +795,12 @@ def save_predictions_as_zarr( ): for i, dask_array in enumerate(dask_output): object_codec = Pickle() if dask_array.dtype == "object" else None + component = ( + f"{key}/{i}" if task_name is None else f"{task_name}/{key}/{i}" + ) task = dask_array.to_zarr( url=save_path, - component=f"{key}/{i}", + component=component, compute=False, zarr_array_kwargs={"object_codec": object_codec}, ) @@ -1387,6 +1400,9 @@ def _run_patch_mode( **kwargs, ) + if isinstance(out, dict): + return out + msg = f"Output file saved at {out}." logger.info(msg=msg) return out diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index ec91691a0..578fa7a88 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -41,6 +41,7 @@ def __init__( verbose: bool = True, ) -> None: """Initialize :class:`NucleusInstanceSegmentor`.""" + self.tasks = set() super().__init__( model=model, batch_size=batch_size, @@ -170,7 +171,7 @@ def post_process_patches( # skipcq: PYL-R0201 for probs_for_idx in zip(*probabilities, strict=False) ] - raw_predictions = build_post_process_raw_predictions( + raw_predictions = self.build_post_process_raw_predictions( post_process_predictions=post_process_predictions, raw_predictions=raw_predictions, ) @@ -180,6 +181,174 @@ def post_process_patches( # skipcq: PYL-R0201 return raw_predictions + def build_post_process_raw_predictions( + self: MultiTaskSegmentor, + post_process_predictions: list[tuple], + raw_predictions: dict, + ) -> dict: + """Merge per-image outputs into a task-organized prediction structure. + + This function takes a list of outputs, where each element corresponds to one + image and contains one or more segmentation dictionaries. Each segmentation + dictionary must include a ``"task_type"`` key along with any number of + additional fields (e.g., ``"predictions"``, ``"info_dict"``, or others). + + The function reorganizes these outputs into ``raw_predictions`` by grouping + entries under their respective task types. For each task, all keys except + ``"task_type"`` are stored in dictionaries indexed by ``img_id``. Existing + content in ``raw_predictions`` is preserved and extended as needed. + + Args: + post_process_predictions (list[tuple]): + A list where each element represents one image. Each element is an + iterable of segmentation dictionaries. Each segmentation dictionary + must contain a ``"task_type"`` field and may contain any number of + additional fields. + raw_predictions (dict): + A dictionary that will be updated in-place. It may already contain + task entries or other unrelated keys. New tasks and new fields are + added dynamically as they appear in ``outputs``. + + Returns: + dict: + The updated ``raw_predictions`` dictionary, containing all tasks and + their associated per-image fields. + + """ + tasks = set() + for seg_list in post_process_predictions: + for seg in seg_list: + task = seg["task_type"] + tasks.add(task) + + # Initialize task entry if needed + if task not in raw_predictions: + raw_predictions[task] = {} + + # For every key except task_type, store values by img_id + for key, value in seg.items(): + if key == "task_type": + continue + + # Initialize list for this key + if key not in raw_predictions[task]: + raw_predictions[task][key] = [] + + raw_predictions[task][key].append(value) + + for task in tasks: + task_dict = raw_predictions[task] + for key in list(task_dict.keys()): + values = task_dict[key] + if all(isinstance(v, (np.ndarray, da.Array)) for v in values): + raw_predictions[task][key] = da.stack(values, axis=0) + continue + + if all(isinstance(v, dict) for v in values): + first = values[0] + + # Add new keys safely + for subkey in first: + raw_predictions[task][subkey] = [d[subkey] for d in values] + + del raw_predictions[task][key] + + self.tasks = tasks + return raw_predictions + + def save_predictions( + self: MultiTaskSegmentor, + processed_predictions: dict, + output_type: str, + save_path: Path | None = None, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict | AnnotationStore | Path | list[Path]: + """Save model predictions to disk or return them in memory. + + Depending on the output type, this method saves predictions as a zarr group, + an AnnotationStore (SQLite database), or returns them as a dictionary. + + Args: + processed_predictions (dict): + Dictionary containing processed model predictions. + output_type (str): + Desired output format. + Supported values are "dict", "zarr", and "annotationstore". + save_path (Path | None): + Path to save the output file. + Required for "zarr" and "annotationstore" formats. + **kwargs (EngineABCRunParams): + Additional runtime parameters to update engine attributes. + + Optional Keys: + auto_get_mask (bool): + Automatically generate segmentation masks using + `wsireader.tissue_mask()` during processing. + batch_size (int): + Number of image patches per forward pass. + class_dict (dict): + Mapping of classification outputs to class names. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). + See https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details. + labels (list): + Optional labels for input images. Only a single label per image + is supported. + memory_threshold (int): + Memory usage threshold (percentage) to trigger caching behavior. + num_workers (int): + Number of workers for DataLoader and post-processing. + output_file (str): + Filename for saving output (e.g., "zarr" or "annotationstore"). + return_labels (bool): + Whether to return labels with predictions. + scale_factor (tuple[float, float]): + Scale factor for annotations (model_mpp / slide_mpp). + Used to convert coordinates from non-baseline to baseline + resolution. + stride_shape (IntPair): + Stride used during WSI processing, at requested read resolution. + Must be positive. Defaults to `patch_input_shape` if not + provided. + verbose (bool): + Whether to enable verbose logging. + + Returns: + dict | AnnotationStore | Path | list [Path]: + - If output_type is "dict": returns predictions as a dictionary. + - If output_type is "zarr": returns path to saved zarr file. + - If output_type is "annotationstore": returns an AnnotationStore + or path to .db file. + + Raises: + TypeError: + If an unsupported output_type is provided. + + """ + if output_type.lower() == "dict": + return super().save_predictions( + processed_predictions, output_type, save_path=save_path, **kwargs + ) + + if output_type.lower() == "zarr": + for task_name in self.tasks: + keys_to_compute = [ + k + for k in processed_predictions[task_name] + if k not in self.drop_keys + ] + _ = self.save_predictions_as_zarr( + processed_predictions=processed_predictions[task_name], + save_path=save_path, + keys_to_compute=keys_to_compute, + task_name=task_name, + ) + return save_path + + # Need to update for AnnotationStore + return save_path + def run( self: MultiTaskSegmentor, images: list[os.PathLike | Path | WSIReader] | np.ndarray, @@ -306,71 +475,3 @@ def run( output_type=output_type, **kwargs, ) - - -def build_post_process_raw_predictions( - post_process_predictions: list[tuple], raw_predictions: dict -) -> dict: - """Merge per-image outputs into a task-organized prediction structure. - - This function takes a list of outputs, where each element corresponds to one - image and contains one or more segmentation dictionaries. Each segmentation - dictionary must include a ``"task_type"`` key along with any number of - additional fields (e.g., ``"predictions"``, ``"info_dict"``, or others). - - The function reorganizes these outputs into ``raw_predictions`` by grouping - entries under their respective task types. For each task, all keys except - ``"task_type"`` are stored in dictionaries indexed by ``img_id``. Existing - content in ``raw_predictions`` is preserved and extended as needed. - - Args: - post_process_predictions (list[tuple]): - A list where each element represents one image. Each element is an - iterable of segmentation dictionaries. Each segmentation dictionary - must contain a ``"task_type"`` field and may contain any number of - additional fields. - raw_predictions (dict): - A dictionary that will be updated in-place. It may already contain - task entries or other unrelated keys. New tasks and new fields are - added dynamically as they appear in ``outputs``. - - Returns: - dict: - The updated ``raw_predictions`` dictionary, containing all tasks and - their associated per-image fields. - - """ - tasks = set() - for seg_list in post_process_predictions: - for seg in seg_list: - task = seg["task_type"] - tasks.add(task) - - # Initialize task entry if needed - if task not in raw_predictions: - raw_predictions[task] = {} - - # For every key except task_type, store values by img_id - for key, value in seg.items(): - if key == "task_type": - continue - - # Initialize list for this key - if key not in raw_predictions[task]: - raw_predictions[task][key] = [] - - raw_predictions[task][key].append(value) - - for task in tasks: - for key, values in raw_predictions[task].items(): - if all(isinstance(v, (np.ndarray, da.Array)) for v in values): - raw_predictions[task][key] = da.stack(values, axis=0) - - if all(isinstance(v, dict) for v in values): - first = values[0] - # Expand each subkey into a list - expanded = {subkey: [d[subkey] for d in values] for subkey in first} - - raw_predictions[task][key] = expanded - - return raw_predictions From c7238b23efeffb451afdc2d8ec9cd40c59dffe3f Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 12 Jan 2026 22:15:03 +0000 Subject: [PATCH 014/156] :bug: Fix deepsource errors --- tiatoolbox/models/engine/engine_abc.py | 65 ++++++++++++------- .../models/engine/multi_task_segmentor.py | 4 +- 2 files changed, 43 insertions(+), 26 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index cce1c0ffe..7034367c0 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -743,6 +743,41 @@ def save_predictions( msg = f"Unsupported output type: {output_type}" raise TypeError(msg) + @staticmethod + def _get_tasks_for_saving_zarr( + dask_output: da.Array | list, + key: str, + task_name: str | None, + save_path: Path, + write_tasks: list, + ) -> list: + """Helper function to get dask tasks for saving zarr output.""" + if isinstance(dask_output, da.Array): + dask_output = dask_output.rechunk("auto") + component = key if task_name is None else f"{task_name}/{key}" + task = dask_output.to_zarr( + url=save_path, component=component, compute=False, object_codec=None + ) + write_tasks.append(task) + + if isinstance(dask_output, list) and all( + isinstance(dask_array, da.Array) for dask_array in dask_output + ): + for i, dask_array in enumerate(dask_output): + object_codec = Pickle() if dask_array.dtype == "object" else None + component = ( + f"{key}/{i}" if task_name is None else f"{task_name}/{key}/{i}" + ) + task = dask_array.to_zarr( + url=save_path, + component=component, + compute=False, + zarr_array_kwargs={"object_codec": object_codec}, + ) + write_tasks.append(task) + + return write_tasks + def save_predictions_as_zarr( self: EngineABC, processed_predictions: dict, @@ -782,29 +817,13 @@ def save_predictions_as_zarr( write_tasks = [] for key in keys_to_compute: dask_output = processed_predictions[key] - if isinstance(dask_output, da.Array): - dask_output = dask_output.rechunk("auto") - component = key if task_name is None else f"{task_name}/{key}" - task = dask_output.to_zarr( - url=save_path, component=component, compute=False, object_codec=None - ) - write_tasks.append(task) - - if isinstance(dask_output, list) and all( - isinstance(dask_array, da.Array) for dask_array in dask_output - ): - for i, dask_array in enumerate(dask_output): - object_codec = Pickle() if dask_array.dtype == "object" else None - component = ( - f"{key}/{i}" if task_name is None else f"{task_name}/{key}/{i}" - ) - task = dask_array.to_zarr( - url=save_path, - component=component, - compute=False, - zarr_array_kwargs={"object_codec": object_codec}, - ) - write_tasks.append(task) + write_tasks = self._get_tasks_for_saving_zarr( + dask_output=dask_output, + key=key, + task_name=task_name, + save_path=save_path, + write_tasks=write_tasks, + ) msg = f"Saving output to {save_path}." logger.info(msg=msg) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 578fa7a88..68b4f0ea6 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -279,7 +279,6 @@ def save_predictions( Required for "zarr" and "annotationstore" formats. **kwargs (EngineABCRunParams): Additional runtime parameters to update engine attributes. - Optional Keys: auto_get_mask (bool): Automatically generate segmentation masks using @@ -290,8 +289,7 @@ def save_predictions( Mapping of classification outputs to class names. device (str): Device to run the model on (e.g., "cpu", "cuda"). - See https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details. + See :class:`torch.device` for more details. labels (list): Optional labels for input images. Only a single label per image is supported. From fa3dcc164424fa6a245c971b1eb4a9fef35a8690 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 13 Jan 2026 10:41:38 +0000 Subject: [PATCH 015/156] :white_check_mark: Add test for single output. --- tests/engines/test_multi_task_segmentor.py | 97 +++++++++++++++++++ tiatoolbox/models/architecture/hovernet.py | 5 +- .../models/engine/multi_task_segmentor.py | 15 ++- 3 files changed, 111 insertions(+), 6 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 5ec6fd576..18a261e0d 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -43,6 +43,23 @@ def assert_predictions_and_boxes( ) +def assert_output_equal( + output_a: OutputType, + output_b: OutputType, + fields: Sequence[str], + indices_a: Sequence[int], + indices_b: Sequence[int], +) -> None: + """Assert equality of arrays across outputs for given fields/indices.""" + for field in fields: + for i_a, i_b in zip(indices_a, indices_b, strict=False): + left = output_a[field][i_a] + right = output_b[field][i_b] + assert all( + np.array_equal(a, b) for a, b in zip(left, right, strict=False) + ), f"{field}[{i_a}] vs {field}[{i_b}] mismatch" + + def test_mtsegmentor_init() -> None: """Tests SemanticSegmentor initialization.""" segmentor = MultiTaskSegmentor(model="hovernetplus-oed", device=device) @@ -121,3 +138,83 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N expected_counts_layer, fields=["contours", "type"], ) + + assert_output_equal( + output_zarr["nuclei_segmentation"], + output_dict["nuclei_segmentation"], + fields=["box", "centroid", "contours", "prob", "type"], + indices_a=[0, 1, 2], + indices_b=[0, 1, 2], + ) + assert_output_equal( + output_zarr["layer_segmentation"], + output_dict["layer_segmentation"], + fields=["contours", "type"], + indices_a=[0, 1, 2], + indices_b=[0, 1, 2], + ) + + +def test_single_output_mtsegmentor( + remote_sample: Callable, track_tmp_path: Path +) -> None: + """Tests MultiTaskSegmentor on single task output.""" + mtsegmentor = MultiTaskSegmentor( + model="hovernet_fast-pannuke", batch_size=32, verbose=False, device=device + ) + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + mini_wsi = WSIReader.open(mini_wsi_svs) + size = (256, 256) + resolution = 0.25 + units: Final = "mpp" + + patch1 = mini_wsi.read_rect( + location=(0, 0), size=size, resolution=resolution, units=units + ) + patch2 = mini_wsi.read_rect( + location=(512, 512), size=size, resolution=resolution, units=units + ) + patch3 = np.zeros_like(patch1) + patches = np.stack([patch1, patch2, patch3], axis=0) + + assert not mtsegmentor.patch_mode + + output_dict = mtsegmentor.run( + images=patches, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + ) + + expected_counts_nuclei = [41, 17, 0] + assert_output_lengths( + output_dict, + expected_counts_nuclei, + fields=["box", "centroid", "contours", "prob", "type"], + ) + assert_predictions_and_boxes(output_dict, expected_counts_nuclei, is_zarr=False) + + # Zarr output comparison + output_zarr = mtsegmentor.run( + images=patches, + patch_mode=True, + device=device, + output_type="zarr", + save_dir=track_tmp_path / "patch_output_zarr", + ) + output_zarr = zarr.open(output_zarr, mode="r") + + assert_output_lengths( + output_zarr, + expected_counts_nuclei, + fields=["box", "centroid", "contours", "prob", "type"], + ) + + assert_output_equal( + output_zarr, + output_dict, + fields=["box", "centroid", "contours", "prob", "type"], + indices_a=[0, 1, 2], + indices_b=[0, 1, 2], + ) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index a36e43cf1..1eef26b97 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -778,7 +778,10 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: tp_map = None np_map, hv_map = raw_maps - pred_type = tp_map + np_map = np_map.compute() if isinstance(np_map, da.Array) else np_map + hv_map = hv_map.compute() if isinstance(hv_map, da.Array) else hv_map + pred_type = tp_map.compute() if isinstance(tp_map, da.Array) else tp_map + pred_inst = HoVerNet._proc_np_hv(np_map, hv_map) nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 68b4f0ea6..aa3e2ddd2 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -325,22 +325,27 @@ def save_predictions( """ if output_type.lower() == "dict": + # If there is a single task simplify the output. + if len(self.tasks) == 1: + task_output = processed_predictions.pop(next(iter(self.tasks))) + processed_predictions.update(task_output) return super().save_predictions( processed_predictions, output_type, save_path=save_path, **kwargs ) if output_type.lower() == "zarr": for task_name in self.tasks: + processed_predictions_ = processed_predictions.pop(task_name) + # If there is a single task simplify the output. + task_name_ = None if len(self.tasks) == 1 else task_name keys_to_compute = [ - k - for k in processed_predictions[task_name] - if k not in self.drop_keys + k for k in processed_predictions_ if k not in self.drop_keys ] _ = self.save_predictions_as_zarr( - processed_predictions=processed_predictions[task_name], + processed_predictions=processed_predictions_, save_path=save_path, keys_to_compute=keys_to_compute, - task_name=task_name, + task_name=task_name_, ) return save_path From b042e86b118ec1d2ac5d694aa8a54892235e850b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 13 Jan 2026 10:47:35 +0000 Subject: [PATCH 016/156] :white_check_mark: Add test for save_dir and None output_type --- tests/engines/test_engine_abc.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 0e7449f85..50da6309d 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -136,6 +136,38 @@ def test_incorrect_output_type() -> NoReturn: ) +def test_incorrect_output_type_save_dir() -> NoReturn: + """Test EngineABC for None output_type and output type zarr/annotationstore.""" + pretrained_model = "alexnet-kather100k" + + # Test engine run without ioconfig + eng = TestEngineABC(model=pretrained_model) + + with pytest.raises( + ValueError, + match=r".*Please provide save_dir for output_type=zarr*", + ): + _ = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ioconfig=None, + output_type="zarr", + ) + + with pytest.raises( + ValueError, + match=r".*Please provide save_dir for output_type=annotationstore*", + ): + _ = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ioconfig=None, + output_type="annotationstore", + ) + + def test_pretrained_ioconfig() -> NoReturn: """Test EngineABC initialization with pretrained model name in the toolbox.""" pretrained_model = "alexnet-kather100k" From 47efdd11acccf1d231d1f2f710c7f46577dcc2f4 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 13 Jan 2026 11:11:53 +0000 Subject: [PATCH 017/156] :white_check_mark: Improve test coverage. --- tests/engines/test_multi_task_segmentor.py | 26 ++++++++++++++++++- .../models/engine/multi_task_segmentor.py | 26 +++++++++---------- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 18a261e0d..2b3cc6c9a 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Final import numpy as np +import pytest import torch import zarr @@ -61,13 +62,36 @@ def assert_output_equal( def test_mtsegmentor_init() -> None: - """Tests SemanticSegmentor initialization.""" + """Tests MultiTaskSegmentor initialization.""" segmentor = MultiTaskSegmentor(model="hovernetplus-oed", device=device) assert isinstance(segmentor, MultiTaskSegmentor) assert isinstance(segmentor.model, torch.nn.Module) +def test_raise_value_error_return_labels_wsi( + sample_svs: Path, + track_tmp_path: Path, +) -> None: + """Tests MultiTaskSegmentor return_labels error.""" + mtsegmentor = MultiTaskSegmentor(model="hovernetplus-oed", device=device) + + with pytest.raises( + ValueError, + match=r".*return_labels` is not supported for MultiTaskSegmentor.", + ): + _ = mtsegmentor.run( + images=[sample_svs], + return_probabilities=False, + return_labels=True, + device=device, + patch_mode=False, + save_dir=track_tmp_path / "wsi_out_check", + batch_size=2, + output_type="zarr", + ) + + def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> None: """Tests MultiTaskSegmentor on image patches.""" mtsegmentor = MultiTaskSegmentor( diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index aa3e2ddd2..091bd0863 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -75,13 +75,12 @@ def infer_patches( Dictionary containing prediction results as Dask arrays. Keys include: - "probabilities": Model output probabilities. - - "labels": Ground truth labels (if `return_labels` is True). - "coordinates": Patch coordinates (if `return_coordinates` is True). """ keys = ["probabilities"] - labels, coordinates = [], [] + coordinates = [] # Expected number of outputs from the model batch_output = self.model.infer_batch( @@ -130,9 +129,6 @@ def infer_patches( ) ) - if self.return_labels: - labels.append(da.from_array(np.array(batch_data["label"]))) - for i in range(num_expected_output): raw_predictions["probabilities"][i] = da.concatenate( probabilities[i], axis=0 @@ -290,17 +286,12 @@ def save_predictions( device (str): Device to run the model on (e.g., "cpu", "cuda"). See :class:`torch.device` for more details. - labels (list): - Optional labels for input images. Only a single label per image - is supported. memory_threshold (int): Memory usage threshold (percentage) to trigger caching behavior. num_workers (int): Number of workers for DataLoader and post-processing. output_file (str): Filename for saving output (e.g., "zarr" or "annotationstore"). - return_labels (bool): - Whether to return labels with predictions. scale_factor (tuple[float, float]): Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates from non-baseline to baseline @@ -411,9 +402,7 @@ def run( Mapping of classification outputs to class names. device (str): Device to run the model on (e.g., "cpu", "cuda"). - labels (list): - Optional labels for input images. Only a single label per image - is supported. + memory_threshold (int): Memory usage threshold (percentage) to trigger caching behavior. num_workers (int): @@ -425,7 +414,7 @@ def run( patch_output_shape (tuple[int, int]): Shape of output patches (height, width). return_labels (bool): - Whether to return labels with predictions. + Whether to return labels with predictions. Should be False. return_probabilities (bool): Whether to return per-class probabilities. scale_factor (tuple[float, float]): @@ -466,6 +455,15 @@ def run( ... "/path/to/wsi1.db" """ + return_labels = kwargs.get("return_labels") + + # Passing multitask labels causes unnecessary memory overheads + if return_labels: + msg = "`return_labels` is not supported for MultiTaskSegmentor." + raise ValueError(msg) + + kwargs["return_labels"] = False + return super().run( images=images, masks=masks, From 645986d117455fc20fd2a10209bf546ffb11b259 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 13 Jan 2026 12:41:14 +0000 Subject: [PATCH 018/156] :white_check_mark: Add test for annotationstore single output. --- tests/engines/test_multi_task_segmentor.py | 68 +++++++ .../models/engine/multi_task_segmentor.py | 180 +++++++++++++++--- 2 files changed, 223 insertions(+), 25 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 2b3cc6c9a..0f1d0637f 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -10,6 +10,7 @@ import torch import zarr +from tiatoolbox.annotation import SQLiteStore from tiatoolbox.models.engine.multi_task_segmentor import MultiTaskSegmentor from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.wsicore import WSIReader @@ -242,3 +243,70 @@ def test_single_output_mtsegmentor( indices_a=[0, 1, 2], indices_b=[0, 1, 2], ) + + # AnnotationStore output comparison + output_ann = mtsegmentor.run( + images=patches, + patch_mode=True, + device=device, + output_type="annotationstore", + save_dir=track_tmp_path / "patch_output_annotationstore", + ) + + assert len(output_ann) == 3 + assert output_ann[0] == track_tmp_path / "patch_output_annotationstore" / "0.db" + + for patch_idx, db_path in enumerate(output_ann): + assert ( + db_path + == track_tmp_path / "patch_output_annotationstore" / f"{patch_idx}.db" + ) + store_ = SQLiteStore.open(db_path) + annotations_ = store_.values() + annotations_geometry_type = [ + str(annotation_.geometry_type) for annotation_ in annotations_ + ] + annotations_list = list(annotations_) + if expected_counts_nuclei[patch_idx] > 0: + assert "Polygon" in annotations_geometry_type + + # Build result dict from annotation properties + result = {} + for ann in annotations_list: + for key, value in ann.properties.items(): + result.setdefault(key, []).append(value) + result["contours"] = [ + list(poly.exterior.coords) + for poly in (a.geometry for a in annotations_list) + ] + + # wrap it to make it compatible to assert_output_lengths + result_ = { + field: [result[field]] + for field in ["box", "centroid", "contours", "prob", "type"] + } + + # Lengths and equality checks for this patch + assert_output_lengths( + result_, + expected_counts=[expected_counts_nuclei[patch_idx]], + fields=["box", "centroid", "contours", "prob", "type"], + ) + assert_output_equal( + result_, + output_dict, + fields=["box", "centroid", "prob", "type"], + indices_a=[0], + indices_b=[patch_idx], + ) + + # Contour check (discard last point) + assert all( + np.array_equal(np.array(a[:-1], dtype=int), np.array(b, dtype=int)) + for a, b in zip( + result["contours"], output_dict["contours"][patch_idx], strict=False + ) + ) + else: + assert annotations_geometry_type == [] + assert annotations_list == [] diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 091bd0863..4221e9ffc 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -2,20 +2,26 @@ from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING import dask.array as da import numpy as np import torch +import zarr +from shapely.geometry import shape as feature2geometry from typing_extensions import Unpack -from tiatoolbox.utils.misc import get_tqdm +from tiatoolbox import logger +from tiatoolbox.annotation import SQLiteStore +from tiatoolbox.annotation.storage import Annotation +from tiatoolbox.utils.misc import get_tqdm, make_valid_poly +from tiatoolbox.wsicore.wsireader import is_zarr from .semantic_segmentor import SemanticSegmentor, SemanticSegmentorRunParams if TYPE_CHECKING: # pragma: no cover import os - from pathlib import Path from torch.utils.data import DataLoader @@ -252,6 +258,39 @@ def build_post_process_raw_predictions( self.tasks = tasks return raw_predictions + def _save_predictions_as_dict_zarr( + self: MultiTaskSegmentor, + processed_predictions: dict, + output_type: str, + save_path: Path | None = None, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict | AnnotationStore | Path | list[Path]: + """Helper function to save predictions as dictionary or zarr.""" + if output_type.lower() == "dict": + # If there is a single task simplify the output. + if len(self.tasks) == 1: + task_output = processed_predictions.pop(next(iter(self.tasks))) + processed_predictions.update(task_output) + return super().save_predictions( + processed_predictions, output_type, save_path=save_path, **kwargs + ) + + # Save to zarr + for task_name in self.tasks: + processed_predictions_ = processed_predictions.pop(task_name) + # If there is a single task simplify the output. + task_name_ = None if len(self.tasks) == 1 else task_name + keys_to_compute = [ + k for k in processed_predictions_ if k not in self.drop_keys + ] + _ = self.save_predictions_as_zarr( + processed_predictions=processed_predictions_, + save_path=save_path, + keys_to_compute=keys_to_compute, + task_name=task_name_, + ) + return save_path + def save_predictions( self: MultiTaskSegmentor, processed_predictions: dict, @@ -315,33 +354,83 @@ def save_predictions( If an unsupported output_type is provided. """ - if output_type.lower() == "dict": - # If there is a single task simplify the output. - if len(self.tasks) == 1: - task_output = processed_predictions.pop(next(iter(self.tasks))) - processed_predictions.update(task_output) - return super().save_predictions( - processed_predictions, output_type, save_path=save_path, **kwargs + if output_type in ["dict", "zarr"]: + return self._save_predictions_as_dict_zarr( + processed_predictions=processed_predictions, + output_type=output_type, + save_path=save_path, + **kwargs, ) - if output_type.lower() == "zarr": - for task_name in self.tasks: - processed_predictions_ = processed_predictions.pop(task_name) - # If there is a single task simplify the output. - task_name_ = None if len(self.tasks) == 1 else task_name - keys_to_compute = [ - k for k in processed_predictions_ if k not in self.drop_keys - ] - _ = self.save_predictions_as_zarr( - processed_predictions=processed_predictions_, - save_path=save_path, - keys_to_compute=keys_to_compute, - task_name=task_name_, + # Save to AnnotationStore + return_probabilities = kwargs.get("return_probabilities", False) + output_type_ = ( + "zarr" + if is_zarr(save_path.with_suffix(".zarr")) or return_probabilities + else "dict" + ) + + # This runs dask.compute and returns numpy arrays + # for saving annotationstore output. + processed_predictions = self._save_predictions_as_dict_zarr( + processed_predictions, + output_type=output_type_, + save_path=save_path.with_suffix(".zarr"), + **kwargs, + ) + + if isinstance(processed_predictions, Path): + processed_predictions = zarr.open(str(processed_predictions), mode="r") + + # scale_factor set from kwargs + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # class_dict set from kwargs + class_dict = kwargs.get("class_dict") + # Need to add support for zarr conversion. + save_paths = [] + + logger.info("Saving predictions as AnnotationStore.") + + # Not required for annotationstore + processed_predictions.pop("predictions") + + if self.patch_mode: + for i, predictions in enumerate( + zip(*processed_predictions.values(), strict=False) + ): + predictions_ = dict( + zip(processed_predictions.keys(), predictions, strict=False) + ) + if isinstance(self.images[i], Path): + output_path = save_path.parent / (self.images[i].stem + ".db") + else: + output_path = save_path.parent / (str(i) + ".db") + + origin = predictions_.pop("coordinates")[:2] + store = SQLiteStore() + store = dict_to_store( + store=store, + processed_predictions=predictions_, + class_dict=class_dict, + scale_factor=scale_factor, + origin=origin, ) - return save_path - # Need to update for AnnotationStore - return save_path + store.commit() + store.dump(output_path) + + save_paths.append(output_path) + + if return_probabilities: + msg = ( + f"Probability maps cannot be saved as AnnotationStore. " + f"To visualise heatmaps in TIAToolbox Visualization tool," + f"convert heatmaps in {save_path} to ome.tiff using" + f"tiatoolbox.utils.misc.write_probability_heatmap_as_ome_tiff." + ) + logger.info(msg) + + return save_paths def run( self: MultiTaskSegmentor, @@ -476,3 +565,44 @@ def run( output_type=output_type, **kwargs, ) + + +def dict_to_store( + store: SQLiteStore, + processed_predictions: dict, + class_dict: dict | None = None, + origin: tuple[float, float] = (0, 0), + scale_factor: tuple[float, float] = (1, 1), +) -> AnnotationStore: + """Helper function to convert dict to store.""" + contour = processed_predictions.pop("contours") + + ann = [] + for i, contour_ in enumerate(contour): + ann_ = Annotation( + make_valid_poly( + feature2geometry( + { + "type": processed_predictions.get("geom_type", "Polygon"), + "coordinates": scale_factor * np.array([contour_]), + }, + ), + tuple(origin), + ), + { + prop: ( + class_dict[processed_predictions[prop][i]] + if prop == "type" and class_dict is not None + # Intention is convert arrays to list + # There might be int or float values which need to be + # converted to arrays first and then apply tolist(). + else np.array(processed_predictions[prop][i]).tolist() + ) + for prop in processed_predictions + }, + ) + ann.append(ann_) + logger.info("Added %d annotations.", len(ann)) + store.append_many(ann) + + return store From e86e17662ce2f548210ffe758240eb24871a3f92 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 13 Jan 2026 14:56:54 +0000 Subject: [PATCH 019/156] :white_check_mark: Add tests for multitask annotation store --- tests/engines/test_multi_task_segmentor.py | 11 ++ .../models/engine/multi_task_segmentor.py | 130 ++++++++++++------ 2 files changed, 96 insertions(+), 45 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 0f1d0637f..87a32fc7a 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -179,6 +179,17 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N indices_b=[0, 1, 2], ) + # AnnotationStore output comparison + output_ann = mtsegmentor.run( + images=patches, + patch_mode=True, + device=device, + output_type="annotationstore", + save_dir=track_tmp_path / "patch_output_annotationstore", + ) + + _ = output_ann + def test_single_output_mtsegmentor( remote_sample: Callable, track_tmp_path: Path diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 4221e9ffc..f397b3ad0 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -291,6 +291,72 @@ def _save_predictions_as_dict_zarr( ) return save_path + def _save_predictions_as_annotationstore( + self: MultiTaskSegmentor, + processed_predictions: dict, + task_name: str | None = None, + save_path: Path | None = None, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict | AnnotationStore | Path | list[Path]: + """Helper function to save predictions as annotationstore.""" + # scale_factor set from kwargs + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # class_dict set from kwargs + class_dict = kwargs.get("class_dict") + # Need to add support for zarr conversion. + save_paths = [] + + logger.info("Saving predictions as AnnotationStore.") + + # Not required for annotationstore + processed_predictions.pop("predictions") + + if self.patch_mode: + for i, predictions in enumerate( + zip(*processed_predictions.values(), strict=False) + ): + predictions_ = dict( + zip(processed_predictions.keys(), predictions, strict=False) + ) + if isinstance(self.images[i], Path): + store_file_name = ( + f"{self.images[i].stem}.db" + if task_name is None + else f"{self.images[i].stem}_{task_name}.db" + ) + output_path = save_path.parent / store_file_name + else: + store_file_name = ( + f"{i}.db" if task_name is None else f"{i}_{task_name}.db" + ) + output_path = save_path.parent / store_file_name + + origin = predictions_.pop("coordinates")[:2] + store = SQLiteStore() + store = dict_to_store( + store=store, + processed_predictions=predictions_, + class_dict=class_dict, + scale_factor=scale_factor, + origin=origin, + ) + + store.commit() + store.dump(output_path) + + save_paths.append(output_path) + return_probabilities = kwargs.get("return_probabilities", False) + if return_probabilities: + msg = ( + f"Probability maps cannot be saved as AnnotationStore. " + f"To visualise heatmaps in TIAToolbox Visualization tool," + f"convert heatmaps in {save_path} to ome.tiff using" + f"tiatoolbox.utils.misc.write_probability_heatmap_as_ome_tiff." + ) + logger.info(msg) + + return save_paths + def save_predictions( self: MultiTaskSegmentor, processed_predictions: dict, @@ -382,55 +448,29 @@ def save_predictions( if isinstance(processed_predictions, Path): processed_predictions = zarr.open(str(processed_predictions), mode="r") - # scale_factor set from kwargs - scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) - # class_dict set from kwargs - class_dict = kwargs.get("class_dict") - # Need to add support for zarr conversion. save_paths = [] - - logger.info("Saving predictions as AnnotationStore.") - - # Not required for annotationstore - processed_predictions.pop("predictions") - - if self.patch_mode: - for i, predictions in enumerate( - zip(*processed_predictions.values(), strict=False) - ): - predictions_ = dict( - zip(processed_predictions.keys(), predictions, strict=False) - ) - if isinstance(self.images[i], Path): - output_path = save_path.parent / (self.images[i].stem + ".db") - else: - output_path = save_path.parent / (str(i) + ".db") - - origin = predictions_.pop("coordinates")[:2] - store = SQLiteStore() - store = dict_to_store( - store=store, - processed_predictions=predictions_, - class_dict=class_dict, - scale_factor=scale_factor, - origin=origin, + if self.tasks & processed_predictions.keys(): + for task_name in self.tasks: + dict_for_store = { + **processed_predictions[task_name], + "coordinates": processed_predictions["coordinates"], + } + out_path = self._save_predictions_as_annotationstore( + processed_predictions=dict_for_store, + task_name=task_name, + save_path=save_path, + **kwargs, ) + save_paths += out_path - store.commit() - store.dump(output_path) - - save_paths.append(output_path) + return save_paths - if return_probabilities: - msg = ( - f"Probability maps cannot be saved as AnnotationStore. " - f"To visualise heatmaps in TIAToolbox Visualization tool," - f"convert heatmaps in {save_path} to ome.tiff using" - f"tiatoolbox.utils.misc.write_probability_heatmap_as_ome_tiff." - ) - logger.info(msg) - - return save_paths + return self._save_predictions_as_annotationstore( + processed_predictions=processed_predictions, + task_name=None, + save_path=save_path, + **kwargs, + ) def run( self: MultiTaskSegmentor, From 4247afa3631bb90219cd3cc9e4cedf9f5d38fa51 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 13 Jan 2026 18:02:49 +0000 Subject: [PATCH 020/156] :white_check_mark: Test output of layer and nuclei segmentation for annotationstore --- tests/engines/test_multi_task_segmentor.py | 153 +++++++++++------- .../models/engine/multi_task_segmentor.py | 1 - 2 files changed, 97 insertions(+), 57 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 87a32fc7a..c91f34ae5 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -62,6 +62,75 @@ def assert_output_equal( ), f"{field}[{i_a}] vs {field}[{i_b}] mismatch" +def assert_annotation_store_patch_output( + output_ann: list[Path], + task_name: str | None, + track_tmp_path: Path, + expected_counts: Sequence[int], + output_dict: OutputType, + fields: list[str], +) -> None: + """Helper function to test AnnotationStore output.""" + for patch_idx, db_path in enumerate(output_ann): + store_file_name = ( + f"{patch_idx}.db" if task_name is None else f"{patch_idx}_{task_name}.db" + ) + assert ( + db_path == track_tmp_path / "patch_output_annotationstore" / store_file_name + ) + store_ = SQLiteStore.open(db_path) + annotations_ = store_.values() + annotations_geometry_type = [ + str(annotation_.geometry_type) for annotation_ in annotations_ + ] + annotations_list = list(annotations_) + if expected_counts[patch_idx] > 0: + assert "Polygon" in annotations_geometry_type + + # Build result dict from annotation properties + result = {} + for ann in annotations_list: + for key, value in ann.properties.items(): + result.setdefault(key, []).append(value) + result["contours"] = [ + list(poly.exterior.coords) + for poly in (a.geometry for a in annotations_list) + ] + + # wrap it to make it compatible to assert_output_lengths + result_ = {field: [result[field]] for field in fields} + + # Lengths and equality checks for this patch + assert_output_lengths( + result_, + expected_counts=[expected_counts[patch_idx]], + fields=fields, + ) + fields_ = fields.copy() + fields_.remove("contours") + assert_output_equal( + result_, + output_dict, + fields=fields_, + indices_a=[0], + indices_b=[patch_idx], + ) + + # Contour check (discard last point) + matches = [ + np.array_equal(np.array(a[:-1], dtype=int), np.array(b, dtype=int)) + for a, b in zip( + result["contours"], output_dict["contours"][patch_idx], strict=False + ) + ] + # Due to make valid poly there might be translation in a few points + # in AnnotationStore + assert sum(matches) / len(matches) >= 0.95 + else: + assert annotations_geometry_type == [] + assert annotations_list == [] + + def test_mtsegmentor_init() -> None: """Tests MultiTaskSegmentor initialization.""" segmentor = MultiTaskSegmentor(model="hovernetplus-oed", device=device) @@ -188,7 +257,26 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N save_dir=track_tmp_path / "patch_output_annotationstore", ) - _ = output_ann + assert len(output_ann) == 6 + + for task_name in mtsegmentor.tasks: + fields_nuclei = ["box", "centroid", "contours", "prob", "type"] + fields_layer = ["contours", "type"] + fields = fields_nuclei if task_name == "nuclei_segmentation" else fields_layer + output_ann_ = [p for p in output_ann if p.name.endswith(f"{task_name}.db")] + expected_counts = ( + expected_counts_nuclei + if task_name == "nuclei_segmentation" + else expected_counts_layer + ) + assert_annotation_store_patch_output( + output_ann=output_ann_, + output_dict=output_dict[task_name], + track_tmp_path=track_tmp_path, + fields=fields, + expected_counts=expected_counts, + task_name=task_name, + ) def test_single_output_mtsegmentor( @@ -265,59 +353,12 @@ def test_single_output_mtsegmentor( ) assert len(output_ann) == 3 - assert output_ann[0] == track_tmp_path / "patch_output_annotationstore" / "0.db" - - for patch_idx, db_path in enumerate(output_ann): - assert ( - db_path - == track_tmp_path / "patch_output_annotationstore" / f"{patch_idx}.db" - ) - store_ = SQLiteStore.open(db_path) - annotations_ = store_.values() - annotations_geometry_type = [ - str(annotation_.geometry_type) for annotation_ in annotations_ - ] - annotations_list = list(annotations_) - if expected_counts_nuclei[patch_idx] > 0: - assert "Polygon" in annotations_geometry_type - - # Build result dict from annotation properties - result = {} - for ann in annotations_list: - for key, value in ann.properties.items(): - result.setdefault(key, []).append(value) - result["contours"] = [ - list(poly.exterior.coords) - for poly in (a.geometry for a in annotations_list) - ] - - # wrap it to make it compatible to assert_output_lengths - result_ = { - field: [result[field]] - for field in ["box", "centroid", "contours", "prob", "type"] - } - - # Lengths and equality checks for this patch - assert_output_lengths( - result_, - expected_counts=[expected_counts_nuclei[patch_idx]], - fields=["box", "centroid", "contours", "prob", "type"], - ) - assert_output_equal( - result_, - output_dict, - fields=["box", "centroid", "prob", "type"], - indices_a=[0], - indices_b=[patch_idx], - ) - # Contour check (discard last point) - assert all( - np.array_equal(np.array(a[:-1], dtype=int), np.array(b, dtype=int)) - for a, b in zip( - result["contours"], output_dict["contours"][patch_idx], strict=False - ) - ) - else: - assert annotations_geometry_type == [] - assert annotations_list == [] + assert_annotation_store_patch_output( + output_ann=output_ann, + output_dict=output_dict, + track_tmp_path=track_tmp_path, + fields=["box", "centroid", "contours", "prob", "type"], + expected_counts=expected_counts_nuclei, + task_name=None, + ) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index f397b3ad0..09899405b 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -244,7 +244,6 @@ def build_post_process_raw_predictions( values = task_dict[key] if all(isinstance(v, (np.ndarray, da.Array)) for v in values): raw_predictions[task][key] = da.stack(values, axis=0) - continue if all(isinstance(v, dict) for v in values): first = values[0] From 94f427e2dda7a9457ee4e829f1d2189e29ddd0c7 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 13 Jan 2026 18:25:50 +0000 Subject: [PATCH 021/156] :white_check_mark: Test output of layer and nuclei segmentation for annotationstore --- tiatoolbox/models/engine/multi_task_segmentor.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 09899405b..216acf2f7 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -282,6 +282,12 @@ def _save_predictions_as_dict_zarr( keys_to_compute = [ k for k in processed_predictions_ if k not in self.drop_keys ] + + if "coordinates" in processed_predictions: + processed_predictions_.update( + {"coordinates": processed_predictions["coordinates"]} + ) + keys_to_compute.extend(["coordinates"]) _ = self.save_predictions_as_zarr( processed_predictions=processed_predictions_, save_path=save_path, @@ -445,7 +451,7 @@ def save_predictions( ) if isinstance(processed_predictions, Path): - processed_predictions = zarr.open(str(processed_predictions), mode="r") + processed_predictions = zarr.open(str(processed_predictions), mode="r+") save_paths = [] if self.tasks & processed_predictions.keys(): From b16b57b3b674c5fc721cae52fca11a0bbdcde6f1 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 13 Jan 2026 21:16:00 +0000 Subject: [PATCH 022/156] :bug: Read image at correct resolution --- tests/engines/test_multi_task_segmentor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index c91f34ae5..87e4a01e2 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -171,7 +171,7 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) mini_wsi = WSIReader.open(mini_wsi_svs) size = (256, 256) - resolution = 0.25 + resolution = 0.50 units: Final = "mpp" patch1 = mini_wsi.read_rect( @@ -193,7 +193,7 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N patch_mode=True, ) - expected_counts_nuclei = [50, 17, 0] + expected_counts_nuclei = [95, 33, 0] assert_output_lengths( output_dict["nuclei_segmentation"], expected_counts_nuclei, @@ -350,6 +350,7 @@ def test_single_output_mtsegmentor( device=device, output_type="annotationstore", save_dir=track_tmp_path / "patch_output_annotationstore", + return_probabilities=True, ) assert len(output_ann) == 3 From 8aebbaf7c328c23c5dadf7eb533b0731fe1ee9d5 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 13 Jan 2026 22:57:05 +0000 Subject: [PATCH 023/156] :bug: Fix saving annotationstore using zarr --- tests/engines/test_multi_task_segmentor.py | 18 ++++++++++++++ tiatoolbox/models/engine/engine_abc.py | 4 +++- .../models/engine/multi_task_segmentor.py | 24 +++++++++++++------ 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 87e4a01e2..53142cdde 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -344,6 +344,9 @@ def test_single_output_mtsegmentor( ) # AnnotationStore output comparison + + # Reinitialize to check for probabilities in output. + mtsegmentor.drop_keys = [] output_ann = mtsegmentor.run( images=patches, patch_mode=True, @@ -363,3 +366,18 @@ def test_single_output_mtsegmentor( expected_counts=expected_counts_nuclei, task_name=None, ) + + zarr_file = track_tmp_path / "patch_output_annotationstore" / "output.zarr" + + assert zarr_file.exists() + + zarr_group = zarr.open( + str(zarr_file), + mode="r", + ) + + assert "probabilities" in zarr_group + + fields = ["box", "centroid", "contours", "prob", "type", "predictions"] + for field in fields: + assert field not in zarr_group diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 7034367c0..caaf80d73 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -772,7 +772,9 @@ def _get_tasks_for_saving_zarr( url=save_path, component=component, compute=False, - zarr_array_kwargs={"object_codec": object_codec}, + zarr_array_kwargs={ + "object_codec": object_codec, + }, ) write_tasks.append(task) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 216acf2f7..5bd517745 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -288,6 +288,11 @@ def _save_predictions_as_dict_zarr( {"coordinates": processed_predictions["coordinates"]} ) keys_to_compute.extend(["coordinates"]) + if kwargs.get("return_probabilities", False): + processed_predictions_.update( + {"probabilities": processed_predictions["probabilities"]} + ) + keys_to_compute.extend(["probabilities"]) _ = self.save_predictions_as_zarr( processed_predictions=processed_predictions_, save_path=save_path, @@ -313,16 +318,17 @@ def _save_predictions_as_annotationstore( logger.info("Saving predictions as AnnotationStore.") - # Not required for annotationstore + # predictions are not required when saving to AnnotationStore. processed_predictions.pop("predictions") + keys_to_compute = list(processed_predictions.keys()) + if "probabilities" in keys_to_compute: + keys_to_compute.remove("probabilities") + if self.patch_mode: - for i, predictions in enumerate( - zip(*processed_predictions.values(), strict=False) - ): - predictions_ = dict( - zip(processed_predictions.keys(), predictions, strict=False) - ) + for i in range(len(self.images)): + values = [processed_predictions[key][i] for key in keys_to_compute] + predictions_ = dict(zip(keys_to_compute, values, strict=False)) if isinstance(self.images[i], Path): store_file_name = ( f"{self.images[i].stem}.db" @@ -350,6 +356,10 @@ def _save_predictions_as_annotationstore( store.dump(output_path) save_paths.append(output_path) + + for key in keys_to_compute: + del processed_predictions[key] + return_probabilities = kwargs.get("return_probabilities", False) if return_probabilities: msg = ( From e4392c9752be68a072973bf3d97ef9821171d9a7 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 13 Jan 2026 23:02:09 +0000 Subject: [PATCH 024/156] :white_check_mark: Add test for logger. --- tests/engines/test_multi_task_segmentor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 53142cdde..33c5d9c1c 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -280,7 +280,9 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N def test_single_output_mtsegmentor( - remote_sample: Callable, track_tmp_path: Path + remote_sample: Callable, + track_tmp_path: Path, + caplog: pytest.LogCaptureFixture, ) -> None: """Tests MultiTaskSegmentor on single task output.""" mtsegmentor = MultiTaskSegmentor( @@ -381,3 +383,5 @@ def test_single_output_mtsegmentor( fields = ["box", "centroid", "contours", "prob", "type", "predictions"] for field in fields: assert field not in zarr_group + + assert "Probability maps cannot be saved as AnnotationStore" in caplog.text From f9f5f69186301334383a9a55dc027a369c5bcef5 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 14 Jan 2026 09:53:52 +0000 Subject: [PATCH 025/156] :bug: Fix deepsource error `PTC-W0060` --- tiatoolbox/models/engine/multi_task_segmentor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 5bd517745..e93cc378b 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -326,19 +326,19 @@ def _save_predictions_as_annotationstore( keys_to_compute.remove("probabilities") if self.patch_mode: - for i in range(len(self.images)): - values = [processed_predictions[key][i] for key in keys_to_compute] + for idx, curr_image in enumerate(self.images): + values = [processed_predictions[key][idx] for key in keys_to_compute] predictions_ = dict(zip(keys_to_compute, values, strict=False)) - if isinstance(self.images[i], Path): + if isinstance(curr_image, Path): store_file_name = ( - f"{self.images[i].stem}.db" + f"{curr_image.stem}.db" if task_name is None - else f"{self.images[i].stem}_{task_name}.db" + else f"{curr_image.stem}_{task_name}.db" ) output_path = save_path.parent / store_file_name else: store_file_name = ( - f"{i}.db" if task_name is None else f"{i}_{task_name}.db" + f"{idx}.db" if task_name is None else f"{idx}_{task_name}.db" ) output_path = save_path.parent / store_file_name From 84a461191ab6d06855430e02f58dd0d58417031d Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 14 Jan 2026 11:49:00 +0000 Subject: [PATCH 026/156] :white_check_mark: Use file paths for a test --- tests/engines/test_multi_task_segmentor.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 33c5d9c1c..0c3455f1c 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -13,6 +13,7 @@ from tiatoolbox.annotation import SQLiteStore from tiatoolbox.models.engine.multi_task_segmentor import MultiTaskSegmentor from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils import imwrite from tiatoolbox.wsicore import WSIReader if TYPE_CHECKING: @@ -301,12 +302,22 @@ def test_single_output_mtsegmentor( location=(512, 512), size=size, resolution=resolution, units=units ) patch3 = np.zeros_like(patch1) + + patch1_path = track_tmp_path / "patch1.png" + patch2_path = track_tmp_path / "patch2.png" + patch3_path = track_tmp_path / "patch3.png" + + imwrite(patch1_path, patch1) + imwrite(patch2_path, patch2) + imwrite(patch3_path, patch3) + + inputs = [Path(patch1_path), str(patch2_path), str(patch3_path)] patches = np.stack([patch1, patch2, patch3], axis=0) assert not mtsegmentor.patch_mode output_dict = mtsegmentor.run( - images=patches, + images=inputs, return_probabilities=True, return_labels=False, device=device, From 6e1b55f96ea7307b30cdb078e570861d01e9ec93 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 14 Jan 2026 14:20:06 +0000 Subject: [PATCH 027/156] :bug: Fix paths to input files in test --- tests/engines/test_multi_task_segmentor.py | 26 ++++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 0c3455f1c..c593f7a49 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -64,6 +64,7 @@ def assert_output_equal( def assert_annotation_store_patch_output( + inputs: list | np.ndarray, output_ann: list[Path], task_name: str | None, track_tmp_path: Path, @@ -73,9 +74,19 @@ def assert_annotation_store_patch_output( ) -> None: """Helper function to test AnnotationStore output.""" for patch_idx, db_path in enumerate(output_ann): - store_file_name = ( - f"{patch_idx}.db" if task_name is None else f"{patch_idx}_{task_name}.db" - ) + if isinstance(inputs[patch_idx], Path): + store_file_name = ( + f"{inputs[patch_idx].stem}.db" + if task_name is None + else f"{inputs[patch_idx].stem}_{task_name}.db" + ) + else: + store_file_name = ( + f"{patch_idx}.db" + if task_name is None + else f"{patch_idx}_{task_name}.db" + ) + assert ( db_path == track_tmp_path / "patch_output_annotationstore" / store_file_name ) @@ -271,6 +282,7 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N else expected_counts_layer ) assert_annotation_store_patch_output( + inputs=patches, output_ann=output_ann_, output_dict=output_dict[task_name], track_tmp_path=track_tmp_path, @@ -311,8 +323,7 @@ def test_single_output_mtsegmentor( imwrite(patch2_path, patch2) imwrite(patch3_path, patch3) - inputs = [Path(patch1_path), str(patch2_path), str(patch3_path)] - patches = np.stack([patch1, patch2, patch3], axis=0) + inputs = [Path(patch1_path), Path(patch2_path), Path(patch3_path)] assert not mtsegmentor.patch_mode @@ -334,7 +345,7 @@ def test_single_output_mtsegmentor( # Zarr output comparison output_zarr = mtsegmentor.run( - images=patches, + images=inputs, patch_mode=True, device=device, output_type="zarr", @@ -361,7 +372,7 @@ def test_single_output_mtsegmentor( # Reinitialize to check for probabilities in output. mtsegmentor.drop_keys = [] output_ann = mtsegmentor.run( - images=patches, + images=inputs, patch_mode=True, device=device, output_type="annotationstore", @@ -372,6 +383,7 @@ def test_single_output_mtsegmentor( assert len(output_ann) == 3 assert_annotation_store_patch_output( + inputs=inputs, output_ann=output_ann, output_dict=output_dict, track_tmp_path=track_tmp_path, From ff804db853fe0579e7d7ee2051ba6ffc03be456b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 16 Jan 2026 17:00:02 +0000 Subject: [PATCH 028/156] :construction: Add zarr output for WSIs --- tests/engines/test_multi_task_segmentor.py | 38 +- tiatoolbox/models/engine/engine_abc.py | 13 +- .../models/engine/multi_task_segmentor.py | 517 +++++++++++++++++- .../models/engine/semantic_segmentor.py | 11 +- 4 files changed, 569 insertions(+), 10 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index c593f7a49..a926d457b 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -292,7 +292,7 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N ) -def test_single_output_mtsegmentor( +def test_single_task_mtsegmentor( remote_sample: Callable, track_tmp_path: Path, caplog: pytest.LogCaptureFixture, @@ -408,3 +408,39 @@ def test_single_output_mtsegmentor( assert field not in zarr_group assert "Probability maps cannot be saved as AnnotationStore" in caplog.text + + +def test_wsi_mtsegmentor_zarr( + sample_svs: Path, + track_tmp_path: Path, +) -> None: + """Test MultiTaskSegmentor for WSIs with zarr output.""" + mtsegmentor = MultiTaskSegmentor( + model="hovernetplus-oed", + batch_size=64, + verbose=False, + num_workers=1, + ) + # Return Probabilities is False + output = mtsegmentor.run( + images=[sample_svs], + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=False, + save_dir=track_tmp_path / "wsi_out_check", + batch_size=2, + output_type="zarr", + memory_threshold=1, + stride_shape=(160, 160), + ) + + output_ = zarr.open(output[sample_svs], mode="r") + assert 18 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 22 + assert 0.43 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 0.45 + assert "probabilities" not in output_["nuclei_segmentation"] + assert "canvas" not in output_["nuclei_segmentation"] + assert "count" not in output_["nuclei_segmentation"] + assert "probabilities" not in output_["layer_segmentation"] + assert "canvas" not in output_["layer_segmentation"] + assert "count" not in output_["layer_segmentation"] diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index caaf80d73..8670ef8e4 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -753,10 +753,19 @@ def _get_tasks_for_saving_zarr( ) -> list: """Helper function to get dask tasks for saving zarr output.""" if isinstance(dask_output, da.Array): - dask_output = dask_output.rechunk("auto") + dask_output_dtype = dask_output.dtype + object_codec = Pickle() + if dask_output_dtype != "object": + dask_output = dask_output.rechunk("auto") + object_codec = None component = key if task_name is None else f"{task_name}/{key}" task = dask_output.to_zarr( - url=save_path, component=component, compute=False, object_codec=None + url=save_path, + component=component, + compute=False, + zarr_array_kwargs={ + "object_codec": object_codec, + }, ) write_tasks.append(task) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index e93cc378b..67b47712e 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -2,13 +2,16 @@ from __future__ import annotations +import gc from pathlib import Path from typing import TYPE_CHECKING import dask.array as da import numpy as np +import psutil import torch import zarr +from dask import compute from shapely.geometry import shape as feature2geometry from typing_extensions import Unpack @@ -18,12 +21,19 @@ from tiatoolbox.utils.misc import get_tqdm, make_valid_poly from tiatoolbox.wsicore.wsireader import is_zarr -from .semantic_segmentor import SemanticSegmentor, SemanticSegmentorRunParams +from .semantic_segmentor import ( + SemanticSegmentor, + SemanticSegmentorRunParams, + concatenate_none, + merge_batch_to_canvas, + store_probabilities, +) if TYPE_CHECKING: # pragma: no cover import os from torch.utils.data import DataLoader + from tqdm import tqdm, tqdm_notebook from tiatoolbox.annotation import AnnotationStore from tiatoolbox.models.models_abc import ModelABC @@ -107,9 +117,9 @@ def infer_patches( raw_predictions["probabilities"] = [[] for _ in range(num_expected_output)] # Inference loop - tqdm = get_tqdm() + tqdm_ = get_tqdm() tqdm_loop = ( - tqdm(dataloader, leave=False, desc="Inferring patches") + tqdm_(dataloader, leave=False, desc="Inferring patches") if self.verbose else self.dataloader ) @@ -145,6 +155,172 @@ def infer_patches( return raw_predictions + def infer_wsi( + self: SemanticSegmentor, + dataloader: DataLoader, + save_path: Path, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict[str, da.Array]: + """Perform model inference on a whole slide image (WSI).""" + # Default Memory threshold percentage is 80. + memory_threshold = kwargs.get("memory_threshold", 80) + vm = psutil.virtual_memory() + + keys = ["probabilities", "coordinates"] + coordinates = [] + + # Main output dictionary + raw_predictions = dict( + zip(keys, [da.empty(shape=(0, 0))] * len(keys), strict=False) + ) + + # Inference loop + tqdm_ = get_tqdm() + tqdm_loop = ( + tqdm_(dataloader, leave=False, desc="Inferring patches") + if self.verbose + else dataloader + ) + + # Expected number of outputs from the model + batch_output = self.model.infer_batch( + self.model, + torch.Tensor(dataloader.dataset[0]["image"][np.newaxis, ...]), + device=self.device, + ) + + num_expected_output = len(batch_output) + canvas_np = [None for _ in range(num_expected_output)] + canvas = [None for _ in range(num_expected_output)] + count = [None for _ in range(num_expected_output)] + canvas_zarr = [None for _ in range(num_expected_output)] + count_zarr = [None for _ in range(num_expected_output)] + + output_locs_y_, output_locs = None, None + + full_output_locs = ( + dataloader.dataset.full_outputs + if hasattr(dataloader.dataset, "full_outputs") + else dataloader.dataset.outputs + ) + + infer_batch = self._get_model_attr("infer_batch") + for batch_idx, batch_data in enumerate(tqdm_loop): + batch_output = infer_batch( + self.model, + batch_data["image"], + device=self.device, + ) + + batch_locs = batch_data["output_locs"].numpy() + + # Interpolate outputs for masked regions + full_batch_output, full_output_locs, output_locs = ( + prepare_multitask_full_batch( + batch_output, + batch_locs, + full_output_locs, + output_locs, + is_last=(batch_idx == (len(dataloader) - 1)), + ) + ) + + for idx, full_batch_output_ in enumerate(full_batch_output): + canvas_np[idx] = concatenate_none( + old_arr=canvas_np[idx], new_arr=full_batch_output_ + ) + + # Determine if dataloader is moved to next row of patches + change_indices = np.where(np.diff(output_locs[:, 1]) != 0)[0] + 1 + + # If a row of patches has been processed. + if change_indices.size > 0: + canvas, count, canvas_np, output_locs, output_locs_y_ = ( + merge_multitask_horizontal( + canvas, + count, + output_locs_y_, + canvas_np, + output_locs, + change_indices, + ) + ) + + used_percent = vm.percent + total_bytes = sum(arr.nbytes for arr in canvas) if canvas else 0 + canvas_used_percent = (total_bytes / vm.free) * 100 + + if ( + used_percent > memory_threshold + or canvas_used_percent > memory_threshold + ): + tqdm_loop.desc = "Spill intermediate data to disk" + used_percent = ( + canvas_used_percent + if (canvas_used_percent > memory_threshold) + else used_percent + ) + msg = ( + f"Current Memory usage: {used_percent} % " + f"exceeds specified threshold: {memory_threshold}. " + f"Saving intermediate results to disk." + ) + tqdm_.write(msg) + # Flush data in Memory and clear dask graph + canvas_zarr, count_zarr = save_multitask_to_cache( + canvas, + count, + canvas_zarr, + count_zarr, + save_path=save_path, + ) + canvas = [None for _ in range(num_expected_output)] + count = [None for _ in range(num_expected_output)] + gc.collect() + tqdm_loop.desc = "Inferring patches" + + coordinates.append( + da.from_array( + self._get_coordinates(batch_data), + ) + ) + + canvas, count, _, _, output_locs_y_ = merge_multitask_horizontal( + canvas, + count, + output_locs_y_, + canvas_np, + output_locs, + change_indices=[len(output_locs)], + ) + + zarr_group = None + if canvas_zarr is not None: + canvas_zarr, count_zarr = save_multitask_to_cache( + canvas, count, canvas_zarr, count_zarr + ) + # Wrap zarr in dask array + for idx, canvas_zarr_ in enumerate(canvas_zarr): + canvas[idx] = da.from_zarr(canvas_zarr_, chunks=canvas_zarr_.chunks) + count[idx] = da.from_zarr( + count_zarr[idx], chunks=count_zarr[idx].chunks + ) + + zarr_group = zarr.open(canvas_zarr[0].store.path, mode="a") + + # Final vertical merge + raw_predictions["probabilities"] = merge_multitask_vertical_chunkwise( + canvas, + count, + output_locs_y_, + zarr_group, + save_path, + memory_threshold, + ) + raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) + + return raw_predictions + def post_process_patches( # skipcq: PYL-R0201 self: MultiTaskSegmentor, raw_predictions: dict, @@ -183,6 +359,36 @@ def post_process_patches( # skipcq: PYL-R0201 return raw_predictions + def post_process_wsi( # skipcq: PYL-R0201 + self: MultiTaskSegmentor, + raw_predictions: dict, + save_path: Path, # noqa: ARG002 + **kwargs: Unpack[SemanticSegmentorRunParams], # noqa: ARG002 + ) -> dict: + """Post-process raw patch predictions from inference.""" + probabilities = raw_predictions["probabilities"] + post_process_predictions = self.model.postproc_func(probabilities) + + tasks = set() + for seg in post_process_predictions: + task_name = seg["task_type"] + tasks.add(task_name) + raw_predictions[task_name] = {} + + for key, value in seg.items(): + if key == "task_type": + continue + if isinstance(value, (np.ndarray, da.Array)): + raw_predictions[task_name][key] = da.array(value) + + if isinstance(value, dict): + for k, v in value.items(): + raw_predictions[task_name][k] = v + + self.tasks = tasks + + return raw_predictions + def build_post_process_raw_predictions( self: MultiTaskSegmentor, post_process_predictions: list[tuple], @@ -661,3 +867,308 @@ def dict_to_store( store.append_many(ann) return store + + +def prepare_multitask_full_batch( + batch_output: tuple[np.ndarray], + batch_locs: np.ndarray, + full_output_locs: np.ndarray, + output_locs: np.ndarray, + *, + is_last: bool, +) -> tuple[list[np.ndarray], np.ndarray, np.ndarray]: + """Prepare full-sized output and count arrays for a batch of patch predictions. + + This function aligns patch-level predictions with global output locations when + a mask (e.g., auto_get_mask) is applied. It initializes full-sized arrays and + fills them using matched indices. If the batch is the last in the sequence, + it pads the arrays to cover remaining locations. + + Args: + batch_output (np.ndarray): + Patch-level model predictions of shape (N, H, W, C). + batch_locs (np.ndarray): + Output locations corresponding to `batch_output`. + full_output_locs (np.ndarray): + Remaining global output locations to be matched. + output_locs (np.ndarray): + Accumulated output location array across batches. + is_last (bool): + Flag indicating whether this is the final batch. + + Returns: + tuple[list[np.ndarray], np.ndarray, np.ndarray]: + - full_batch_output: Full-sized output array with predictions placed. + - full_output_locs: Updated remaining global output locations. + - output_locs: Updated accumulated output locations. + + """ + # Use np.intersect1d once numpy version is upgraded to 2.0 + full_output_dict = {tuple(row): i for i, row in enumerate(full_output_locs)} + matches = [full_output_dict[tuple(row)] for row in batch_locs] + + total_size = np.max(matches).astype(np.uint32) + 1 + + full_batch_output = [np.empty(0) for _ in range(len(batch_output))] + + for idx, batch_output_ in enumerate(batch_output): + # Initialize full output array + full_batch_output[idx] = np.zeros( + shape=(total_size, *batch_output_.shape[1:]), + dtype=batch_output_.dtype, + ) + + # Place matching outputs using matching indices + full_batch_output[idx][matches] = batch_output_ + + output_locs = concatenate_none( + old_arr=output_locs, new_arr=full_output_locs[:total_size] + ) + full_output_locs = full_output_locs[total_size:] + + if is_last: + output_locs = concatenate_none(old_arr=output_locs, new_arr=full_output_locs) + for idx, batch_output_ in enumerate(batch_output): + full_batch_output[idx] = concatenate_none( + old_arr=full_batch_output[idx], + new_arr=np.zeros( + shape=(len(full_output_locs), *batch_output_.shape[1:]), + dtype=np.uint8, + ), + ) + + return full_batch_output, full_output_locs, output_locs + + +def merge_multitask_horizontal( + canvas: list[None] | list[da.Array], + count: list[None] | list[da.Array], + output_locs_y_: np.ndarray, + canvas_np: list[np.ndarray], + output_locs: np.ndarray, + change_indices: np.ndarray | list[int], +) -> tuple[list[da.Array], list[da.Array], list[np.ndarray], np.ndarray, np.ndarray]: + """Merge horizontal patches incrementally for each row of patches.""" + start_idx = 0 + for c_idx in change_indices: + output_locs_ = output_locs[: c_idx - start_idx] + + batch_xs = np.min(output_locs[:, 0], axis=0) + batch_xe = np.max(output_locs[:, 2], axis=0) + + for idx, canvas_np_ in enumerate(canvas_np): + canvas_np__ = canvas_np_[: c_idx - start_idx] + merged_shape = ( + canvas_np__.shape[1], + batch_xe - batch_xs, + canvas_np__.shape[3], + ) + canvas_merge, count_merge = merge_batch_to_canvas( + blocks=canvas_np__, + output_locations=output_locs_, + merged_shape=merged_shape, + ) + canvas_merge = da.from_array(canvas_merge, chunks=canvas_merge.shape) + count_merge = da.from_array(count_merge, chunks=count_merge.shape) + canvas[idx] = concatenate_none(old_arr=canvas[idx], new_arr=canvas_merge) + count[idx] = concatenate_none(old_arr=count[idx], new_arr=count_merge) + canvas_np[idx] = canvas_np[idx][c_idx - start_idx :] + + output_locs_y_ = concatenate_none( + old_arr=output_locs_y_, new_arr=output_locs[:, (1, 3)] + ) + + output_locs = output_locs[c_idx - start_idx :] + start_idx = c_idx + + return canvas, count, canvas_np, output_locs, output_locs_y_ + + +def save_multitask_to_cache( + canvas: list[da.Array], + count: list[da.Array], + canvas_zarr: list[zarr.Array | None], + count_zarr: list[zarr.Array | None], + save_path: str | Path = "temp.zarr", +) -> tuple[list[zarr.Array], list[zarr.Array]]: + """Save computed canvas and count list of arrays to Zarr cache.""" + zarr_group = None + for idx, canvas_ in enumerate(canvas): + computed_values = compute(*[canvas_, count[idx]]) + canvas_computed, count_computed = computed_values + + chunk_shape = tuple(chunk[0] for chunk in canvas_.chunks) + if canvas_zarr[idx] is None: + # Only open zarr for first canvas. + zarr_group = zarr.open(str(save_path), mode="w") if idx == 0 else zarr_group + + canvas_zarr[idx] = zarr_group.create_dataset( + name=f"canvas/{idx}", + shape=(0, *canvas_computed.shape[1:]), + chunks=(chunk_shape[0], *canvas_computed.shape[1:]), + dtype=canvas_computed.dtype, + overwrite=True, + ) + + count_zarr[idx] = zarr_group.create_dataset( + name=f"count/{idx}", + shape=(0, *count_computed.shape[1:]), + dtype=count_computed.dtype, + chunks=(chunk_shape[0], *count_computed.shape[1:]), + overwrite=True, + ) + + canvas_zarr[idx].resize( + ( + canvas_zarr[idx].shape[0] + canvas_computed.shape[0], + *canvas_zarr[idx].shape[1:], + ) + ) + canvas_zarr[idx][-canvas_computed.shape[0] :] = canvas_computed + + count_zarr[idx].resize( + ( + count_zarr[idx].shape[0] + count_computed.shape[0], + *count_zarr[idx].shape[1:], + ) + ) + count_zarr[idx][-count_computed.shape[0] :] = count_computed + + return canvas_zarr, count_zarr + + +def merge_multitask_vertical_chunkwise( + canvas: list[da.Array], + count: list[da.Array], + output_locs_y_: np.ndarray, + zarr_group: zarr.Group, + save_path: Path, + memory_threshold: int = 80, +) -> list[da.Array]: + """Merge vertically chunked arrays into a single probability map.""" + y0s, y1s = np.unique(output_locs_y_[:, 0]), np.unique(output_locs_y_[:, 1]) + overlaps = np.append(y1s[:-1] - y0s[1:], 0) + + probabilities_zarr = [None for _ in range(len(canvas))] + probabilities_da = [None for _ in range(len(canvas))] + + for idx, canvas_ in enumerate(canvas): + num_chunks = canvas_.numblocks[0] + chunk_shape = tuple(chunk[0] for chunk in canvas_.chunks) + + tqdm_ = get_tqdm() + tqdm_loop = tqdm_(overlaps, leave=False, desc="Merging rows") + + curr_chunk = canvas_.blocks[0, 0].compute() + curr_count = count[idx].blocks[0, 0].compute() + next_chunk = canvas_.blocks[1, 0].compute() if num_chunks > 1 else None + next_count = count[idx].blocks[1, 0].compute() if num_chunks > 1 else None + + for i, overlap in enumerate(tqdm_loop): + if next_chunk is not None and overlap > 0: + curr_chunk[-overlap:] += next_chunk[:overlap] + curr_count[-overlap:] += next_count[:overlap] + + # Normalize + curr_count = np.where(curr_count == 0, 1, curr_count) + probabilities = curr_chunk / curr_count.astype(np.float32) + + probabilities_zarr[idx], probabilities_da[idx] = store_probabilities( + probabilities=probabilities, + chunk_shape=chunk_shape, + probabilities_zarr=probabilities_zarr[idx], + probabilities_da=probabilities_da[idx], + zarr_group=zarr_group, + name=f"probabilities/{idx}", + ) + + probabilities_zarr, probabilities_da = _save_multitask_vertical_to_cache( + probabilities_zarr=probabilities_zarr, + probabilities_da=probabilities_da, + probabilities=probabilities, + idx=idx, + tqdm_=tqdm_, + save_path=save_path, + chunk_shape=chunk_shape, + memory_threshold=memory_threshold, + ) + + if next_chunk is not None: + curr_chunk, curr_count = next_chunk[overlap:], next_count[overlap:] + + if i + 2 < num_chunks: + next_chunk = canvas_.blocks[i + 2, 0].compute() + next_count = count[idx].blocks[i + 2, 0].compute() + else: + next_chunk, next_count = None, None + + probabilities_da[idx] = _clear_zarr( + probabilities_zarr=probabilities_zarr[idx], + probabilities_da=probabilities_da[idx], + zarr_group=zarr_group, + idx=idx, + chunk_shape=chunk_shape, + probabilities_shape=curr_chunk.shape[1:], + ) + + return probabilities_da + + +def _save_multitask_vertical_to_cache( + probabilities_zarr: list[zarr.Array], + probabilities_da: list[da.Array] | None, + probabilities: np.ndarray, + idx: int, + tqdm_: type[tqdm_notebook | tqdm], + save_path: Path, + chunk_shape: tuple, + memory_threshold: int = 80, +) -> tuple[list[zarr.Array], list[da.Array] | None]: + """Helper function to save to zarr if vertical merge is out of memory.""" + used_percent = 0 + if probabilities_da[idx] is not None: + vm = psutil.virtual_memory() + total_bytes = ( + sum(arr.nbytes for arr in probabilities_da) if probabilities_da else 0 + ) + used_percent = (total_bytes / vm.free) * 100 + if probabilities_zarr[idx] is None and used_percent > memory_threshold: + msg = ( + f"Current Memory usage: {used_percent} % " + f"exceeds specified threshold: {memory_threshold}. " + f"Saving intermediate results to disk." + ) + tqdm_.write(msg) + zarr_group = zarr.open(str(save_path), mode="a") + probabilities_zarr = zarr_group.create_dataset( + name=f"probabilities/{idx}", + shape=probabilities_da.shape, + chunks=(chunk_shape[0], *probabilities.shape[1:]), + dtype=probabilities.dtype, + overwrite=True, + ) + probabilities_zarr[idx][:] = probabilities_da.compute() + + probabilities_da = None + + return probabilities_zarr, probabilities_da + + +def _clear_zarr( + probabilities_zarr: zarr.Array, + probabilities_da: da.Array | None, + zarr_group: zarr.Group, + idx: int, + chunk_shape: tuple, + probabilities_shape: tuple, +) -> da.Array | None: + """Helper function to clear all zarr contents and return dask array.""" + if probabilities_zarr: + if "canvas" in zarr_group: + del zarr_group["canvas"][idx] + if "count" in zarr_group: + del zarr_group["count"][idx] + return da.from_zarr( + probabilities_zarr, chunks=(chunk_shape[0], *probabilities_shape) + ) + return probabilities_da diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 23f04fc56..20338960e 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -961,7 +961,7 @@ def run( def concatenate_none( - old_arr: np.ndarray | da.Array, + old_arr: np.ndarray | da.Array | None, new_arr: np.ndarray | da.Array, ) -> np.ndarray | da.Array: """Concatenate arrays, handling None values gracefully. @@ -971,7 +971,7 @@ def concatenate_none( arrays. Args: - old_arr (np.ndarray | da.Array): + old_arr (np.ndarray | da.Array | None): Existing array to append to. Can be None. new_arr (np.ndarray | da.Array): New array to append. @@ -1034,7 +1034,7 @@ def merge_horizontal( output_locs_y_: np.ndarray, canvas_np: np.ndarray, output_locs: np.ndarray, - change_indices: np.ndarray | list[np.ndarray], + change_indices: np.ndarray | list[int], ) -> tuple[da.Array, da.Array, np.ndarray, np.ndarray, np.ndarray]: """Merge horizontal patches incrementally for each row of patches. @@ -1283,6 +1283,7 @@ def store_probabilities( probabilities_zarr: zarr.Array | None, probabilities_da: da.Array | None, zarr_group: zarr.Group | None, + name: str = "probabilities", ) -> tuple[zarr.Array | None, da.Array | None]: """Store computed probability data into a Zarr dataset or accumulate in memory. @@ -1301,6 +1302,8 @@ def store_probabilities( Existing Dask array for in-memory accumulation. zarr_group (zarr.Group | None): Zarr group used to create or access the dataset. + name (str): + Name to create Zarr dataset. Returns: tuple[zarr.Array | None, da.Array | None]: @@ -1310,7 +1313,7 @@ def store_probabilities( if zarr_group is not None: if probabilities_zarr is None: probabilities_zarr = zarr_group.create_dataset( - name="probabilities", + name=name, shape=(0, *probabilities.shape[1:]), chunks=(chunk_shape[0], *probabilities.shape[1:]), dtype=probabilities.dtype, From 04dada48e1be29f2ca9c9aab7aae96efb64692d2 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jan 2026 10:12:39 +0000 Subject: [PATCH 029/156] :bug: Fix deepsource error --- .../models/engine/multi_task_segmentor.py | 162 ++++++++++++------ 1 file changed, 108 insertions(+), 54 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 67b47712e..13ad3cd32 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -164,7 +164,6 @@ def infer_wsi( """Perform model inference on a whole slide image (WSI).""" # Default Memory threshold percentage is 80. memory_threshold = kwargs.get("memory_threshold", 80) - vm = psutil.virtual_memory() keys = ["probabilities", "coordinates"] coordinates = [] @@ -246,38 +245,19 @@ def infer_wsi( ) ) - used_percent = vm.percent - total_bytes = sum(arr.nbytes for arr in canvas) if canvas else 0 - canvas_used_percent = (total_bytes / vm.free) * 100 - - if ( - used_percent > memory_threshold - or canvas_used_percent > memory_threshold - ): - tqdm_loop.desc = "Spill intermediate data to disk" - used_percent = ( - canvas_used_percent - if (canvas_used_percent > memory_threshold) - else used_percent - ) - msg = ( - f"Current Memory usage: {used_percent} % " - f"exceeds specified threshold: {memory_threshold}. " - f"Saving intermediate results to disk." - ) - tqdm_.write(msg) - # Flush data in Memory and clear dask graph - canvas_zarr, count_zarr = save_multitask_to_cache( - canvas, - count, - canvas_zarr, - count_zarr, + canvas, count, canvas_zarr, count_zarr, tqdm_loop = ( + _check_and_update_for_memory_overload( + canvas=canvas, + count=count, + canvas_zarr=canvas_zarr, + count_zarr=count_zarr, + memory_threshold=memory_threshold, + tqdm_loop=tqdm_loop, + tqdm_=tqdm_, save_path=save_path, + num_expected_output=num_expected_output, ) - canvas = [None for _ in range(num_expected_output)] - count = [None for _ in range(num_expected_output)] - gc.collect() - tqdm_loop.desc = "Inferring patches" + ) coordinates.append( da.from_array( @@ -294,29 +274,16 @@ def infer_wsi( change_indices=[len(output_locs)], ) - zarr_group = None - if canvas_zarr is not None: - canvas_zarr, count_zarr = save_multitask_to_cache( - canvas, count, canvas_zarr, count_zarr - ) - # Wrap zarr in dask array - for idx, canvas_zarr_ in enumerate(canvas_zarr): - canvas[idx] = da.from_zarr(canvas_zarr_, chunks=canvas_zarr_.chunks) - count[idx] = da.from_zarr( - count_zarr[idx], chunks=count_zarr[idx].chunks - ) - - zarr_group = zarr.open(canvas_zarr[0].store.path, mode="a") - - # Final vertical merge - raw_predictions["probabilities"] = merge_multitask_vertical_chunkwise( - canvas, - count, - output_locs_y_, - zarr_group, - save_path, - memory_threshold, + raw_predictions["probabilities"] = _calculate_probabilities( + canvas_zarr=canvas_zarr, + count_zarr=count_zarr, + canvas=canvas, + count=count, + output_locs_y_=output_locs_y_, + save_path=save_path, + memory_threshold=memory_threshold, ) + raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) return raw_predictions @@ -1155,7 +1122,7 @@ def _save_multitask_vertical_to_cache( def _clear_zarr( - probabilities_zarr: zarr.Array, + probabilities_zarr: zarr.Array | None, probabilities_da: da.Array | None, zarr_group: zarr.Group, idx: int, @@ -1172,3 +1139,90 @@ def _clear_zarr( probabilities_zarr, chunks=(chunk_shape[0], *probabilities_shape) ) return probabilities_da + + +def _calculate_probabilities( + canvas_zarr: list[zarr.Array | None], + count_zarr: list[zarr.Array | None], + canvas: list[da.Array | None], + count: list[da.Array | None], + output_locs_y_: np.ndarray, + save_path: Path, + memory_threshold: int, +) -> list[da.Array]: + """Helper function to calculate probabilities for MultiTaskSegmentor.""" + zarr_group = None + if canvas_zarr is not None: + canvas_zarr, count_zarr = save_multitask_to_cache( + canvas, count, canvas_zarr, count_zarr + ) + # Wrap zarr in dask array + for idx, canvas_zarr_ in enumerate(canvas_zarr): + canvas[idx] = da.from_zarr(canvas_zarr_, chunks=canvas_zarr_.chunks) + count[idx] = da.from_zarr(count_zarr[idx], chunks=count_zarr[idx].chunks) + + zarr_group = zarr.open(canvas_zarr[0].store.path, mode="a") + + # Final vertical merge + return merge_multitask_vertical_chunkwise( + canvas, + count, + output_locs_y_, + zarr_group, + save_path, + memory_threshold, + ) + + +def _check_and_update_for_memory_overload( + canvas: list[da.Array | None], + count: list[da.Array | None], + canvas_zarr: list[zarr.Array | None], + count_zarr: list[zarr.Array | None], + memory_threshold: int, + tqdm_loop: DataLoader | tqdm, + tqdm_: type[tqdm_notebook | tqdm], + save_path: Path, + num_expected_output: int, +) -> tuple[ + list[da.Array | None], + list[da.Array | None], + list[zarr.Array | None], + list[zarr.Array | None], + DataLoader | tqdm, +]: + """Helper function to check and update the memory usage for multitask segmentor.""" + vm = psutil.virtual_memory() + used_percent = vm.percent + total_bytes = sum(arr.nbytes for arr in canvas) if canvas else 0 + canvas_used_percent = (total_bytes / vm.free) * 100 + + if not (used_percent > memory_threshold or canvas_used_percent > memory_threshold): + return canvas, count, canvas_zarr, count_zarr, tqdm_loop + + tqdm_loop.desc = "Spill intermediate data to disk" + used_percent = ( + canvas_used_percent + if (canvas_used_percent > memory_threshold) + else used_percent + ) + msg = ( + f"Current Memory usage: {used_percent} % " + f"exceeds specified threshold: {memory_threshold}. " + f"Saving intermediate results to disk." + ) + tqdm_.write(msg) + # Flush data in Memory and clear dask graph + canvas_zarr, count_zarr = save_multitask_to_cache( + canvas, + count, + canvas_zarr, + count_zarr, + save_path=save_path, + ) + canvas = [None for _ in range(num_expected_output)] + count = [None for _ in range(num_expected_output)] + gc.collect() + tqdm_loop.desc = "Inferring patches" + + return canvas, count, canvas_zarr, count_zarr, tqdm_loop From a8ca393121082e54dda73812af4b163d745ad9ae Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jan 2026 10:50:23 +0000 Subject: [PATCH 030/156] :bug: Fix deepsource error --- tests/engines/test_multi_task_segmentor.py | 49 ++++++++++++++++++- .../models/engine/multi_task_segmentor.py | 15 ++++-- 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index a926d457b..bf4bc20ab 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -2,6 +2,7 @@ from __future__ import annotations +import shutil from pathlib import Path from typing import TYPE_CHECKING, Any, Final @@ -411,10 +412,12 @@ def test_single_task_mtsegmentor( def test_wsi_mtsegmentor_zarr( + remote_sample: Callable, sample_svs: Path, track_tmp_path: Path, ) -> None: """Test MultiTaskSegmentor for WSIs with zarr output.""" + wsi1_2k_2k_svs = Path(remote_sample("wsi1_2k_2k_svs")) mtsegmentor = MultiTaskSegmentor( model="hovernetplus-oed", batch_size=64, @@ -438,9 +441,51 @@ def test_wsi_mtsegmentor_zarr( output_ = zarr.open(output[sample_svs], mode="r") assert 18 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 22 assert 0.43 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 0.45 - assert "probabilities" not in output_["nuclei_segmentation"] + assert "probabilities" not in output_ assert "canvas" not in output_["nuclei_segmentation"] assert "count" not in output_["nuclei_segmentation"] - assert "probabilities" not in output_["layer_segmentation"] assert "canvas" not in output_["layer_segmentation"] assert "count" not in output_["layer_segmentation"] + shutil.rmtree(output[sample_svs]) + + mtsegmentor = MultiTaskSegmentor( + model="hovernetplus-oed", + batch_size=64, + verbose=False, + num_workers=1, + ) + # Return Probabilities is True + # Add multi-input test + output = mtsegmentor.run( + images=[sample_svs, wsi1_2k_2k_svs], + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=False, + save_dir=track_tmp_path / "return_probabilities_check", + batch_size=2, + output_type="zarr", + stride_shape=(160, 160), + verbose=True, + ) + + output_ = zarr.open(output[sample_svs], mode="r") + assert 18 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 22 + assert 0.43 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 0.45 + assert "probabilities" in output_ + assert "canvas" not in output_["nuclei_segmentation"] + assert "count" not in output_["nuclei_segmentation"] + assert "canvas" not in output_["layer_segmentation"] + assert "count" not in output_["layer_segmentation"] + + output_ = zarr.open(output[wsi1_2k_2k_svs], mode="r") + assert 69 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 73 + assert 0.8 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 1.2 + assert "probabilities" in output_ + assert "canvas" not in output_["nuclei_segmentation"] + assert "count" not in output_["nuclei_segmentation"] + assert "canvas" not in output_["layer_segmentation"] + assert "count" not in output_["layer_segmentation"] + + shutil.rmtree(output[sample_svs]) + shutil.rmtree(output[wsi1_2k_2k_svs]) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 13ad3cd32..997689b98 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -448,6 +448,16 @@ def _save_predictions_as_dict_zarr( ) # Save to zarr + if kwargs.get("return_probabilities", False): + _ = self.save_predictions_as_zarr( + processed_predictions={ + "probabilities": processed_predictions.pop("probabilities") + }, + save_path=save_path, + keys_to_compute=["probabilities"], + task_name=None, + ) + for task_name in self.tasks: processed_predictions_ = processed_predictions.pop(task_name) # If there is a single task simplify the output. @@ -461,11 +471,6 @@ def _save_predictions_as_dict_zarr( {"coordinates": processed_predictions["coordinates"]} ) keys_to_compute.extend(["coordinates"]) - if kwargs.get("return_probabilities", False): - processed_predictions_.update( - {"probabilities": processed_predictions["probabilities"]} - ) - keys_to_compute.extend(["probabilities"]) _ = self.save_predictions_as_zarr( processed_predictions=processed_predictions_, save_path=save_path, From a3365695df999f7cee76fa33a98839ae4fdc2f40 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Mon, 19 Jan 2026 11:35:46 +0000 Subject: [PATCH 031/156] add type annotation utils.transforms.py --- tiatoolbox/utils/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/utils/transforms.py b/tiatoolbox/utils/transforms.py index 8fba6fee9..fbc8c20c0 100644 --- a/tiatoolbox/utils/transforms.py +++ b/tiatoolbox/utils/transforms.py @@ -153,7 +153,7 @@ def imresize( # can work on out-of-the-box (anything else will cause # error). The `converted type` has been selected so that # they can maintain the numeric precision of the `original type`. - dtype_mapping = [ + dtype_mapping: list[tuple[type, type]] = [ (np.bool_, np.uint8), (np.int8, np.int16), (np.int16, np.int16), @@ -167,7 +167,7 @@ def imresize( (np.float32, np.float32), (np.float64, np.float64), ] - source_dtypes = [v[0] for v in dtype_mapping] + source_dtypes = [np.dtype(v[0]) for v in dtype_mapping] original_dtype = img.dtype if original_dtype not in source_dtypes: msg = f"Does not support resizing for array of dtype: {original_dtype}" @@ -415,6 +415,6 @@ def pad_bounds( elif np.size(padding) == ndims: # pragma: no cover padding = np.tile(padding, 2) - signs = np.repeat([-1, 1], ndims) + signs: np.ndarray = np.repeat([-1, 1], ndims) result = np.add(bounds, padding * signs) return (result[0], result[1], result[2], result[3]) From 57175120e094f589d29eb47a479011b5b61aabef Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jan 2026 15:00:28 +0000 Subject: [PATCH 032/156] :sparkles: Add WSI AnnotationStore support --- tests/engines/test_multi_task_segmentor.py | 101 ++++++++++++++++- tests/engines/test_semantic_segmentor.py | 2 +- .../models/architecture/hovernetplus.py | 12 ++ .../models/engine/multi_task_segmentor.py | 103 +++++++++++++----- 4 files changed, 187 insertions(+), 31 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index bf4bc20ab..621dd837e 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -6,13 +6,17 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Final +import dask.array as da import numpy as np import pytest import torch import zarr from tiatoolbox.annotation import SQLiteStore -from tiatoolbox.models.engine.multi_task_segmentor import MultiTaskSegmentor +from tiatoolbox.models.engine.multi_task_segmentor import ( + MultiTaskSegmentor, + _clear_zarr, +) from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils import imwrite from tiatoolbox.wsicore import WSIReader @@ -175,6 +179,43 @@ def test_raise_value_error_return_labels_wsi( ) +def test_clear_zarr() -> None: + """Test _clear_zarr working appropriately. + + This test only covers scenarios which are not feasible to run on GitHub Actions. + + """ + store = zarr.MemoryStore() + root = zarr.group(store=store) + + # Create a dummy zarr array for probabilities_zarr + probabilities_zarr = root.create_dataset("probs", data=np.zeros((5, 3, 3))) + + idx = 2 + chunk_shape = (1,) + probabilities_shape = (3, 3) + + # Add canvas and count arrays with multiple entries + root.create_dataset(f"canvas/{idx}", data=np.arange(10)) + root.create_dataset(f"count/{idx}", data=np.arange(10)) + + result = _clear_zarr( + probabilities_zarr=probabilities_zarr, + probabilities_da=None, + zarr_group=root, + idx=idx, + chunk_shape=chunk_shape, + probabilities_shape=probabilities_shape, + ) + + # Ensure the keys still exist but the specific index was removed + assert "canvas" in root + assert "count" in root + assert 2 not in root["canvas"] + assert 2 not in root["count"] + assert isinstance(result, da.Array) + + def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> None: """Tests MultiTaskSegmentor on image patches.""" mtsegmentor = MultiTaskSegmentor( @@ -489,3 +530,61 @@ def test_wsi_mtsegmentor_zarr( shutil.rmtree(output[sample_svs]) shutil.rmtree(output[wsi1_2k_2k_svs]) + + +def test_wsi_segmentor_annotationstore( + sample_svs: Path, track_tmp_path: Path, caplog: pytest.CaptureFixture +) -> None: + """Test MultiTaskSegmentor for WSIs with AnnotationStore output.""" + mtsegmentor = MultiTaskSegmentor( + model="hovernetplus-oed", + batch_size=32, + verbose=False, + ) + # Return Probabilities is False + output = mtsegmentor.run( + images=[sample_svs], + return_probabilities=False, + device=device, + patch_mode=False, + save_dir=track_tmp_path / "wsi_out_check", + verbose=True, + output_type="annotationstore", + ) + + for output_ in output[sample_svs]: + assert output_.suffix != ".zarr" + + for task_name in mtsegmentor.tasks: + store_file_name = f"{sample_svs.stem}_{task_name}.db" + store_file_path = track_tmp_path / "wsi_out_check" / store_file_name + assert store_file_path.exists() + assert store_file_path in output[sample_svs] + + # Return Probabilities is True + mtsegmentor = MultiTaskSegmentor( + model="hovernetplus-oed", + batch_size=32, + verbose=False, + ) + + output = mtsegmentor.run( + images=[sample_svs], + return_probabilities=True, + device=device, + patch_mode=False, + save_dir=track_tmp_path / "wsi_prob_out_check", + verbose=True, + output_type="annotationstore", + ) + + assert "Probability maps cannot be saved as AnnotationStore." in caplog.text + zarr_group = zarr.open(output[sample_svs][0], mode="r") + assert "probabilities" in zarr_group + + for task_name in mtsegmentor.tasks: + store_file_name = f"{sample_svs.stem}_{task_name}.db" + store_file_path = track_tmp_path / "wsi_prob_out_check" / store_file_name + assert store_file_path.exists() + assert store_file_path in output[sample_svs] + assert task_name not in zarr_group diff --git a/tests/engines/test_semantic_segmentor.py b/tests/engines/test_semantic_segmentor.py index 14d492a8c..30647cf21 100644 --- a/tests/engines/test_semantic_segmentor.py +++ b/tests/engines/test_semantic_segmentor.py @@ -470,7 +470,7 @@ def test_wsi_segmentor_annotationstore( batch_size=32, verbose=False, ) - # Return Probabilities is False + # Return Probabilities is True output = segmentor.run( images=[sample_svs], return_probabilities=True, diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index ffcd42ae4..3cab5ad16 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -225,7 +225,19 @@ def _get_layer_info(pred_layer: np.ndarray) -> dict: cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE, ) + for layer in contours: + # * opencv protocol format may break + contour_ = np.squeeze(layer) + + # < 3 points does not make a contour, so skip, likely artifact too + # as the contours obtained via approximation => too small + if contour_.shape[0] < 3: # pragma: no cover # noqa: PLR2004 + continue + # ! check for trickery shape + if len(contour_.shape) != 2: # pragma: no cover # noqa: PLR2004 + continue + coords = layer[:, 0, :] layer_info_dict[count] = { "contours": coords, diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 997689b98..4eceadf61 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -465,7 +465,6 @@ def _save_predictions_as_dict_zarr( keys_to_compute = [ k for k in processed_predictions_ if k not in self.drop_keys ] - if "coordinates" in processed_predictions: processed_predictions_.update( {"coordinates": processed_predictions["coordinates"]} @@ -506,33 +505,31 @@ def _save_predictions_as_annotationstore( if self.patch_mode: for idx, curr_image in enumerate(self.images): values = [processed_predictions[key][idx] for key in keys_to_compute] - predictions_ = dict(zip(keys_to_compute, values, strict=False)) - if isinstance(curr_image, Path): - store_file_name = ( - f"{curr_image.stem}.db" - if task_name is None - else f"{curr_image.stem}_{task_name}.db" - ) - output_path = save_path.parent / store_file_name - else: - store_file_name = ( - f"{idx}.db" if task_name is None else f"{idx}_{task_name}.db" - ) - output_path = save_path.parent / store_file_name - - origin = predictions_.pop("coordinates")[:2] - store = SQLiteStore() - store = dict_to_store( - store=store, - processed_predictions=predictions_, + output_path = _save_annotation_store( + curr_image=curr_image, + keys_to_compute=keys_to_compute, + values=values, + task_name=task_name, + idx=idx, + save_path=save_path, class_dict=class_dict, scale_factor=scale_factor, - origin=origin, ) + save_paths.append(output_path) - store.commit() - store.dump(output_path) - + else: + for idx, curr_image in enumerate(self.images): + values = [processed_predictions[key] for key in keys_to_compute] + output_path = _save_annotation_store( + curr_image=curr_image, + keys_to_compute=keys_to_compute, + values=values, + task_name=task_name, + idx=idx, + save_path=save_path, + class_dict=class_dict, + scale_factor=scale_factor, + ) save_paths.append(output_path) for key in keys_to_compute: @@ -638,16 +635,20 @@ def save_predictions( **kwargs, ) + save_paths = [] if isinstance(processed_predictions, Path): + if return_probabilities: + save_paths.append(processed_predictions) processed_predictions = zarr.open(str(processed_predictions), mode="r+") - save_paths = [] if self.tasks & processed_predictions.keys(): for task_name in self.tasks: - dict_for_store = { - **processed_predictions[task_name], - "coordinates": processed_predictions["coordinates"], - } + dict_for_store = processed_predictions[task_name] + if "coordinates" in processed_predictions: + dict_for_store = { + **processed_predictions[task_name], + "coordinates": processed_predictions["coordinates"], + } out_path = self._save_predictions_as_annotationstore( processed_predictions=dict_for_store, task_name=task_name, @@ -655,6 +656,7 @@ def save_predictions( **kwargs, ) save_paths += out_path + del processed_predictions[task_name] return save_paths @@ -1231,3 +1233,46 @@ def _check_and_update_for_memory_overload( tqdm_loop.desc = "Inferring patches" return canvas, count, canvas_zarr, count_zarr, tqdm_loop + + +def _save_annotation_store( + curr_image: Path | None, + keys_to_compute: list[str], + values: list[da.Array | list[da.Array]], + task_name: str, + idx: int, + save_path: Path, + class_dict: dict, + scale_factor: tuple[float, float], +) -> Path: + """Helper function to save to annotation store.""" + if isinstance(curr_image, Path): + store_file_name = ( + f"{curr_image.stem}.db" + if task_name is None + else f"{curr_image.stem}_{task_name}.db" + ) + else: + store_file_name = f"{idx}.db" if task_name is None else f"{idx}_{task_name}.db" + predictions_ = dict(zip(keys_to_compute, values, strict=False)) + output_path = save_path.parent / store_file_name + # Patch mode indexes the "coordinates" while calculating "values" variable. + origin = ( + predictions_.pop("coordinates")[0][:2] + if len(predictions_["coordinates"].shape) > 1 + else predictions_.pop("coordinates")[:2] + ) + origin = tuple(max(0, x) for x in origin) + store = SQLiteStore() + store = dict_to_store( + store=store, + processed_predictions=predictions_, + class_dict=class_dict, + scale_factor=scale_factor, + origin=origin, + ) + + store.commit() + store.dump(output_path) + + return output_path From 568964d947bfb2eaefed1535eb758171fb785279 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jan 2026 17:21:51 +0000 Subject: [PATCH 033/156] :white_check_mark: Add check for single task and class_dict --- tests/engines/test_multi_task_segmentor.py | 59 ++++++++++--------- tiatoolbox/models/architecture/hovernet.py | 10 +++- .../models/architecture/hovernetplus.py | 14 +++-- tiatoolbox/models/engine/engine_abc.py | 2 +- .../models/engine/multi_task_segmentor.py | 12 +++- .../models/engine/semantic_segmentor.py | 2 +- tiatoolbox/models/models_abc.py | 1 + 7 files changed, 61 insertions(+), 39 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 621dd837e..cda98c941 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -189,16 +189,12 @@ def test_clear_zarr() -> None: root = zarr.group(store=store) # Create a dummy zarr array for probabilities_zarr - probabilities_zarr = root.create_dataset("probs", data=np.zeros((5, 3, 3))) + probabilities_zarr = root.create_dataset("probabilities", data=np.zeros((5, 3, 3))) idx = 2 chunk_shape = (1,) probabilities_shape = (3, 3) - # Add canvas and count arrays with multiple entries - root.create_dataset(f"canvas/{idx}", data=np.arange(10)) - root.create_dataset(f"count/{idx}", data=np.arange(10)) - result = _clear_zarr( probabilities_zarr=probabilities_zarr, probabilities_da=None, @@ -209,12 +205,21 @@ def test_clear_zarr() -> None: ) # Ensure the keys still exist but the specific index was removed - assert "canvas" in root - assert "count" in root - assert 2 not in root["canvas"] - assert 2 not in root["count"] + assert "canvas" not in root + assert "count" not in root assert isinstance(result, da.Array) + result_ = _clear_zarr( + probabilities_zarr=None, + probabilities_da=result, + zarr_group=root, + idx=idx, + chunk_shape=chunk_shape, + probabilities_shape=probabilities_shape, + ) + + assert np.all(result_.compute() == result.compute()) + def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> None: """Tests MultiTaskSegmentor on image patches.""" @@ -490,13 +495,14 @@ def test_wsi_mtsegmentor_zarr( shutil.rmtree(output[sample_svs]) mtsegmentor = MultiTaskSegmentor( - model="hovernetplus-oed", + model="hovernet_fast-pannuke", batch_size=64, verbose=False, num_workers=1, ) # Return Probabilities is True # Add multi-input test + # Use single task output from hovernet output = mtsegmentor.run( images=[sample_svs, wsi1_2k_2k_svs], return_probabilities=True, @@ -511,22 +517,16 @@ def test_wsi_mtsegmentor_zarr( ) output_ = zarr.open(output[sample_svs], mode="r") - assert 18 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 22 - assert 0.43 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 0.45 + assert 34 < np.mean(output_["predictions"][:]) < 38 assert "probabilities" in output_ - assert "canvas" not in output_["nuclei_segmentation"] - assert "count" not in output_["nuclei_segmentation"] - assert "canvas" not in output_["layer_segmentation"] - assert "count" not in output_["layer_segmentation"] + assert "canvas" not in output_ + assert "count" not in output_ output_ = zarr.open(output[wsi1_2k_2k_svs], mode="r") - assert 69 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 73 - assert 0.8 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 1.2 + assert 136 < np.mean(output_["predictions"][:]) < 140 assert "probabilities" in output_ - assert "canvas" not in output_["nuclei_segmentation"] - assert "count" not in output_["nuclei_segmentation"] - assert "canvas" not in output_["layer_segmentation"] - assert "count" not in output_["layer_segmentation"] + assert "canvas" not in output_ + assert "count" not in output_ shutil.rmtree(output[sample_svs]) shutil.rmtree(output[wsi1_2k_2k_svs]) @@ -537,10 +537,13 @@ def test_wsi_segmentor_annotationstore( ) -> None: """Test MultiTaskSegmentor for WSIs with AnnotationStore output.""" mtsegmentor = MultiTaskSegmentor( - model="hovernetplus-oed", + model="hovernet_fast-pannuke", batch_size=32, verbose=False, ) + + class_dict = mtsegmentor.model.class_dict + # Return Probabilities is False output = mtsegmentor.run( images=[sample_svs], @@ -550,16 +553,16 @@ def test_wsi_segmentor_annotationstore( save_dir=track_tmp_path / "wsi_out_check", verbose=True, output_type="annotationstore", + class_dict=class_dict, ) for output_ in output[sample_svs]: assert output_.suffix != ".zarr" - for task_name in mtsegmentor.tasks: - store_file_name = f"{sample_svs.stem}_{task_name}.db" - store_file_path = track_tmp_path / "wsi_out_check" / store_file_name - assert store_file_path.exists() - assert store_file_path in output[sample_svs] + store_file_name = f"{sample_svs.stem}.db" + store_file_path = track_tmp_path / "wsi_out_check" / store_file_name + assert store_file_path.exists() + assert store_file_path == output[sample_svs][0] # Return Probabilities is True mtsegmentor = MultiTaskSegmentor( diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 1eef26b97..b5ce75bfe 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -337,6 +337,11 @@ def __init__( self.mode = mode self.num_types = num_types self.nuc_type_dict = nuc_type_dict + self.tasks = [] + self.tasks.append("nuclei_segmentation") + self.class_dict = { + self.tasks[0]: nuc_type_dict, + } if mode not in ["original", "fast"]: msg = ( @@ -713,9 +718,8 @@ def get_instance_info(pred_inst: np.ndarray, pred_type: np.ndarray = None) -> di return inst_info_dict - @staticmethod # skipcq: PYL-W0221 # noqa: ERA001 - def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: + def postproc(self: HoVerNet, raw_maps: list[np.ndarray]) -> tuple[dict, ...]: """Post-processing script for image tiles. Args: @@ -804,7 +808,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: ) nuclei_seg = { - "task_type": "nuclei_segmentation", + "task_type": self.tasks[0], "predictions": pred_inst, "info_dict": nuc_inst_info_dict_, } diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 3cab5ad16..1cd3f0c8c 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -96,6 +96,13 @@ def __init__( self.num_layers = num_layers self.nuc_type_dict = nuc_type_dict self.layer_type_dict = layer_type_dict + self.tasks = [] + self.tasks.append("nuclei_segmentation") + self.tasks.append("layer_segmentation") + self.class_dict = { + self.tasks[0]: nuc_type_dict, + self.tasks[1]: layer_type_dict, + } ksize = 3 self.decoder = nn.ModuleDict( @@ -247,9 +254,8 @@ def _get_layer_info(pred_layer: np.ndarray) -> dict: return layer_info_dict - @staticmethod # skipcq: PYL-W0221 # noqa: ERA001 - def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: + def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...]: """Post-processing script for image tiles. Args: @@ -359,7 +365,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: ) nuclei_seg = { - "task_type": "nuclei_segmentation", + "task_type": self.tasks[0], "predictions": pred_inst, "info_dict": nuc_inst_info_dict_, } @@ -380,7 +386,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[dict, ...]: ) layer_seg = { - "task_type": "layer_segmentation", + "task_type": self.tasks[1], "predictions": pred_layer, "info_dict": layer_info_dict_, } diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 8670ef8e4..4b7ba16d1 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -731,7 +731,7 @@ def save_predictions( # scale_factor set from kwargs scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) # class_dict set from kwargs - class_dict = kwargs.get("class_dict") + class_dict = kwargs.get("class_dict", self.model.class_dict) return dict_to_store_patch_predictions( processed_predictions, diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 4eceadf61..9cd7c8d89 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -490,6 +490,7 @@ def _save_predictions_as_annotationstore( scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) # class_dict set from kwargs class_dict = kwargs.get("class_dict") + # Need to add support for zarr conversion. save_paths = [] @@ -501,7 +502,6 @@ def _save_predictions_as_annotationstore( keys_to_compute = list(processed_predictions.keys()) if "probabilities" in keys_to_compute: keys_to_compute.remove("probabilities") - if self.patch_mode: for idx, curr_image in enumerate(self.images): values = [processed_predictions[key][idx] for key in keys_to_compute] @@ -628,6 +628,12 @@ def save_predictions( # This runs dask.compute and returns numpy arrays # for saving annotationstore output. + class_dict = kwargs.get("class_dict", self.model.class_dict) + if len(self.tasks) == 1: + kwargs["class_dict"] = class_dict[next(iter(self.tasks))] + else: + kwargs["class_dict"] = class_dict + processed_predictions = self._save_predictions_as_dict_zarr( processed_predictions, output_type=output_type_, @@ -641,9 +647,11 @@ def save_predictions( save_paths.append(processed_predictions) processed_predictions = zarr.open(str(processed_predictions), mode="r+") + # For single tasks there should be no overlap if self.tasks & processed_predictions.keys(): for task_name in self.tasks: dict_for_store = processed_predictions[task_name] + kwargs["class_dict"] = class_dict[task_name] if "coordinates" in processed_predictions: dict_for_store = { **processed_predictions[task_name], @@ -1262,7 +1270,7 @@ def _save_annotation_store( if len(predictions_["coordinates"].shape) > 1 else predictions_.pop("coordinates")[:2] ) - origin = tuple(max(0, x) for x in origin) + origin = tuple(max(0.0, float(x)) for x in origin) store = SQLiteStore() store = dict_to_store( store=store, diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 20338960e..f3417bcf1 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -679,7 +679,7 @@ def save_predictions( # scale_factor set from kwargs scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) # class_dict set from kwargs - class_dict = kwargs.get("class_dict") + class_dict = kwargs.get("class_dict", self.model.class_dict) # Need to add support for zarr conversion. save_paths = [] diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index dcca370f5..c82f8ac71 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -92,6 +92,7 @@ def __init__(self: ModelABC) -> None: super().__init__() self._postproc = self.postproc self._preproc = self.preproc + self.class_dict = None @abstractmethod # This is generic abc, else pylint will complain From 943aebf2d61e27721f01eecb653bdc5762294996 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jan 2026 21:15:50 +0000 Subject: [PATCH 034/156] :white_check_mark: Add check for single task and class_dict support --- tests/engines/test_multi_task_segmentor.py | 404 +++++++++++---------- 1 file changed, 210 insertions(+), 194 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index cda98c941..7d36a05b8 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -28,126 +28,6 @@ device = "cuda" if toolbox_env.has_gpu() else "cpu" -def assert_output_lengths( - output: OutputType, expected_counts: Sequence[int], fields: list[str] -) -> None: - """Assert lengths of output dict fields against expected counts.""" - for field in fields: - for i, expected in enumerate(expected_counts): - assert len(output[field][i]) == expected, f"{field}[{i}] mismatch" - - -def assert_predictions_and_boxes( - output: OutputType, expected_counts: Sequence[int], *, is_zarr: bool = False -) -> None: - """Assert predictions maxima and box lengths against expected counts.""" - # predictions maxima - for idx, expected in enumerate(expected_counts): - if is_zarr and idx == 2: - # zarr output doesn't store predictions for patch 2 - continue - assert np.max(output["predictions"][idx][:]) == expected, ( - f"predictions[{idx}] mismatch" - ) - - -def assert_output_equal( - output_a: OutputType, - output_b: OutputType, - fields: Sequence[str], - indices_a: Sequence[int], - indices_b: Sequence[int], -) -> None: - """Assert equality of arrays across outputs for given fields/indices.""" - for field in fields: - for i_a, i_b in zip(indices_a, indices_b, strict=False): - left = output_a[field][i_a] - right = output_b[field][i_b] - assert all( - np.array_equal(a, b) for a, b in zip(left, right, strict=False) - ), f"{field}[{i_a}] vs {field}[{i_b}] mismatch" - - -def assert_annotation_store_patch_output( - inputs: list | np.ndarray, - output_ann: list[Path], - task_name: str | None, - track_tmp_path: Path, - expected_counts: Sequence[int], - output_dict: OutputType, - fields: list[str], -) -> None: - """Helper function to test AnnotationStore output.""" - for patch_idx, db_path in enumerate(output_ann): - if isinstance(inputs[patch_idx], Path): - store_file_name = ( - f"{inputs[patch_idx].stem}.db" - if task_name is None - else f"{inputs[patch_idx].stem}_{task_name}.db" - ) - else: - store_file_name = ( - f"{patch_idx}.db" - if task_name is None - else f"{patch_idx}_{task_name}.db" - ) - - assert ( - db_path == track_tmp_path / "patch_output_annotationstore" / store_file_name - ) - store_ = SQLiteStore.open(db_path) - annotations_ = store_.values() - annotations_geometry_type = [ - str(annotation_.geometry_type) for annotation_ in annotations_ - ] - annotations_list = list(annotations_) - if expected_counts[patch_idx] > 0: - assert "Polygon" in annotations_geometry_type - - # Build result dict from annotation properties - result = {} - for ann in annotations_list: - for key, value in ann.properties.items(): - result.setdefault(key, []).append(value) - result["contours"] = [ - list(poly.exterior.coords) - for poly in (a.geometry for a in annotations_list) - ] - - # wrap it to make it compatible to assert_output_lengths - result_ = {field: [result[field]] for field in fields} - - # Lengths and equality checks for this patch - assert_output_lengths( - result_, - expected_counts=[expected_counts[patch_idx]], - fields=fields, - ) - fields_ = fields.copy() - fields_.remove("contours") - assert_output_equal( - result_, - output_dict, - fields=fields_, - indices_a=[0], - indices_b=[patch_idx], - ) - - # Contour check (discard last point) - matches = [ - np.array_equal(np.array(a[:-1], dtype=int), np.array(b, dtype=int)) - for a, b in zip( - result["contours"], output_dict["contours"][patch_idx], strict=False - ) - ] - # Due to make valid poly there might be translation in a few points - # in AnnotationStore - assert sum(matches) / len(matches) >= 0.95 - else: - assert annotations_geometry_type == [] - assert annotations_list == [] - - def test_mtsegmentor_init() -> None: """Tests MultiTaskSegmentor initialization.""" segmentor = MultiTaskSegmentor(model="hovernetplus-oed", device=device) @@ -156,71 +36,6 @@ def test_mtsegmentor_init() -> None: assert isinstance(segmentor.model, torch.nn.Module) -def test_raise_value_error_return_labels_wsi( - sample_svs: Path, - track_tmp_path: Path, -) -> None: - """Tests MultiTaskSegmentor return_labels error.""" - mtsegmentor = MultiTaskSegmentor(model="hovernetplus-oed", device=device) - - with pytest.raises( - ValueError, - match=r".*return_labels` is not supported for MultiTaskSegmentor.", - ): - _ = mtsegmentor.run( - images=[sample_svs], - return_probabilities=False, - return_labels=True, - device=device, - patch_mode=False, - save_dir=track_tmp_path / "wsi_out_check", - batch_size=2, - output_type="zarr", - ) - - -def test_clear_zarr() -> None: - """Test _clear_zarr working appropriately. - - This test only covers scenarios which are not feasible to run on GitHub Actions. - - """ - store = zarr.MemoryStore() - root = zarr.group(store=store) - - # Create a dummy zarr array for probabilities_zarr - probabilities_zarr = root.create_dataset("probabilities", data=np.zeros((5, 3, 3))) - - idx = 2 - chunk_shape = (1,) - probabilities_shape = (3, 3) - - result = _clear_zarr( - probabilities_zarr=probabilities_zarr, - probabilities_da=None, - zarr_group=root, - idx=idx, - chunk_shape=chunk_shape, - probabilities_shape=probabilities_shape, - ) - - # Ensure the keys still exist but the specific index was removed - assert "canvas" not in root - assert "count" not in root - assert isinstance(result, da.Array) - - result_ = _clear_zarr( - probabilities_zarr=None, - probabilities_da=result, - zarr_group=root, - idx=idx, - chunk_shape=chunk_shape, - probabilities_shape=probabilities_shape, - ) - - assert np.all(result_.compute() == result.compute()) - - def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> None: """Tests MultiTaskSegmentor on image patches.""" mtsegmentor = MultiTaskSegmentor( @@ -336,6 +151,7 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N fields=fields, expected_counts=expected_counts, task_name=task_name, + class_dict=mtsegmentor.model.class_dict, ) @@ -437,6 +253,7 @@ def test_single_task_mtsegmentor( fields=["box", "centroid", "contours", "prob", "type"], expected_counts=expected_counts_nuclei, task_name=None, + class_dict=mtsegmentor.model.class_dict["nuclei_segmentation"], ) zarr_file = track_tmp_path / "patch_output_annotationstore" / "output.zarr" @@ -458,12 +275,10 @@ def test_single_task_mtsegmentor( def test_wsi_mtsegmentor_zarr( - remote_sample: Callable, sample_svs: Path, track_tmp_path: Path, ) -> None: """Test MultiTaskSegmentor for WSIs with zarr output.""" - wsi1_2k_2k_svs = Path(remote_sample("wsi1_2k_2k_svs")) mtsegmentor = MultiTaskSegmentor( model="hovernetplus-oed", batch_size=64, @@ -492,17 +307,24 @@ def test_wsi_mtsegmentor_zarr( assert "count" not in output_["nuclei_segmentation"] assert "canvas" not in output_["layer_segmentation"] assert "count" not in output_["layer_segmentation"] - shutil.rmtree(output[sample_svs]) + +def test_multi_input_wsi_mtsegmentor_zarr( + remote_sample: Callable, + sample_svs: Path, + track_tmp_path: Path, +) -> None: + """Test MultiTaskSegmentor for multiple WSIs with zarr output.""" + wsi1_2k_2k_svs = Path(remote_sample("wsi1_2k_2k_svs")) + # Return Probabilities is True + # Add multi-input test + # Use single task output from hovernet mtsegmentor = MultiTaskSegmentor( model="hovernet_fast-pannuke", batch_size=64, verbose=False, num_workers=1, ) - # Return Probabilities is True - # Add multi-input test - # Use single task output from hovernet output = mtsegmentor.run( images=[sample_svs, wsi1_2k_2k_svs], return_probabilities=True, @@ -532,9 +354,7 @@ def test_wsi_mtsegmentor_zarr( shutil.rmtree(output[wsi1_2k_2k_svs]) -def test_wsi_segmentor_annotationstore( - sample_svs: Path, track_tmp_path: Path, caplog: pytest.CaptureFixture -) -> None: +def test_wsi_segmentor_annotationstore(sample_svs: Path, track_tmp_path: Path) -> None: """Test MultiTaskSegmentor for WSIs with AnnotationStore output.""" mtsegmentor = MultiTaskSegmentor( model="hovernet_fast-pannuke", @@ -564,6 +384,11 @@ def test_wsi_segmentor_annotationstore( assert store_file_path.exists() assert store_file_path == output[sample_svs][0] + +def test_wsi_segmentor_annotationstore_probabilities( + sample_svs: Path, track_tmp_path: Path, caplog: pytest.CaptureFixture +) -> None: + """Test MultiTaskSegmentor with AnnotationStore and probabilities output.""" # Return Probabilities is True mtsegmentor = MultiTaskSegmentor( model="hovernetplus-oed", @@ -591,3 +416,194 @@ def test_wsi_segmentor_annotationstore( assert store_file_path.exists() assert store_file_path in output[sample_svs] assert task_name not in zarr_group + + +def test_raise_value_error_return_labels_wsi( + sample_svs: Path, + track_tmp_path: Path, +) -> None: + """Tests MultiTaskSegmentor return_labels error.""" + mtsegmentor = MultiTaskSegmentor(model="hovernetplus-oed", device=device) + + with pytest.raises( + ValueError, + match=r".*return_labels` is not supported for MultiTaskSegmentor.", + ): + _ = mtsegmentor.run( + images=[sample_svs], + return_probabilities=False, + return_labels=True, + device=device, + patch_mode=False, + save_dir=track_tmp_path / "wsi_out_check", + batch_size=2, + output_type="zarr", + ) + + +def test_clear_zarr() -> None: + """Test _clear_zarr working appropriately. + + This test only covers scenarios which are not feasible to run on GitHub Actions. + + """ + store = zarr.MemoryStore() + root = zarr.group(store=store) + + # Create a dummy zarr array for probabilities_zarr + probabilities_zarr = root.create_dataset("probabilities", data=np.zeros((5, 3, 3))) + + idx = 2 + chunk_shape = (1,) + probabilities_shape = (3, 3) + + result = _clear_zarr( + probabilities_zarr=probabilities_zarr, + probabilities_da=None, + zarr_group=root, + idx=idx, + chunk_shape=chunk_shape, + probabilities_shape=probabilities_shape, + ) + + # Ensure the keys still exist but the specific index was removed + assert "canvas" not in root + assert "count" not in root + assert isinstance(result, da.Array) + + result_ = _clear_zarr( + probabilities_zarr=None, + probabilities_da=result, + zarr_group=root, + idx=idx, + chunk_shape=chunk_shape, + probabilities_shape=probabilities_shape, + ) + + assert np.all(result_.compute() == result.compute()) + + +# HELPER functions +def assert_output_lengths( + output: OutputType, expected_counts: Sequence[int], fields: list[str] +) -> None: + """Assert lengths of output dict fields against expected counts.""" + for field in fields: + for i, expected in enumerate(expected_counts): + assert len(output[field][i]) == expected, f"{field}[{i}] mismatch" + + +def assert_predictions_and_boxes( + output: OutputType, expected_counts: Sequence[int], *, is_zarr: bool = False +) -> None: + """Assert predictions maxima and box lengths against expected counts.""" + # predictions maxima + for idx, expected in enumerate(expected_counts): + if is_zarr and idx == 2: + # zarr output doesn't store predictions for patch 2 + continue + assert np.max(output["predictions"][idx][:]) == expected, ( + f"predictions[{idx}] mismatch" + ) + + +def assert_output_equal( + output_a: OutputType, + output_b: OutputType, + fields: Sequence[str], + indices_a: Sequence[int], + indices_b: Sequence[int], +) -> None: + """Assert equality of arrays across outputs for given fields/indices.""" + for field in fields: + for i_a, i_b in zip(indices_a, indices_b, strict=False): + left = output_a[field][i_a] + right = output_b[field][i_b] + assert all( + np.array_equal(a, b) for a, b in zip(left, right, strict=False) + ), f"{field}[{i_a}] vs {field}[{i_b}] mismatch" + + +def assert_annotation_store_patch_output( + inputs: list | np.ndarray, + output_ann: list[Path], + task_name: str | None, + track_tmp_path: Path, + expected_counts: Sequence[int], + output_dict: OutputType, + fields: list[str], + class_dict: dict, +) -> None: + """Helper function to test AnnotationStore output.""" + for patch_idx, db_path in enumerate(output_ann): + if isinstance(inputs[patch_idx], Path): + store_file_name = ( + f"{inputs[patch_idx].stem}.db" + if task_name is None + else f"{inputs[patch_idx].stem}_{task_name}.db" + ) + else: + store_file_name = ( + f"{patch_idx}.db" + if task_name is None + else f"{patch_idx}_{task_name}.db" + ) + + assert ( + db_path == track_tmp_path / "patch_output_annotationstore" / store_file_name + ) + store_ = SQLiteStore.open(db_path) + annotations_ = store_.values() + annotations_geometry_type = [ + str(annotation_.geometry_type) for annotation_ in annotations_ + ] + annotations_list = list(annotations_) + if expected_counts[patch_idx] > 0: + assert "Polygon" in annotations_geometry_type + + # Build result dict from annotation properties + result = {} + for ann in annotations_list: + for key, value in ann.properties.items(): + result.setdefault(key, []).append(value) + result["contours"] = [ + list(poly.exterior.coords) + for poly in (a.geometry for a in annotations_list) + ] + + # wrap it to make it compatible to assert_output_lengths + result_ = {field: [result[field]] for field in fields} + + # Lengths and equality checks for this patch + assert_output_lengths( + result_, + expected_counts=[expected_counts[patch_idx]], + fields=fields, + ) + fields_ = fields.copy() + fields_.remove("contours") + + class_dict_ = class_dict[task_name] if task_name else class_dict + type_ = [class_dict_[c_id] for c_id in output_dict["type"][patch_idx]] + output_dict["type"][patch_idx] = type_ + assert_output_equal( + result_, + output_dict, + fields=fields_, + indices_a=[0], + indices_b=[patch_idx], + ) + + # Contour check (discard last point) + matches = [ + np.array_equal(np.array(a[:-1], dtype=int), np.array(b, dtype=int)) + for a, b in zip( + result["contours"], output_dict["contours"][patch_idx], strict=False + ) + ] + # Due to make valid poly there might be translation in a few points + # in AnnotationStore + assert sum(matches) / len(matches) >= 0.95 + else: + assert annotations_geometry_type == [] + assert annotations_list == [] From e3830b15b2e27eba3e009c153665b3f8953e234d Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jan 2026 21:27:29 +0000 Subject: [PATCH 035/156] :bug: Fix PY-W0070 deep source error --- tiatoolbox/models/architecture/hovernet.py | 3 +-- tiatoolbox/models/architecture/hovernetplus.py | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index b5ce75bfe..659f255b3 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -337,8 +337,7 @@ def __init__( self.mode = mode self.num_types = num_types self.nuc_type_dict = nuc_type_dict - self.tasks = [] - self.tasks.append("nuclei_segmentation") + self.tasks = ["nuclei_segmentation"] self.class_dict = { self.tasks[0]: nuc_type_dict, } diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 1cd3f0c8c..c4910e3a6 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -96,9 +96,7 @@ def __init__( self.num_layers = num_layers self.nuc_type_dict = nuc_type_dict self.layer_type_dict = layer_type_dict - self.tasks = [] - self.tasks.append("nuclei_segmentation") - self.tasks.append("layer_segmentation") + self.tasks = ["nuclei_segmentation", "layer_segmentation"] self.class_dict = { self.tasks[0]: nuc_type_dict, self.tasks[1]: layer_type_dict, From 528b72e1ebad03c46a9f6ccb075307f440b7b527 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jan 2026 22:02:44 +0000 Subject: [PATCH 036/156] :zap: Use smaller images for faster test runs. --- tests/engines/test_multi_task_segmentor.py | 63 ++++++++++++---------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 7d36a05b8..32b621350 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -275,10 +275,11 @@ def test_single_task_mtsegmentor( def test_wsi_mtsegmentor_zarr( - sample_svs: Path, + remote_sample: Callable, track_tmp_path: Path, ) -> None: """Test MultiTaskSegmentor for WSIs with zarr output.""" + wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") mtsegmentor = MultiTaskSegmentor( model="hovernetplus-oed", batch_size=64, @@ -287,7 +288,7 @@ def test_wsi_mtsegmentor_zarr( ) # Return Probabilities is False output = mtsegmentor.run( - images=[sample_svs], + images=[wsi4_512_512_svs], return_probabilities=False, return_labels=False, device=device, @@ -299,9 +300,9 @@ def test_wsi_mtsegmentor_zarr( stride_shape=(160, 160), ) - output_ = zarr.open(output[sample_svs], mode="r") - assert 18 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 22 - assert 0.43 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 0.45 + output_ = zarr.open(output[wsi4_512_512_svs], mode="r") + assert 0.8 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 1.0 + assert 0.57 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 0.61 assert "probabilities" not in output_ assert "canvas" not in output_["nuclei_segmentation"] assert "count" not in output_["nuclei_segmentation"] @@ -311,11 +312,17 @@ def test_wsi_mtsegmentor_zarr( def test_multi_input_wsi_mtsegmentor_zarr( remote_sample: Callable, - sample_svs: Path, track_tmp_path: Path, ) -> None: """Test MultiTaskSegmentor for multiple WSIs with zarr output.""" - wsi1_2k_2k_svs = Path(remote_sample("wsi1_2k_2k_svs")) + wsi4_512_512_svs = Path(remote_sample("wsi4_512_512_svs")) + wsi4_512_512_svs_2 = wsi4_512_512_svs.parent / ( + wsi4_512_512_svs.stem + "_2" + wsi4_512_512_svs.suffix + ) + wsi4_512_512_svs_2 = Path( + shutil.copy(str(wsi4_512_512_svs), str(wsi4_512_512_svs_2)) + ) + # Return Probabilities is True # Add multi-input test # Use single task output from hovernet @@ -326,7 +333,7 @@ def test_multi_input_wsi_mtsegmentor_zarr( num_workers=1, ) output = mtsegmentor.run( - images=[sample_svs, wsi1_2k_2k_svs], + images=[wsi4_512_512_svs_2, wsi4_512_512_svs], return_probabilities=True, return_labels=False, device=device, @@ -338,24 +345,24 @@ def test_multi_input_wsi_mtsegmentor_zarr( verbose=True, ) - output_ = zarr.open(output[sample_svs], mode="r") - assert 34 < np.mean(output_["predictions"][:]) < 38 + output_ = zarr.open(output[wsi4_512_512_svs], mode="r") + assert 23 < np.mean(output_["predictions"][:]) < 27 assert "probabilities" in output_ assert "canvas" not in output_ assert "count" not in output_ - output_ = zarr.open(output[wsi1_2k_2k_svs], mode="r") - assert 136 < np.mean(output_["predictions"][:]) < 140 + output_ = zarr.open(output[wsi4_512_512_svs_2], mode="r") + assert 23 < np.mean(output_["predictions"][:]) < 27 assert "probabilities" in output_ assert "canvas" not in output_ assert "count" not in output_ - shutil.rmtree(output[sample_svs]) - shutil.rmtree(output[wsi1_2k_2k_svs]) - -def test_wsi_segmentor_annotationstore(sample_svs: Path, track_tmp_path: Path) -> None: +def test_wsi_segmentor_annotationstore( + remote_sample: Callable, track_tmp_path: Path +) -> None: """Test MultiTaskSegmentor for WSIs with AnnotationStore output.""" + wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") mtsegmentor = MultiTaskSegmentor( model="hovernet_fast-pannuke", batch_size=32, @@ -366,7 +373,7 @@ def test_wsi_segmentor_annotationstore(sample_svs: Path, track_tmp_path: Path) - # Return Probabilities is False output = mtsegmentor.run( - images=[sample_svs], + images=[wsi4_512_512_svs], return_probabilities=False, device=device, patch_mode=False, @@ -376,19 +383,20 @@ def test_wsi_segmentor_annotationstore(sample_svs: Path, track_tmp_path: Path) - class_dict=class_dict, ) - for output_ in output[sample_svs]: + for output_ in output[wsi4_512_512_svs]: assert output_.suffix != ".zarr" - store_file_name = f"{sample_svs.stem}.db" + store_file_name = f"{wsi4_512_512_svs.stem}.db" store_file_path = track_tmp_path / "wsi_out_check" / store_file_name assert store_file_path.exists() - assert store_file_path == output[sample_svs][0] + assert store_file_path == output[wsi4_512_512_svs][0] def test_wsi_segmentor_annotationstore_probabilities( - sample_svs: Path, track_tmp_path: Path, caplog: pytest.CaptureFixture + remote_sample: Callable, track_tmp_path: Path, caplog: pytest.CaptureFixture ) -> None: """Test MultiTaskSegmentor with AnnotationStore and probabilities output.""" + wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") # Return Probabilities is True mtsegmentor = MultiTaskSegmentor( model="hovernetplus-oed", @@ -397,7 +405,7 @@ def test_wsi_segmentor_annotationstore_probabilities( ) output = mtsegmentor.run( - images=[sample_svs], + images=[wsi4_512_512_svs], return_probabilities=True, device=device, patch_mode=False, @@ -407,22 +415,23 @@ def test_wsi_segmentor_annotationstore_probabilities( ) assert "Probability maps cannot be saved as AnnotationStore." in caplog.text - zarr_group = zarr.open(output[sample_svs][0], mode="r") + zarr_group = zarr.open(output[wsi4_512_512_svs][0], mode="r") assert "probabilities" in zarr_group for task_name in mtsegmentor.tasks: - store_file_name = f"{sample_svs.stem}_{task_name}.db" + store_file_name = f"{wsi4_512_512_svs.stem}_{task_name}.db" store_file_path = track_tmp_path / "wsi_prob_out_check" / store_file_name assert store_file_path.exists() - assert store_file_path in output[sample_svs] + assert store_file_path in output[wsi4_512_512_svs] assert task_name not in zarr_group def test_raise_value_error_return_labels_wsi( - sample_svs: Path, + remote_sample: Callable, track_tmp_path: Path, ) -> None: """Tests MultiTaskSegmentor return_labels error.""" + wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") mtsegmentor = MultiTaskSegmentor(model="hovernetplus-oed", device=device) with pytest.raises( @@ -430,7 +439,7 @@ def test_raise_value_error_return_labels_wsi( match=r".*return_labels` is not supported for MultiTaskSegmentor.", ): _ = mtsegmentor.run( - images=[sample_svs], + images=[wsi4_512_512_svs], return_probabilities=False, return_labels=True, device=device, From 8e5578ab3bb4e44b34af61b7af92ac14f7de28b0 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 20 Jan 2026 10:21:43 +0000 Subject: [PATCH 037/156] :bug: Fix memory check for saving arrays --- tiatoolbox/models/engine/multi_task_segmentor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 9cd7c8d89..dc71701e4 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1110,8 +1110,9 @@ def _save_multitask_vertical_to_cache( used_percent = 0 if probabilities_da[idx] is not None: vm = psutil.virtual_memory() - total_bytes = ( - sum(arr.nbytes for arr in probabilities_da) if probabilities_da else 0 + # Calculate total bytes for all outputs + total_bytes = sum( + arr.nbytes if arr is not None else 0 for arr in probabilities_da ) used_percent = (total_bytes / vm.free) * 100 if probabilities_zarr[idx] is None and used_percent > memory_threshold: @@ -1167,7 +1168,7 @@ def _calculate_probabilities( ) -> list[da.Array]: """Helper function to calculate probabilities for MultiTaskSegmentor.""" zarr_group = None - if canvas_zarr is not None: + if canvas_zarr[0] is not None: canvas_zarr, count_zarr = save_multitask_to_cache( canvas, count, canvas_zarr, count_zarr ) From 60faab5383809f58097f4a34d8b8ec89a13e7b39 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 20 Jan 2026 10:23:27 +0000 Subject: [PATCH 038/156] :lipstick: Update logic for better readability --- tiatoolbox/models/engine/multi_task_segmentor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index dc71701e4..3c69306fd 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1111,9 +1111,7 @@ def _save_multitask_vertical_to_cache( if probabilities_da[idx] is not None: vm = psutil.virtual_memory() # Calculate total bytes for all outputs - total_bytes = sum( - arr.nbytes if arr is not None else 0 for arr in probabilities_da - ) + total_bytes = sum(0 if arr is None else arr.nbytes for arr in probabilities_da) used_percent = (total_bytes / vm.free) * 100 if probabilities_zarr[idx] is None and used_percent > memory_threshold: msg = ( From ddc2a540d5da6c8c90c54657b08a6b4f9554790c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 20 Jan 2026 11:45:31 +0000 Subject: [PATCH 039/156] :bug: Remove residual canvas and count keys --- tests/engines/test_multi_task_segmentor.py | 62 ++++++++++++++++++- .../models/engine/multi_task_segmentor.py | 19 +++--- 2 files changed, 71 insertions(+), 10 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 32b621350..f9e63f1e6 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -4,10 +4,11 @@ import shutil from pathlib import Path -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, ClassVar, Final import dask.array as da import numpy as np +import psutil import pytest import torch import zarr @@ -16,6 +17,7 @@ from tiatoolbox.models.engine.multi_task_segmentor import ( MultiTaskSegmentor, _clear_zarr, + _save_multitask_vertical_to_cache, ) from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils import imwrite @@ -381,6 +383,7 @@ def test_wsi_segmentor_annotationstore( verbose=True, output_type="annotationstore", class_dict=class_dict, + memory_threshold=0, ) for output_ in output[wsi4_512_512_svs]: @@ -492,6 +495,63 @@ def test_clear_zarr() -> None: assert np.all(result_.compute() == result.compute()) +def test_vertical_save_branch_without_patch( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Test saving to cache if memory threshold is breached for vertical merge.""" + idx = 0 + + # --- Fake psutil.virtual_memory() with extremely low free memory --- + class FakeVM: + free = 1 # force used_percent > memory_threshold + + monkeypatch.setattr(psutil, "virtual_memory", lambda: FakeVM()) + + # --- Real dask array --- + da_arr = da.from_array(np.array([[1, 2, 3]]), chunks=(1, 3)) + probabilities_da = [da_arr] + + # --- probabilities_zarr slot is None to trigger the branch --- + probabilities_zarr = [None] + + # --- Real numpy array for shape/dtype --- + probabilities = np.zeros((1, 3)) + + # --- Dummy tqdm with a write() method --- + class DummyTqdm: + messages: ClassVar[list[str]] = [] + + @classmethod + def write(cls: DummyTqdm, msg: str) -> None: + cls.messages.append(msg) + + # --- Call function --- + new_zarr, new_da = _save_multitask_vertical_to_cache( + probabilities_zarr=probabilities_zarr, + probabilities_da=probabilities_da, + probabilities=probabilities, + idx=idx, + tqdm_=DummyTqdm, + save_path=tmp_path / "cache.zarr", + chunk_shape=(1,), + memory_threshold=0, # ensure branch triggers + ) + + # --- Assertions --- + # tqdm.write was called + assert len(DummyTqdm.messages) == 1 + assert "Saving intermediate results to disk" in DummyTqdm.messages[0] + + # probabilities_da must be set to None + assert new_da[idx] is None + + # new_zarr must be a real zarr array + assert isinstance(new_zarr[idx], zarr.Array) + + # Data was written correctly + assert np.array_equal(new_zarr[idx][:], np.array([[1, 2, 3]])) + + # HELPER functions def assert_output_lengths( output: OutputType, expected_counts: Sequence[int], fields: list[str] diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 3c69306fd..93fc4d8da 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -497,7 +497,8 @@ def _save_predictions_as_annotationstore( logger.info("Saving predictions as AnnotationStore.") # predictions are not required when saving to AnnotationStore. - processed_predictions.pop("predictions") + for key in ("canvas", "count", "predictions"): + processed_predictions.pop(key, None) keys_to_compute = list(processed_predictions.keys()) if "probabilities" in keys_to_compute: @@ -1097,8 +1098,8 @@ def merge_multitask_vertical_chunkwise( def _save_multitask_vertical_to_cache( - probabilities_zarr: list[zarr.Array], - probabilities_da: list[da.Array] | None, + probabilities_zarr: list[zarr.Array] | list[None], + probabilities_da: list[da.Array] | list[None], probabilities: np.ndarray, idx: int, tqdm_: type[tqdm_notebook | tqdm], @@ -1121,16 +1122,16 @@ def _save_multitask_vertical_to_cache( ) tqdm_.write(msg) zarr_group = zarr.open(str(save_path), mode="a") - probabilities_zarr = zarr_group.create_dataset( + probabilities_zarr[idx] = zarr_group.create_dataset( name=f"probabilities/{idx}", - shape=probabilities_da.shape, + shape=probabilities_da[idx].shape, chunks=(chunk_shape[0], *probabilities.shape[1:]), dtype=probabilities.dtype, overwrite=True, ) - probabilities_zarr[idx][:] = probabilities_da.compute() + probabilities_zarr[idx][:] = probabilities_da[idx].compute() - probabilities_da = None + probabilities_da[idx] = None return probabilities_zarr, probabilities_da @@ -1156,8 +1157,8 @@ def _clear_zarr( def _calculate_probabilities( - canvas_zarr: list[zarr.Array | None], - count_zarr: list[zarr.Array | None], + canvas_zarr: list[zarr.Array] | list[None], + count_zarr: list[zarr.Array] | list[None], canvas: list[da.Array | None], count: list[da.Array | None], output_locs_y_: np.ndarray, From 8cd457ce555cb168d144ed798e9018f275195d58 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 20 Jan 2026 11:56:27 +0000 Subject: [PATCH 040/156] :bug: Fix deepsource errors --- tests/engines/test_multi_task_segmentor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index f9e63f1e6..5cff72b35 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -501,11 +501,12 @@ def test_vertical_save_branch_without_patch( """Test saving to cache if memory threshold is breached for vertical merge.""" idx = 0 - # --- Fake psutil.virtual_memory() with extremely low free memory --- class FakeVM: + """Fake psutil.virtual_memory() with extremely low free memory.""" + free = 1 # force used_percent > memory_threshold - monkeypatch.setattr(psutil, "virtual_memory", lambda: FakeVM()) + monkeypatch.setattr(psutil, "virtual_memory", FakeVM) # --- Real dask array --- da_arr = da.from_array(np.array([[1, 2, 3]]), chunks=(1, 3)) @@ -517,12 +518,14 @@ class FakeVM: # --- Real numpy array for shape/dtype --- probabilities = np.zeros((1, 3)) - # --- Dummy tqdm with a write() method --- class DummyTqdm: + """Dummy tqdm with a write() method.""" + messages: ClassVar[list[str]] = [] @classmethod def write(cls: DummyTqdm, msg: str) -> None: + """Append a message to the messages list.""" cls.messages.append(msg) # --- Call function --- From 4bb49006e4ddbcbade52ca3d383e401451d5de0d Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 23 Jan 2026 16:03:18 +0000 Subject: [PATCH 041/156] :bug: Fix count in merge_predictions --- tiatoolbox/models/engine/semantic_segmentor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index f3417bcf1..2405c9374 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1024,7 +1024,7 @@ def merge_batch_to_canvas( continue # To deal with edge cases canvas[0 : ye - ys, xs:xe, :] += block[0 : ye - ys, 0 : xe - xs, :] - count[ys:ye, xs:xe, 0] += 1 + count[0 : ye - ys, xs:xe, 0] += 1 return canvas, count From 5e659c0ed2ea0fced5afb3d423e361bbd9c51137 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 23 Jan 2026 16:18:51 +0000 Subject: [PATCH 042/156] :bug: Fix count in merge_predictions --- tests/engines/test_multi_task_segmentor.py | 2 +- tests/engines/test_semantic_segmentor.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 5cff72b35..f7de5e468 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -303,7 +303,7 @@ def test_wsi_mtsegmentor_zarr( ) output_ = zarr.open(output[wsi4_512_512_svs], mode="r") - assert 0.8 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 1.0 + assert 15 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 19 assert 0.57 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 0.61 assert "probabilities" not in output_ assert "canvas" not in output_["nuclei_segmentation"] diff --git a/tests/engines/test_semantic_segmentor.py b/tests/engines/test_semantic_segmentor.py index 30647cf21..2b8d5fbf3 100644 --- a/tests/engines/test_semantic_segmentor.py +++ b/tests/engines/test_semantic_segmentor.py @@ -432,7 +432,7 @@ def test_wsi_segmentor_zarr( output_ = zarr.open(output[sample_svs], mode="r") assert 0.17 < np.mean(output_["predictions"][:]) < 0.19 - assert 0.52 < np.mean(output_["probabilities"][:]) < 0.56 + assert 0.48 < np.mean(output_["probabilities"][:]) < 0.52 output_ = zarr.open(output[wsi1_2k_2k_svs], mode="r") assert 0.24 < np.mean(output_["predictions"][:]) < 0.25 @@ -497,15 +497,16 @@ def test_wsi_segmentor_annotationstore( # ------------------------------------------------------------------------------------- -def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None: +def test_cli_model_single_file(remote_sample: Callable, track_tmp_path: Path) -> None: """Test semantic segmentor CLI single file.""" + wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") runner = CliRunner() models_wsi_result = runner.invoke( cli.main, [ "semantic-segmentor", "--img-input", - str(sample_svs), + str(wsi4_512_512_svs), "--patch-mode", "False", "--output-path", @@ -514,4 +515,4 @@ def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None: ) assert models_wsi_result.exit_code == 0 - assert (track_tmp_path / "output" / (sample_svs.stem + ".db")).exists() + assert (track_tmp_path / "output" / (wsi4_512_512_svs.stem + ".db")).exists() From c0decb6738e992cad4900dd3eaee8b205f9b8b81 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sun, 25 Jan 2026 11:08:37 +0000 Subject: [PATCH 043/156] :bug: Fix HoVerNetPlus postprocessing. --- tiatoolbox/models/architecture/hovernetplus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index c4910e3a6..7c17b8468 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -369,7 +369,7 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] } layer_info_dict_ = {} - if not nuc_inst_info_dict: + if not layer_info_dict: layer_info_dict_ = { # inst_id should start at 1 "contours": da.empty(shape=0), "type": da.empty(shape=0), From d934167944a45abc37368eae0980fc9be2bde184 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:27:00 +0000 Subject: [PATCH 044/156] :construction: Merge using tile outputs. --- tests/engines/test_multi_task_segmentor.py | 3 +- tiatoolbox/data/pretrained_model.yaml | 2 +- tiatoolbox/models/architecture/hovernet.py | 2 +- .../models/architecture/hovernetplus.py | 8 +- .../models/engine/multi_task_segmentor.py | 442 +++++++++++++++++- 5 files changed, 450 insertions(+), 7 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index f7de5e468..5c7bab463 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -281,7 +281,7 @@ def test_wsi_mtsegmentor_zarr( track_tmp_path: Path, ) -> None: """Test MultiTaskSegmentor for WSIs with zarr output.""" - wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") + wsi4_512_512_svs = remote_sample("wsi2_4k_4k_svs") mtsegmentor = MultiTaskSegmentor( model="hovernetplus-oed", batch_size=64, @@ -299,7 +299,6 @@ def test_wsi_mtsegmentor_zarr( batch_size=2, output_type="zarr", memory_threshold=1, - stride_shape=(160, 160), ) output_ = zarr.open(output[wsi4_512_512_svs], mode="r") diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 9f715593e..2e10ba199 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -779,7 +779,7 @@ hovernetplus-oed: - {"units": "mpp", "resolution": 0.50} - {"units": "mpp", "resolution": 0.50} margin: 128 - tile_shape: [2048, 2048] + tile_shape: [512, 512] patch_input_shape: [256, 256] patch_output_shape: [164, 164] stride_shape: [164, 164] diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 659f255b3..c23ee1323 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -808,7 +808,7 @@ def postproc(self: HoVerNet, raw_maps: list[np.ndarray]) -> tuple[dict, ...]: nuclei_seg = { "task_type": self.tasks[0], - "predictions": pred_inst, + "predictions": da.array(pred_inst) if isinstance(raw_maps[0], da.Array) else pred_inst, "info_dict": nuc_inst_info_dict_, } diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 7c17b8468..37477c10c 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -15,6 +15,7 @@ from tiatoolbox.models.architecture.hovernet import HoVerNet from tiatoolbox.models.architecture.utils import UpSample2x +from tiatoolbox.utils.misc import get_bounding_box class HoVerNetPlus(HoVerNet): @@ -225,6 +226,7 @@ def _get_layer_info(pred_layer: np.ndarray) -> dict: for type_class in layer_list: layer = np.where(pred_layer == type_class, 1, 0).astype("uint8") + bounding_box = get_bounding_box(layer) contours, _ = cv2.findContours( layer.astype("uint8"), cv2.RETR_TREE, @@ -245,6 +247,7 @@ def _get_layer_info(pred_layer: np.ndarray) -> dict: coords = layer[:, 0, :] layer_info_dict[count] = { + "box": bounding_box, "contours": coords, "type": type_class, } @@ -364,13 +367,14 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] nuclei_seg = { "task_type": self.tasks[0], - "predictions": pred_inst, + "predictions": da.array(pred_inst) if isinstance(raw_maps[0], da.Array) else pred_inst, "info_dict": nuc_inst_info_dict_, } layer_info_dict_ = {} if not layer_info_dict: layer_info_dict_ = { # inst_id should start at 1 + "box": da.empty(shape=0), "contours": da.empty(shape=0), "type": da.empty(shape=0), } @@ -385,7 +389,7 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] layer_seg = { "task_type": self.tasks[1], - "predictions": pred_layer, + "predictions": da.array(pred_layer) if isinstance(raw_maps[0], da.Array) else pred_layer, "info_dict": layer_info_dict_, } diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 93fc4d8da..a3174d391 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -3,6 +3,7 @@ from __future__ import annotations import gc +import uuid from pathlib import Path from typing import TYPE_CHECKING @@ -28,6 +29,10 @@ merge_batch_to_canvas, store_probabilities, ) +from tiatoolbox.tools.patchextraction import PatchExtractor +from shapely.geometry import box as shapely_box +from shapely.strtree import STRtree +from collections import deque if TYPE_CHECKING: # pragma: no cover import os @@ -334,7 +339,18 @@ def post_process_wsi( # skipcq: PYL-R0201 ) -> dict: """Post-process raw patch predictions from inference.""" probabilities = raw_predictions["probabilities"] - post_process_predictions = self.model.postproc_func(probabilities) + + probabilities_is_zarr = False + for probabilities_ in probabilities: + if any("from-zarr" in str(key) for key in probabilities_.dask.layers.keys()): + probabilities_is_zarr = True + break + + # If dask array can fit in memory process without tiling. + if not probabilities_is_zarr: + post_process_predictions = self.model.postproc_func(probabilities) + else: + post_process_predictions = self._process_tile_mode(probabilities) tasks = set() for seg in post_process_predictions: @@ -356,6 +372,250 @@ def post_process_wsi( # skipcq: PYL-R0201 return raw_predictions + def _process_tile_mode( + self: MultiTaskSegmentor, + probabilities: list[da.Array | np.ndarray], + *, + return_predictions: bool = False, + ) -> list[dict]: + """Helper function to process WSI in tile mode.""" + highest_input_resolution = self.ioconfig.highest_input_resolution + wsi_reader = self.dataloader.dataset.reader + + # assume ioconfig has already been converted to `baseline` for `tile` mode + wsi_proc_shape = wsi_reader.slide_dimensions(**highest_input_resolution) + + # * retrieve tile placement and tile info flag + # tile shape will always be corrected to be multiple of output + tile_info_sets = self._get_tile_info(wsi_proc_shape, self.ioconfig) + ioconfig = self.ioconfig.to_baseline() + + merged = [] + wsi_info_dict = None + for set_idx, (set_bounds, set_flags) in enumerate(tile_info_sets): + for tile_idx, tile_bounds in enumerate(set_bounds): + tile_flag = set_flags[tile_idx] + tile_tl = tile_bounds[:2] + tile_br = tile_bounds[2:] + tile_shape = tile_br - tile_tl # in width height + head_raws = [ + probabilities_[ + tile_bounds[1] : tile_bounds[3], + tile_bounds[0] : tile_bounds[2], + :, + ] + for probabilities_ in probabilities + ] + post_process_output = self.model.postproc_func(head_raws) + + # create a list for info dict for each task + wsi_info_dict = [{} for _ in post_process_output] if wsi_info_dict is None else wsi_info_dict + inst_dicts = _get_inst_dicts(post_process_output=post_process_output) + + tile_mode = set_idx + new_inst_dicts, remove_insts_in_origs = [], [] + for inst_id, inst_dict in enumerate(inst_dicts): + new_inst_dict, remove_insts_in_orig = _process_instance_predictions( + inst_dict, + ioconfig, + tile_shape, + tile_flag, + tile_mode, + tile_tl, + wsi_info_dict[inst_id], + ) + new_inst_dicts.append(new_inst_dict) + remove_insts_in_origs.append(remove_insts_in_orig) + + merged.append((new_inst_dicts, remove_insts_in_origs)) + + for new_inst_dicts, remove_uuid_lists in merged: + for inst_id, new_inst_dict in enumerate(new_inst_dicts): + wsi_info_dict[inst_id].update(new_inst_dict) + for inst_uuid in remove_uuid_lists[inst_id]: + wsi_info_dict[inst_id].pop(inst_uuid, None) + + return wsi_info_dict + + @staticmethod + def _get_tile_info( + image_shape: list[int] | np.ndarray, + ioconfig: IOSegmentorConfig, + ) -> list[list, ...]: + """Generating tile information. + + To avoid out of memory problem when processing WSI-scale in + general, the predictor will perform the inference and assemble + on a large image tiles (each may have size of 4000x4000 compared + to patch output of 256x256) first before stitching every tiles + by the end to complete the WSI output. For nuclei instance + segmentation, the stitching process will require removal of + predictions within some bounding areas. This function generates + both the tile placement and the flag to indicate how the removal + should be done to achieve the above goal. + + Args: + image_shape (:class:`numpy.ndarray`, list(int)): + The shape of WSI to extract the tile from, assumed to be + in `[width, height]`. + ioconfig (:obj:IOSegmentorConfig): + The input and output configuration objects. + + Returns: + list: + - :py:obj:`list` - Tiles and flags + - :class:`numpy.ndarray` - Grid tiles + - :class:`numpy.ndarray` - Removal flags + - :py:obj:`list` - Tiles and flags + - :class:`numpy.ndarray` - Vertical strip tiles + - :class:`numpy.ndarray` - Removal flags + - :py:obj:`list` - Tiles and flags + - :class:`numpy.ndarray` - Horizontal strip tiles + - :class:`numpy.ndarray` - Removal flags + - :py:obj:`list` - Tiles and flags + - :class:`numpy.ndarray` - Cross-section tiles + - :class:`numpy.ndarray` - Removal flags + + """ + margin = np.array(ioconfig.margin) + tile_shape = np.array(ioconfig.tile_shape) + tile_shape = ( + np.floor(tile_shape / ioconfig.patch_output_shape) + * ioconfig.patch_output_shape + ).astype(np.int32) + image_shape = np.array(image_shape) + tile_outputs = PatchExtractor.get_coordinates( + image_shape=image_shape, + patch_input_shape=tile_shape, + patch_output_shape=tile_shape, + stride_shape=tile_shape, + ) + + # * === Now generating the flags to indicate which side should + # * === be removed in postproc callback + boxes = tile_outputs[1] + + # This saves computation time if the image is smaller than the expected tile + if np.all(image_shape <= tile_shape): + flag = np.zeros([boxes.shape[0], 4], dtype=np.int32) + return [[boxes, flag]] + + # * remove all sides for boxes + # unset for those lie within the selection + def unset_removal_flag(boxes: tuple, removal_flag: np.ndarray) -> np.ndarray: + """Unset removal flags for tiles intersecting image boundaries.""" + sel_boxes = [ + shapely_box(0, 0, w, 0), # top edge + shapely_box(0, h, w, h), # bottom edge + shapely_box(0, 0, 0, h), # left + shapely_box(w, 0, w, h), # right + ] + geometries = [shapely_box(*bounds) for bounds in boxes] + spatial_indexer = STRtree(geometries) + + for idx, sel_box in enumerate(sel_boxes): + sel_indices = list(spatial_indexer.query(sel_box)) + removal_flag[sel_indices, idx] = 0 + return removal_flag + + w, h = image_shape + boxes = tile_outputs[1] + # expand to full four corners + boxes_br = boxes[:, 2:] + boxes_tr = np.dstack([boxes[:, 2], boxes[:, 1]])[0] + boxes_bl = np.dstack([boxes[:, 0], boxes[:, 3]])[0] + + # * remove edges on all sides, excluding edges at on WSI boundary + flag = np.ones([boxes.shape[0], 4], dtype=np.int32) + flag = unset_removal_flag(boxes, flag) + info = deque([[boxes, flag]]) + + # * create vertical boxes at tile boundary and + # * flag top and bottom removal, excluding those + # * on the WSI boundary + # ------------------- + # | =|= =|= | + # | =|= =|= | + # | >=|= >=|= | + # ------------------- + # | >=|= >=|= | + # | =|= =|= | + # | >=|= >=|= | + # ------------------- + # | >=|= >=|= | + # | =|= =|= | + # | =|= =|= | + # ------------------- + # only select boxes having right edges removed + sel_indices = np.nonzero(flag[..., 3]) + _boxes = np.concatenate( + [ + boxes_tr[sel_indices] - np.array([margin, 0])[None], + boxes_br[sel_indices] + np.array([margin, 0])[None], + ], + axis=-1, + ) + _flag = np.full([_boxes.shape[0], 4], 0, dtype=np.int32) + _flag[:, [0, 1]] = 1 + _flag = unset_removal_flag(_boxes, _flag) + info.append([_boxes, _flag]) + + # * create horizontal boxes at tile boundary and + # * flag left and right removal, excluding those + # * on the WSI boundary + # ------------- + # | | | | + # | v|v v|v | + # |===|===|===| + # ------------- + # |===|===|===| + # | | | | + # | | | | + # ------------- + # only select boxes having bottom edges removed + sel_indices = np.nonzero(flag[..., 1]) + # top bottom left right + _boxes = np.concatenate( + [ + boxes_bl[sel_indices] - np.array([0, margin])[None], + boxes_br[sel_indices] + np.array([0, margin])[None], + ], + axis=-1, + ) + _flag = np.full([_boxes.shape[0], 4], 0, dtype=np.int32) + _flag[:, [2, 3]] = 1 + _flag = unset_removal_flag(_boxes, _flag) + info.append([_boxes, _flag]) + + # * create boxes at tile cross-section and all sides + # ------------------------ + # | | | | | + # | v| | | | + # | > =|= =|= =|= | + # -----=-=---=-=---=-=---- + # | =|= =|= =|= | + # | | | | | + # | =|= =|= =|= | + # -----=-=---=-=---=-=---- + # | =|= =|= =|= | + # | | | | | + # | | | | | + # ------------------------ + + # only select boxes having both right and bottom edges removed + sel_indices = np.nonzero(np.prod(flag[:, [1, 3]], axis=-1)) + _boxes = np.concatenate( + [ + boxes_br[sel_indices] - np.array([2 * margin, 2 * margin])[None], + boxes_br[sel_indices] + np.array([2 * margin, 2 * margin])[None], + ], + axis=-1, + ) + flag = np.full([_boxes.shape[0], 4], 1, dtype=np.int32) + info.append([_boxes, flag]) + + return info + def build_post_process_raw_predictions( self: MultiTaskSegmentor, post_process_predictions: list[tuple], @@ -1284,3 +1544,183 @@ def _save_annotation_store( store.dump(output_path) return output_path + +def _process_instance_predictions( + inst_dict: dict, + ioconfig: IOSegmentorConfig, + tile_shape: list, + tile_flag: list, + tile_mode: int, + tile_tl: tuple, + ref_inst_dict: dict, +) -> list | tuple: + """Function to merge new tile prediction with existing prediction. + + Args: + inst_dict (dict): Dictionary containing instance information. + ioconfig (:class:`IOSegmentorConfig`): Object defines information + about input and output placement of patches. + tile_shape (list): A list of the tile shape. + tile_flag (list): A list of flag to indicate if instances within + an area extended from each side (by `ioconfig.margin`) of + the tile should be replaced by those within the same spatial + region in the accumulated output this run. The format is + [top, bottom, left, right], 1 indicates removal while 0 is not. + For example, [1, 1, 0, 0] denotes replacing top and bottom instances + within `ref_inst_dict` with new ones after this processing. + tile_mode (int): A flag to indicate the type of this tile. There + are 4 flags: + - 0: A tile from tile grid without any overlapping, it is not + an overlapping tile from tile generation. The predicted + instances are immediately added to accumulated output. + - 1: Vertical tile strip that stands between two normal tiles + (flag 0). It has the same height as normal tile but + less width (hence vertical strip). + - 2: Horizontal tile strip that stands between two normal tiles + (flag 0). It has the same width as normal tile but + less height (hence horizontal strip). + - 3: tile strip stands at the cross-section of four normal tiles + (flag 0). + tile_tl (tuple): Top left coordinates of the current tile. + ref_inst_dict (dict): Dictionary contains accumulated output. The + expected format is {instance_id: {type: int, + contour: List[List[int]], centroid:List[float], box:List[int]}. + + Returns: + new_inst_dict (dict): A dictionary contain new instances to be accumulated. + The expected format is {instance_id: {type: int, + contour: List[List[int]], centroid:List[float], box:List[int]}. + remove_insts_in_orig (list): List of instance id within `ref_inst_dict` + to be removed to prevent overlapping predictions. These instances + are those get cutoff at the boundary due to the tiling process. + + """ + # should be rare, no nuclei detected in input images + if len(inst_dict) == 0: + return {}, [] + + # ! + m = ioconfig.margin + w, h = tile_shape + inst_boxes = [v["box"] for v in inst_dict.values()] + inst_boxes = np.array(inst_boxes) + + geometries = [shapely_box(*bounds) for bounds in inst_boxes] + tile_rtree = STRtree(geometries) + # ! + + # create margin bounding box, ordering should match with + # created tile info flag (top, bottom, left, right) + boundary_lines = [ + shapely_box(0, 0, w, 1), # top egde + shapely_box(0, h - 1, w, h), # bottom edge + shapely_box(0, 0, 1, h), # left + shapely_box(w - 1, 0, w, h), # right + ] + margin_boxes = [ + shapely_box(0, 0, w, m), # top egde + shapely_box(0, h - m, w, h), # bottom edge + shapely_box(0, 0, m, h), # left + shapely_box(w - m, 0, w, h), # right + ] + # ! this is wrt to WSI coord space, not tile + margin_lines = [ + [[m, m], [w - m, m]], # top egde + [[m, h - m], [w - m, h - m]], # bottom edge + [[m, m], [m, h - m]], # left + [[w - m, m], [w - m, h - m]], # right + ] + margin_lines = np.array(margin_lines) + tile_tl[None, None] + margin_lines = [shapely_box(*v.flatten().tolist()) for v in margin_lines] + + # the ids within this match with those within `inst_map`, not UUID + sel_indices = [] + if tile_mode in [0, 3]: + # for `full grid` tiles `cross section` tiles + # -- extend from the boundary by the margin size, remove + # nuclei whose entire contours lie within the margin area + sel_boxes = [ + box + for idx, box in enumerate(margin_boxes) + if tile_flag[idx] or tile_mode == 3 # noqa: PLR2004 + ] + + sel_indices = [ + geo + for bounds in sel_boxes + for geo in tile_rtree.query(bounds) + if bounds.contains(geometries[geo]) + ] + elif tile_mode in [1, 2]: + # for `horizontal/vertical strip` tiles + # -- extend from the marked edges (top/bot or left/right) by + # the margin size, remove all nuclei lie within the margin + # area (including on the margin line) + # -- remove all nuclei on the boundary also + + sel_boxes = [ + margin_boxes[idx] if flag else boundary_lines[idx] + for idx, flag in enumerate(tile_flag) + ] + + sel_indices = [geo for bounds in sel_boxes for geo in tile_rtree.query(bounds)] + else: + msg = f"Unknown tile mode {tile_mode}." + raise ValueError(msg) + + def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: + """Helper to retrieved selected instance uids.""" + if len(sel_indices) > 0: + # not sure how costly this is in large dict + inst_uids = list(inst_dict.keys()) + return [inst_uids[idx] for idx in sel_indices] + + remove_insts_in_tile = retrieve_sel_uids(sel_indices, inst_dict) + + # external removal only for tile at cross-sections + # this one should contain UUID with the reference database + remove_insts_in_orig = [] + if tile_mode == 3: # noqa: PLR2004 + inst_boxes = [v["box"] for v in ref_inst_dict.values()] + inst_boxes = np.array(inst_boxes) + + geometries = [shapely_box(*bounds) for bounds in inst_boxes] + ref_inst_rtree = STRtree(geometries) + sel_indices = [ + geo for bounds in margin_lines for geo in ref_inst_rtree.query(bounds) + ] + + remove_insts_in_orig = retrieve_sel_uids(sel_indices, ref_inst_dict) + + # move inst position from tile space back to WSI space + # an also generate universal uid as replacement for storage + new_inst_dict = {} + for inst_uid, inst_info in inst_dict.items(): + if inst_uid not in remove_insts_in_tile: + inst_info["box"] += np.concatenate([tile_tl] * 2) + if "centroid" in inst_info: + inst_info["centroid"] += tile_tl + inst_info["contours"] += tile_tl + inst_uuid = uuid.uuid4().hex + new_inst_dict[inst_uuid] = inst_info + return new_inst_dict, remove_insts_in_orig + + +def _get_inst_dicts(post_process_output: tuple[dict]) -> list: + inst_dicts = [] + for _output in post_process_output: + keys_ = list(_output["info_dict"].keys()) + + inst_dicts.extend( + [ + { + i + 1: { + key: values[i] + for key, values in _output["info_dict"].items() + } + for i in range(len(_output["info_dict"][keys_[0]])) + } + ] + ) + + return inst_dicts From e660bb31534d2644f604ff7058ff55fee1821be1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:27:39 +0000 Subject: [PATCH 045/156] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/architecture/hovernet.py | 4 +++- .../models/architecture/hovernetplus.py | 8 +++++-- .../models/engine/multi_task_segmentor.py | 22 ++++++++++++------- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index c23ee1323..076b216c7 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -808,7 +808,9 @@ def postproc(self: HoVerNet, raw_maps: list[np.ndarray]) -> tuple[dict, ...]: nuclei_seg = { "task_type": self.tasks[0], - "predictions": da.array(pred_inst) if isinstance(raw_maps[0], da.Array) else pred_inst, + "predictions": da.array(pred_inst) + if isinstance(raw_maps[0], da.Array) + else pred_inst, "info_dict": nuc_inst_info_dict_, } diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 37477c10c..a73f92ab5 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -367,7 +367,9 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] nuclei_seg = { "task_type": self.tasks[0], - "predictions": da.array(pred_inst) if isinstance(raw_maps[0], da.Array) else pred_inst, + "predictions": da.array(pred_inst) + if isinstance(raw_maps[0], da.Array) + else pred_inst, "info_dict": nuc_inst_info_dict_, } @@ -389,7 +391,9 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] layer_seg = { "task_type": self.tasks[1], - "predictions": da.array(pred_layer) if isinstance(raw_maps[0], da.Array) else pred_layer, + "predictions": da.array(pred_layer) + if isinstance(raw_maps[0], da.Array) + else pred_layer, "info_dict": layer_info_dict_, } diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index a3174d391..0bcdaf72f 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -4,6 +4,7 @@ import gc import uuid +from collections import deque from pathlib import Path from typing import TYPE_CHECKING @@ -13,12 +14,15 @@ import torch import zarr from dask import compute +from shapely.geometry import box as shapely_box from shapely.geometry import shape as feature2geometry +from shapely.strtree import STRtree from typing_extensions import Unpack from tiatoolbox import logger from tiatoolbox.annotation import SQLiteStore from tiatoolbox.annotation.storage import Annotation +from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils.misc import get_tqdm, make_valid_poly from tiatoolbox.wsicore.wsireader import is_zarr @@ -29,10 +33,6 @@ merge_batch_to_canvas, store_probabilities, ) -from tiatoolbox.tools.patchextraction import PatchExtractor -from shapely.geometry import box as shapely_box -from shapely.strtree import STRtree -from collections import deque if TYPE_CHECKING: # pragma: no cover import os @@ -342,7 +342,9 @@ def post_process_wsi( # skipcq: PYL-R0201 probabilities_is_zarr = False for probabilities_ in probabilities: - if any("from-zarr" in str(key) for key in probabilities_.dask.layers.keys()): + if any( + "from-zarr" in str(key) for key in probabilities_.dask.layers.keys() + ): probabilities_is_zarr = True break @@ -409,7 +411,11 @@ def _process_tile_mode( post_process_output = self.model.postproc_func(head_raws) # create a list for info dict for each task - wsi_info_dict = [{} for _ in post_process_output] if wsi_info_dict is None else wsi_info_dict + wsi_info_dict = ( + [{} for _ in post_process_output] + if wsi_info_dict is None + else wsi_info_dict + ) inst_dicts = _get_inst_dicts(post_process_output=post_process_output) tile_mode = set_idx @@ -1545,6 +1551,7 @@ def _save_annotation_store( return output_path + def _process_instance_predictions( inst_dict: dict, ioconfig: IOSegmentorConfig, @@ -1715,8 +1722,7 @@ def _get_inst_dicts(post_process_output: tuple[dict]) -> list: [ { i + 1: { - key: values[i] - for key, values in _output["info_dict"].items() + key: values[i] for key, values in _output["info_dict"].items() } for i in range(len(_output["info_dict"][keys_[0]])) } From be79c7f2c4ffc67ea7dd53922a33ebc5c6bff363 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 26 Jan 2026 18:49:00 +0000 Subject: [PATCH 046/156] [skip ci] :construction: Restructure dictionary output --- tiatoolbox/models/engine/multi_task_segmentor.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 0bcdaf72f..3159b78e1 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -441,6 +441,15 @@ def _process_tile_mode( for inst_uuid in remove_uuid_lists[inst_id]: wsi_info_dict[inst_id].pop(inst_uuid, None) + a = wsi_info_dict[0] + keys = ["box", "centroid", "contours", "prob", "type"] + info_dict = {} + for key in keys: + # Extract the list of values for this key across all instances + values = [a[i][key] for i in a] + # Convert to a Dask array (single chunk of size N) + info_dict[key] = values + return wsi_info_dict @staticmethod From 55f7d7f326fc24c98557cce0c8597ed1250e4744 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 26 Jan 2026 22:46:10 +0000 Subject: [PATCH 047/156] [skip ci] :construction: Fix structure of the output --- tests/engines/test_multi_task_segmentor.py | 2 +- tiatoolbox/data/pretrained_model.yaml | 2 +- .../models/engine/multi_task_segmentor.py | 71 +++++++++++++------ 3 files changed, 51 insertions(+), 24 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 5c7bab463..d02144f46 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -281,7 +281,7 @@ def test_wsi_mtsegmentor_zarr( track_tmp_path: Path, ) -> None: """Test MultiTaskSegmentor for WSIs with zarr output.""" - wsi4_512_512_svs = remote_sample("wsi2_4k_4k_svs") + wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") mtsegmentor = MultiTaskSegmentor( model="hovernetplus-oed", batch_size=64, diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 2e10ba199..5493d175e 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -779,7 +779,7 @@ hovernetplus-oed: - {"units": "mpp", "resolution": 0.50} - {"units": "mpp", "resolution": 0.50} margin: 128 - tile_shape: [512, 512] + tile_shape: [300, 300] patch_input_shape: [256, 256] patch_output_shape: [164, 164] stride_shape: [164, 164] diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 3159b78e1..71c825cc9 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -6,7 +6,7 @@ import uuid from collections import deque from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import dask.array as da import numpy as np @@ -342,9 +342,7 @@ def post_process_wsi( # skipcq: PYL-R0201 probabilities_is_zarr = False for probabilities_ in probabilities: - if any( - "from-zarr" in str(key) for key in probabilities_.dask.layers.keys() - ): + if any("from-zarr" in str(key) for key in probabilities_.dask.layers): probabilities_is_zarr = True break @@ -379,10 +377,11 @@ def _process_tile_mode( probabilities: list[da.Array | np.ndarray], *, return_predictions: bool = False, - ) -> list[dict]: + ) -> list[dict] | None: """Helper function to process WSI in tile mode.""" highest_input_resolution = self.ioconfig.highest_input_resolution wsi_reader = self.dataloader.dataset.reader + _ = return_predictions # assume ioconfig has already been converted to `baseline` for `tile` mode wsi_proc_shape = wsi_reader.slide_dimensions(**highest_input_resolution) @@ -410,12 +409,13 @@ def _process_tile_mode( ] post_process_output = self.model.postproc_func(head_raws) - # create a list for info dict for each task - wsi_info_dict = ( - [{} for _ in post_process_output] - if wsi_info_dict is None - else wsi_info_dict + # create a list of info dict for each task + wsi_info_dict = _create_wsi_info_dict( + post_process_output=post_process_output, + wsi_info_dict=wsi_info_dict, + wsi_proc_shape=wsi_proc_shape, ) + inst_dicts = _get_inst_dicts(post_process_output=post_process_output) tile_mode = set_idx @@ -428,7 +428,7 @@ def _process_tile_mode( tile_flag, tile_mode, tile_tl, - wsi_info_dict[inst_id], + wsi_info_dict[inst_id]["info_dict"], ) new_inst_dicts.append(new_inst_dict) remove_insts_in_origs.append(remove_insts_in_orig) @@ -437,18 +437,22 @@ def _process_tile_mode( for new_inst_dicts, remove_uuid_lists in merged: for inst_id, new_inst_dict in enumerate(new_inst_dicts): - wsi_info_dict[inst_id].update(new_inst_dict) + wsi_info_dict[inst_id]["info_dict"].update(new_inst_dict) for inst_uuid in remove_uuid_lists[inst_id]: - wsi_info_dict[inst_id].pop(inst_uuid, None) - - a = wsi_info_dict[0] - keys = ["box", "centroid", "contours", "prob", "type"] - info_dict = {} - for key in keys: - # Extract the list of values for this key across all instances - values = [a[i][key] for i in a] - # Convert to a Dask array (single chunk of size N) - info_dict[key] = values + wsi_info_dict[inst_id]["info_dict"].pop(inst_uuid, None) + + for idx, wsi_info_dict_ in enumerate(wsi_info_dict): + info_dict_keys: list[str] = wsi_info_dict_["info_dict_keys"] + info_dict = {} + for key in info_dict_keys: + # Extract the list of values for this key across all instances + values = [ + da.array(wsi_info_dict_["info_dict"][i][key]) + for i in wsi_info_dict_["info_dict"] + ] + info_dict[key] = values + wsi_info_dict[idx]["info_dict"] = info_dict + wsi_info_dict_.pop("info_dict_keys") return wsi_info_dict @@ -1739,3 +1743,26 @@ def _get_inst_dicts(post_process_output: tuple[dict]) -> list: ) return inst_dicts + + +def _create_wsi_info_dict( + post_process_output: tuple[dict], + wsi_info_dict: tuple[dict] | None, + wsi_proc_shape: tuple[int, ...], +) -> tuple[dict[str, dict[Any, Any] | list[Any] | Any], ...]: + """Helper function to create wsi info dict.""" + if wsi_info_dict is not None: + return wsi_info_dict + + return tuple( + { + "task_type": post_process_output_["task_type"], + "predictions": da.zeros( + shape=wsi_proc_shape, + dtype=post_process_output_["predictions"].dtype, + ), + "info_dict": {}, + "info_dict_keys": list(post_process_output_["info_dict"]), + } + for post_process_output_ in post_process_output + ) From 7a2fb6c01837826a22855eb76de6b6c1fa6e297a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 26 Jan 2026 22:50:43 +0000 Subject: [PATCH 048/156] [skip ci] :construction: Fix structure of the output --- tests/engines/test_multi_task_segmentor.py | 3 +++ tiatoolbox/data/pretrained_model.yaml | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index d02144f46..f2e52b97f 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -288,6 +288,8 @@ def test_wsi_mtsegmentor_zarr( verbose=False, num_workers=1, ) + ioconfig = mtsegmentor.ioconfig + ioconfig.tile_shape = (300, 300) # Return Probabilities is False output = mtsegmentor.run( images=[wsi4_512_512_svs], @@ -299,6 +301,7 @@ def test_wsi_mtsegmentor_zarr( batch_size=2, output_type="zarr", memory_threshold=1, + ioconfig=ioconfig, ) output_ = zarr.open(output[wsi4_512_512_svs], mode="r") diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 5493d175e..9f715593e 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -779,7 +779,7 @@ hovernetplus-oed: - {"units": "mpp", "resolution": 0.50} - {"units": "mpp", "resolution": 0.50} margin: 128 - tile_shape: [300, 300] + tile_shape: [2048, 2048] patch_input_shape: [256, 256] patch_output_shape: [164, 164] stride_shape: [164, 164] From 189d5b3a97a4248de423f0887f892b2960efdd29 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 29 Jan 2026 13:55:42 +0100 Subject: [PATCH 049/156] :memo: Add docstring for `_get_inst_info_dicts` --- tiatoolbox/models/engine/multi_task_segmentor.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 71c825cc9..7e1f5a1e4 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -416,7 +416,9 @@ def _process_tile_mode( wsi_proc_shape=wsi_proc_shape, ) - inst_dicts = _get_inst_dicts(post_process_output=post_process_output) + inst_dicts = _get_inst_info_dicts( + post_process_output=post_process_output + ) tile_mode = set_idx new_inst_dicts, remove_insts_in_origs = [], [] @@ -1726,7 +1728,13 @@ def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: return new_inst_dict, remove_insts_in_orig -def _get_inst_dicts(post_process_output: tuple[dict]) -> list: +def _get_inst_info_dicts(post_process_output: tuple[dict]) -> list: + """Helper to convert post processing output to dictionary list. + + This function makes the info_dict compatible with tile based processing of + info_dictionaries from HoVerNet. + + """ inst_dicts = [] for _output in post_process_output: keys_ = list(_output["info_dict"].keys()) From 13aa48c74e19a2e959c39cd922b414375d1f7073 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:56:19 +0100 Subject: [PATCH 050/156] :construction: Working on testing tile based output --- .../models/architecture/hovernetplus.py | 59 +++++++++++-------- .../models/engine/multi_task_segmentor.py | 11 +++- 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index a73f92ab5..8fbdb4f34 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -333,10 +333,13 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] """ np_map, hv_map, tp_map, ls_map = raw_maps - np_map = np_map.compute() if isinstance(np_map, da.Array) else np_map - hv_map = hv_map.compute() if isinstance(hv_map, da.Array) else hv_map - tp_map = tp_map.compute() if isinstance(tp_map, da.Array) else tp_map - ls_map = ls_map.compute() if isinstance(ls_map, da.Array) else ls_map + # Assumes raw_maps is a tuple of dask or numpy arrays. + is_dask = isinstance(raw_maps[0], da.Array) + + np_map = np_map.compute() if is_dask else np_map + hv_map = hv_map.compute() if is_dask else hv_map + tp_map = tp_map.compute() if is_dask else tp_map + ls_map = ls_map.compute() if is_dask else ls_map pred_inst = HoVerNetPlus._proc_np_hv(np_map, hv_map, scale_factor=0.5) # fx=0.5 as nuclear processing is at 0.5 mpp instead of 0.25 mpp @@ -350,50 +353,58 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] nuc_inst_info_dict_ = {} if not nuc_inst_info_dict: nuc_inst_info_dict_ = { # inst_id should start at 1 - "box": da.empty(shape=0), - "centroid": da.empty(shape=0), - "contours": da.empty(shape=0), - "prob": da.empty(shape=0), - "type": da.empty(shape=0), + "box": da.empty(shape=0) if is_dask else np.empty(0), + "centroid": da.empty(shape=0) if is_dask else np.empty(0), + "contours": da.empty(shape=0) if is_dask else np.empty(0), + "prob": da.empty(shape=0) if is_dask else np.empty(0), + "type": da.empty(shape=0) if is_dask else np.empty(0), } else: # dask dataframe does not support transpose nuc_inst_info_df = pd.DataFrame(nuc_inst_info_dict).transpose() for key, col in nuc_inst_info_df.items(): - nuc_inst_info_dict_[key] = da.from_array( - col.to_numpy(), - chunks=(len(col),), # one chunk, avoids auto-rechunking + col_np = col.to_numpy() + nuc_inst_info_dict_[key] = ( + da.from_array( + col_np, + chunks=(len(col),), # one chunk, avoids + # auto-rechunking + ) + if is_dask + else col_np ) nuclei_seg = { "task_type": self.tasks[0], - "predictions": da.array(pred_inst) - if isinstance(raw_maps[0], da.Array) - else pred_inst, + "predictions": da.array(pred_inst) if is_dask else pred_inst, "info_dict": nuc_inst_info_dict_, } layer_info_dict_ = {} if not layer_info_dict: layer_info_dict_ = { # inst_id should start at 1 - "box": da.empty(shape=0), - "contours": da.empty(shape=0), - "type": da.empty(shape=0), + "box": da.empty(shape=0) if is_dask else np.empty(0), + "contours": da.empty(shape=0) if is_dask else np.empty(0), + "type": da.empty(shape=0) if is_dask else np.empty(0), } else: # dask dataframe does not support transpose layer_info_df = pd.DataFrame(layer_info_dict).transpose() for key, col in layer_info_df.items(): - layer_info_dict_[key] = da.from_array( - col.to_numpy(), - chunks=(len(col),), # one chunk, avoids auto-rechunking + col_np = col.to_numpy() + layer_info_dict_[key] = ( + da.from_array( + col.to_numpy(), + chunks=(len(col),), # one chunk, avoids + # auto-rechunking + ) + if is_dask + else col_np ) layer_seg = { "task_type": self.tasks[1], - "predictions": da.array(pred_layer) - if isinstance(raw_maps[0], da.Array) - else pred_layer, + "predictions": da.array(pred_layer) if is_dask else pred_layer, "info_dict": layer_info_dict_, } diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 7e1f5a1e4..e41edf2bc 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -404,7 +404,7 @@ def _process_tile_mode( tile_bounds[1] : tile_bounds[3], tile_bounds[0] : tile_bounds[2], :, - ] + ].compute() for probabilities_ in probabilities ] post_process_output = self.model.postproc_func(head_raws) @@ -449,7 +449,7 @@ def _process_tile_mode( for key in info_dict_keys: # Extract the list of values for this key across all instances values = [ - da.array(wsi_info_dict_["info_dict"][i][key]) + wsi_info_dict_["info_dict"][i][key] for i in wsi_info_dict_["info_dict"] ] info_dict[key] = values @@ -1725,6 +1725,13 @@ def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: inst_info["contours"] += tile_tl inst_uuid = uuid.uuid4().hex new_inst_dict[inst_uuid] = inst_info + + for inst_uid, inst_info in new_inst_dict.items(): + for key, value in inst_info.items(): + new_inst_dict[inst_uid][key] = da.asarray( + value, + ) + return new_inst_dict, remove_insts_in_orig From 44456a7564a415149f6921a7da72cbbcd640034f Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 29 Jan 2026 17:44:35 +0100 Subject: [PATCH 051/156] :construction: Finalise wsi_info_dict structure to work with save_predictions. --- tests/engines/test_multi_task_segmentor.py | 9 +++--- .../models/architecture/hovernetplus.py | 6 ++-- .../models/engine/multi_task_segmentor.py | 28 +++++++------------ 3 files changed, 17 insertions(+), 26 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index f2e52b97f..47e5f8cfa 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -281,7 +281,7 @@ def test_wsi_mtsegmentor_zarr( track_tmp_path: Path, ) -> None: """Test MultiTaskSegmentor for WSIs with zarr output.""" - wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") + wsi4_1k_1k_svs = remote_sample("wsi4_1k_1k_svs") mtsegmentor = MultiTaskSegmentor( model="hovernetplus-oed", batch_size=64, @@ -289,10 +289,10 @@ def test_wsi_mtsegmentor_zarr( num_workers=1, ) ioconfig = mtsegmentor.ioconfig - ioconfig.tile_shape = (300, 300) + ioconfig.tile_shape = (512, 512) # Return Probabilities is False output = mtsegmentor.run( - images=[wsi4_512_512_svs], + images=[wsi4_1k_1k_svs], return_probabilities=False, return_labels=False, device=device, @@ -304,7 +304,7 @@ def test_wsi_mtsegmentor_zarr( ioconfig=ioconfig, ) - output_ = zarr.open(output[wsi4_512_512_svs], mode="r") + output_ = zarr.open(output[wsi4_1k_1k_svs], mode="r") assert 15 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 19 assert 0.57 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 0.61 assert "probabilities" not in output_ @@ -312,6 +312,7 @@ def test_wsi_mtsegmentor_zarr( assert "count" not in output_["nuclei_segmentation"] assert "canvas" not in output_["layer_segmentation"] assert "count" not in output_["layer_segmentation"] + wsi4_1k_1k_svs.unlink() def test_multi_input_wsi_mtsegmentor_zarr( diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 8fbdb4f34..7ae42447a 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -367,8 +367,7 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] nuc_inst_info_dict_[key] = ( da.from_array( col_np, - chunks=(len(col),), # one chunk, avoids - # auto-rechunking + chunks=(len(col),), ) if is_dask else col_np @@ -395,8 +394,7 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] layer_info_dict_[key] = ( da.from_array( col.to_numpy(), - chunks=(len(col),), # one chunk, avoids - # auto-rechunking + chunks=(len(col),), ) if is_dask else col_np diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index e41edf2bc..9fa91fbd1 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -10,6 +10,7 @@ import dask.array as da import numpy as np +import pandas as pd import psutil import torch import zarr @@ -444,17 +445,15 @@ def _process_tile_mode( wsi_info_dict[inst_id]["info_dict"].pop(inst_uuid, None) for idx, wsi_info_dict_ in enumerate(wsi_info_dict): - info_dict_keys: list[str] = wsi_info_dict_["info_dict_keys"] - info_dict = {} - for key in info_dict_keys: - # Extract the list of values for this key across all instances - values = [ - wsi_info_dict_["info_dict"][i][key] - for i in wsi_info_dict_["info_dict"] - ] - info_dict[key] = values - wsi_info_dict[idx]["info_dict"] = info_dict - wsi_info_dict_.pop("info_dict_keys") + info_df = pd.DataFrame(wsi_info_dict_["info_dict"]).transpose() + dict_info_wsi = {} + for key, col in info_df.items(): + col_np = col.to_numpy() + dict_info_wsi[key] = da.from_array( + col_np, + chunks=(len(col),), + ) + wsi_info_dict[idx]["info_dict"] = dict_info_wsi return wsi_info_dict @@ -1726,12 +1725,6 @@ def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: inst_uuid = uuid.uuid4().hex new_inst_dict[inst_uuid] = inst_info - for inst_uid, inst_info in new_inst_dict.items(): - for key, value in inst_info.items(): - new_inst_dict[inst_uid][key] = da.asarray( - value, - ) - return new_inst_dict, remove_insts_in_orig @@ -1777,7 +1770,6 @@ def _create_wsi_info_dict( dtype=post_process_output_["predictions"].dtype, ), "info_dict": {}, - "info_dict_keys": list(post_process_output_["info_dict"]), } for post_process_output_ in post_process_output ) From 4dbf9c4d7e0497ed39b8f0378cf795f820775eec Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 29 Jan 2026 17:56:05 +0100 Subject: [PATCH 052/156] :white_check_mark: Update hovernet postprocessing to be compatible with multitask design --- tiatoolbox/models/architecture/hovernet.py | 21 +++++++++++++------ .../models/architecture/hovernetplus.py | 1 + 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 076b216c7..10fdc7cc9 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -781,9 +781,13 @@ def postproc(self: HoVerNet, raw_maps: list[np.ndarray]) -> tuple[dict, ...]: tp_map = None np_map, hv_map = raw_maps - np_map = np_map.compute() if isinstance(np_map, da.Array) else np_map - hv_map = hv_map.compute() if isinstance(hv_map, da.Array) else hv_map - pred_type = tp_map.compute() if isinstance(tp_map, da.Array) else tp_map + # Assumes raw_maps is a tuple of dask or numpy arrays. + # Only return dask if it's required. + is_dask = isinstance(raw_maps[0], da.Array) + + np_map = np_map.compute() if is_dask else np_map + hv_map = hv_map.compute() if is_dask else hv_map + pred_type = tp_map.compute() if is_dask else tp_map pred_inst = HoVerNet._proc_np_hv(np_map, hv_map) nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) @@ -801,9 +805,14 @@ def postproc(self: HoVerNet, raw_maps: list[np.ndarray]) -> tuple[dict, ...]: # dask dataframe does not support transpose nuc_inst_info_df = pd.DataFrame(nuc_inst_info_dict).transpose() for key, col in nuc_inst_info_df.items(): - nuc_inst_info_dict_[key] = da.from_array( - col.to_numpy(), - chunks=(len(col),), # one chunk, avoids auto-rechunking + col_np = col.to_numpy() + nuc_inst_info_dict_[key] = ( + da.from_array( + col_np, + chunks=(len(col),), + ) + if is_dask + else col_np ) nuclei_seg = { diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 7ae42447a..43e0b7aec 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -334,6 +334,7 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] np_map, hv_map, tp_map, ls_map = raw_maps # Assumes raw_maps is a tuple of dask or numpy arrays. + # Only return dask if it's required. is_dask = isinstance(raw_maps[0], da.Array) np_map = np_map.compute() if is_dask else np_map From eb60aa1be61cc719bbf2d75e0bb43e5ab9e83770 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 10:57:51 +0100 Subject: [PATCH 053/156] :construction: Initial implementation of merged predictions --- tests/engines/test_multi_task_segmentor.py | 20 ++++- .../models/engine/multi_task_segmentor.py | 86 +++++++++++++++++-- tiatoolbox/utils/misc.py | 64 ++++++++++++++ 3 files changed, 162 insertions(+), 8 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 47e5f8cfa..4a22eebb8 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -289,9 +289,24 @@ def test_wsi_mtsegmentor_zarr( num_workers=1, ) ioconfig = mtsegmentor.ioconfig + # Return Probabilities is False + output_full = mtsegmentor.run( + images=[wsi4_1k_1k_svs], + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=False, + save_dir=track_tmp_path / "wsi_out_check", + batch_size=2, + output_type="zarr", + memory_threshold=1, + ioconfig=ioconfig, + ) + + # Redefine tile size to force tile-based processing. ioconfig.tile_shape = (512, 512) # Return Probabilities is False - output = mtsegmentor.run( + output_tile = mtsegmentor.run( images=[wsi4_1k_1k_svs], return_probabilities=False, return_labels=False, @@ -304,7 +319,7 @@ def test_wsi_mtsegmentor_zarr( ioconfig=ioconfig, ) - output_ = zarr.open(output[wsi4_1k_1k_svs], mode="r") + output_ = zarr.open(output_full[wsi4_1k_1k_svs], mode="r") assert 15 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 19 assert 0.57 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 0.61 assert "probabilities" not in output_ @@ -312,6 +327,7 @@ def test_wsi_mtsegmentor_zarr( assert "count" not in output_["nuclei_segmentation"] assert "canvas" not in output_["layer_segmentation"] assert "count" not in output_["layer_segmentation"] + _ = zarr.open(output_tile[wsi4_1k_1k_svs], mode="r") wsi4_1k_1k_svs.unlink() diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 9fa91fbd1..fa06bcb7d 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -24,7 +24,7 @@ from tiatoolbox.annotation import SQLiteStore from tiatoolbox.annotation.storage import Annotation from tiatoolbox.tools.patchextraction import PatchExtractor -from tiatoolbox.utils.misc import get_tqdm, make_valid_poly +from tiatoolbox.utils.misc import create_smart_array, get_tqdm, make_valid_poly from tiatoolbox.wsicore.wsireader import is_zarr from .semantic_segmentor import ( @@ -335,8 +335,8 @@ def post_process_patches( # skipcq: PYL-R0201 def post_process_wsi( # skipcq: PYL-R0201 self: MultiTaskSegmentor, raw_predictions: dict, - save_path: Path, # noqa: ARG002 - **kwargs: Unpack[SemanticSegmentorRunParams], # noqa: ARG002 + save_path: Path, + **kwargs: Unpack[SemanticSegmentorRunParams], ) -> dict: """Post-process raw patch predictions from inference.""" probabilities = raw_predictions["probabilities"] @@ -351,7 +351,12 @@ def post_process_wsi( # skipcq: PYL-R0201 if not probabilities_is_zarr: post_process_predictions = self.model.postproc_func(probabilities) else: - post_process_predictions = self._process_tile_mode(probabilities) + post_process_predictions = self._process_tile_mode( + probabilities, + save_path=save_path.with_suffix(".zarr"), + memory_threshold=kwargs.get("memory_threshold", 80), + return_predictions=kwargs.get("return_predictions", False), + ) tasks = set() for seg in post_process_predictions: @@ -376,6 +381,8 @@ def post_process_wsi( # skipcq: PYL-R0201 def _process_tile_mode( self: MultiTaskSegmentor, probabilities: list[da.Array | np.ndarray], + save_path: Path, + memory_threshold: float = 80, *, return_predictions: bool = False, ) -> list[dict] | None: @@ -415,6 +422,14 @@ def _process_tile_mode( post_process_output=post_process_output, wsi_info_dict=wsi_info_dict, wsi_proc_shape=wsi_proc_shape, + save_path=save_path, + memory_threshold=memory_threshold, + ) + + wsi_info_dict = _update_tile_based_predictions_array( + post_process_output=post_process_output, + wsi_info_dict=wsi_info_dict, + bounds=tile_bounds, ) inst_dicts = _get_inst_info_dicts( @@ -1757,19 +1772,78 @@ def _create_wsi_info_dict( post_process_output: tuple[dict], wsi_info_dict: tuple[dict] | None, wsi_proc_shape: tuple[int, ...], + save_path: Path, + memory_threshold: float = 80, ) -> tuple[dict[str, dict[Any, Any] | list[Any] | Any], ...]: - """Helper function to create wsi info dict.""" + """Create or reuse WSI info dictionaries for post-processed outputs. + + This function constructs a tuple of WSI information dictionaries, one for each + element in `post_process_output`. If an existing `wsi_info_dict` is provided, + it is returned unchanged. Otherwise, a new dictionary is created for each item, + containing task metadata, an allocated prediction array (NumPy or Zarr, chosen + based on available memory), and an empty `info_dict` for downstream metadata. + + Args: + post_process_output (tuple[dict]): + A tuple of dictionaries produced by the post-processing step. Each + dictionary must contain at least: + - "task_type": str + - "predictions": array-like with a `.dtype` and `.shape` attribute + wsi_info_dict (tuple[dict] | None): + Existing WSI info dictionaries. If provided, they are returned as-is. + wsi_proc_shape (tuple[int, ...]): + The full shape of the WSI-level prediction array to allocate for each + output item. + save_path (Path): + Filesystem path where Zarr arrays will be stored if disk-backed + allocation is required. + memory_threshold (float, optional): + Fraction of available RAM allowed for in-memory allocation. Must be + between 0.0 and 100. Defaults to 80. + + Returns: + tuple[dict[str, dict[Any, Any] | list[Any] | Any], ...]: + A tuple of dictionaries, one per post-processing output. Each dictionary + contains: + - "task_type": str + - "predictions": allocated NumPy or Zarr array. + - "info_dict": an empty dictionary for additional metadata. + + """ if wsi_info_dict is not None: return wsi_info_dict return tuple( { "task_type": post_process_output_["task_type"], - "predictions": da.zeros( + "predictions": create_smart_array( shape=wsi_proc_shape, dtype=post_process_output_["predictions"].dtype, + memory_threshold=memory_threshold, + zarr_path=save_path, + chunks=post_process_output_["predictions"].shape, ), "info_dict": {}, } for post_process_output_ in post_process_output ) + + +def _update_tile_based_predictions_array( + post_process_output: tuple[dict], + wsi_info_dict: tuple[dict], + bounds: tuple[int, int, int, int], +) -> tuple[dict]: + """Helper function to update tile based predictions array.""" + x_start, y_start, x_end, y_end = bounds + + for idx, post_process_output_ in enumerate(post_process_output): + max_h, max_w = wsi_info_dict[idx]["predictions"].shape + x_end, y_end = min(x_end, max_w), min(y_end, max_h) + wsi_info_dict[idx]["predictions"][y_start:y_end, x_start:x_end, :] = ( + post_process_output_["predictions"][ + 0 : y_end - y_start, 0 : x_end - x_start, : + ] + ) + + return wsi_info_dict diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index aa2c3e9dd..148fe7eda 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -15,6 +15,7 @@ import joblib import numpy as np import pandas as pd +import psutil import requests import tifffile import yaml @@ -1669,3 +1670,66 @@ def cast_to_min_dtype(array: np.ndarray | da.Array) -> np.ndarray | da.Array: return array.astype(dtype) return array + + +def create_smart_array( + shape: tuple[int, ...], + dtype: np.dtype | str, + memory_threshold: float, + zarr_path: str | Path = "array.zarr", + chunks: tuple[int, ...] | None = None, +) -> np.ndarray | zarr.Array: + """Allocate a NumPy or Zarr array depending on available memory and a threshold. + + This function estimates the memory required for an array of the given shape and + dtype. If the required memory is below the allowed fraction of available RAM + (defined by `memory_threshold`), a NumPy array is created in memory. Otherwise, + a Zarr array is created on disk. This enables seamless scaling between in-memory + and out-of-core workflows. + + Args: + shape (tuple(int,...)): + Shape of the array to allocate, e.g., (height, width, channels). + dtype (np.dtype | str): + NumPy dtype or dtype string for the array, e.g., np.float32 or "float32". + memory_threshold (float): + Fraction of available RAM allowed for this allocation. Must be between + 0.0 and 100. A value of 100 allows using all available RAM; 0.0 forces + Zarr allocation. + zarr_path (str | None): + Filesystem path where the Zarr array will be created if needed. + Defaults to "array.zarr". + chunks (tuple(int,...) | None): + Chunk shape for the Zarr array. If None, a reasonable default is chosen + based on the array shape. + + Returns: + np.ndarray | zarr.core.Array: + - The allocated array (NumPy or Zarr). + + """ + # Compute required bytes + bytes_needed = np.prod(shape) * np.dtype(dtype).itemsize + + # Available memory + available = psutil.virtual_memory().available + allowed = available * (memory_threshold / 100.0) + + fits_in_memory = bytes_needed <= allowed + + if fits_in_memory: + # Allocate in-memory NumPy array + arr = np.zeros(shape, dtype=dtype) + else: + if zarr_path is None: + temp_dir = tempfile.mkdtemp(prefix="smartarray_") + zarr_path = Path(str(temp_dir)) / "array.zarr" + + # Allocate Zarr array on disk + if chunks is None: + # Default chunking: try to chunk along spatial dims + chunks = (*(min(s, 512) for s in shape[:-1]), shape[-1]) + + arr = zarr.open(zarr_path, mode="w", shape=shape, chunks=chunks, dtype=dtype) + + return arr From 820bb6e8830f5f67e9fcac2d7b0d7db7820fdf4a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 12:54:48 +0100 Subject: [PATCH 054/156] :construction: Initial implementation of merged predictions --- tests/engines/test_multi_task_segmentor.py | 31 ++++--- .../models/engine/multi_task_segmentor.py | 87 +++++++++++++++---- 2 files changed, 91 insertions(+), 27 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 4a22eebb8..13e5d4c04 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -296,15 +296,24 @@ def test_wsi_mtsegmentor_zarr( return_labels=False, device=device, patch_mode=False, - save_dir=track_tmp_path / "wsi_out_check", + save_dir=track_tmp_path / "wsi_out_full", batch_size=2, output_type="zarr", - memory_threshold=1, ioconfig=ioconfig, ) + output_ = zarr.open(output_full[wsi4_1k_1k_svs], mode="r") + assert 37 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 41 + assert 0.87 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 0.91 + assert "probabilities" not in output_ + assert "canvas" not in output_["nuclei_segmentation"] + assert "count" not in output_["nuclei_segmentation"] + assert "canvas" not in output_["layer_segmentation"] + assert "count" not in output_["layer_segmentation"] + # Redefine tile size to force tile-based processing. ioconfig.tile_shape = (512, 512) + # Return Probabilities is False output_tile = mtsegmentor.run( images=[wsi4_1k_1k_svs], @@ -312,21 +321,19 @@ def test_wsi_mtsegmentor_zarr( return_labels=False, device=device, patch_mode=False, - save_dir=track_tmp_path / "wsi_out_check", + save_dir=track_tmp_path / "wsi_out_tile_based", batch_size=2, output_type="zarr", - memory_threshold=1, + memory_threshold=1, # Memory threshold forces tile_mode ioconfig=ioconfig, + # HoVerNet does not return predictions once + # contours have been calculated in original implementation. + # It's also not straight forward to keep track of instances + # Prediction masks can be tracked and saved as for layer segmentation in + # HoVerNet Plus. + return_predictions=(False, True), ) - output_ = zarr.open(output_full[wsi4_1k_1k_svs], mode="r") - assert 15 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 19 - assert 0.57 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 0.61 - assert "probabilities" not in output_ - assert "canvas" not in output_["nuclei_segmentation"] - assert "count" not in output_["nuclei_segmentation"] - assert "canvas" not in output_["layer_segmentation"] - assert "count" not in output_["layer_segmentation"] _ = zarr.open(output_tile[wsi4_1k_1k_svs], mode="r") wsi4_1k_1k_svs.unlink() diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index fa06bcb7d..9ce69eb39 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -49,6 +49,54 @@ from .io_config import IOSegmentorConfig +class MultiTaskSegmentorRunParams(SemanticSegmentorRunParams, total=False): + """Runtime parameters for configuring the `MultiTaskSegmentor.run()` method. + + This class extends `SemanticSegmentorRunParams`, and adds parameters specific + to multitask segmentation workflows. + + Attributes: + auto_get_mask (bool): + Whether to automatically generate segmentation masks using + `wsireader.tissue_mask()` during processing. + batch_size (int): + Number of image patches to feed to the model in a forward pass. + class_dict (dict): + Optional dictionary mapping classification outputs to class names. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). + labels (list): + Optional labels for input images. Only a single label per image + is supported. + memory_threshold (int): + Memory usage threshold (in percentage) to trigger caching behavior. + num_workers (int): + Number of workers used in DataLoader. + output_file (str): + Output file name for saving results (e.g., .zarr or .db). + output_resolutions (Resolution): + Resolution used for writing output predictions. + patch_output_shape (tuple[int, int]): + Shape of output patches (height, width). + return_labels (bool): + Whether to return labels with predictions. + return_probabilities (bool): + Whether to return per-class probabilities. + return_predictions (tuple(bool, ...): + Whether to return array predictions for individual tasks. + scale_factor (tuple[float, float]): + Scale factor for converting annotations to baseline resolution. + Typically model_mpp / slide_mpp. + stride_shape (tuple[int, int]): + Stride used during WSI processing. Defaults to patch_input_shape. + verbose (bool): + Whether to output logging information. + + """ + + return_predictions: tuple[bool, ...] + + class MultiTaskSegmentor(SemanticSegmentor): """A multitask segmentation engine for models like hovernet and hovernetplus.""" @@ -165,7 +213,7 @@ def infer_wsi( self: SemanticSegmentor, dataloader: DataLoader, save_path: Path, - **kwargs: Unpack[SemanticSegmentorRunParams], + **kwargs: Unpack[MultiTaskSegmentorRunParams], ) -> dict[str, da.Array]: """Perform model inference on a whole slide image (WSI).""" # Default Memory threshold percentage is 80. @@ -297,7 +345,7 @@ def infer_wsi( def post_process_patches( # skipcq: PYL-R0201 self: MultiTaskSegmentor, raw_predictions: dict, - **kwargs: Unpack[SemanticSegmentorRunParams], # noqa: ARG002 + **kwargs: Unpack[MultiTaskSegmentorRunParams], # noqa: ARG002 ) -> dict: """Post-process raw patch predictions from inference. @@ -336,7 +384,7 @@ def post_process_wsi( # skipcq: PYL-R0201 self: MultiTaskSegmentor, raw_predictions: dict, save_path: Path, - **kwargs: Unpack[SemanticSegmentorRunParams], + **kwargs: Unpack[MultiTaskSegmentorRunParams], ) -> dict: """Post-process raw patch predictions from inference.""" probabilities = raw_predictions["probabilities"] @@ -348,6 +396,7 @@ def post_process_wsi( # skipcq: PYL-R0201 break # If dask array can fit in memory process without tiling. + # This ignores post-processing tile size even if it is smaller. if not probabilities_is_zarr: post_process_predictions = self.model.postproc_func(probabilities) else: @@ -355,7 +404,7 @@ def post_process_wsi( # skipcq: PYL-R0201 probabilities, save_path=save_path.with_suffix(".zarr"), memory_threshold=kwargs.get("memory_threshold", 80), - return_predictions=kwargs.get("return_predictions", False), + return_predictions=kwargs.get("return_predictions", None), ) tasks = set() @@ -384,12 +433,11 @@ def _process_tile_mode( save_path: Path, memory_threshold: float = 80, *, - return_predictions: bool = False, + return_predictions: tuple[bool, ...] | None = None, ) -> list[dict] | None: """Helper function to process WSI in tile mode.""" highest_input_resolution = self.ioconfig.highest_input_resolution wsi_reader = self.dataloader.dataset.reader - _ = return_predictions # assume ioconfig has already been converted to `baseline` for `tile` mode wsi_proc_shape = wsi_reader.slide_dimensions(**highest_input_resolution) @@ -424,6 +472,7 @@ def _process_tile_mode( wsi_proc_shape=wsi_proc_shape, save_path=save_path, memory_threshold=memory_threshold, + return_predictions=return_predictions, ) wsi_info_dict = _update_tile_based_predictions_array( @@ -730,7 +779,7 @@ def _save_predictions_as_dict_zarr( processed_predictions: dict, output_type: str, save_path: Path | None = None, - **kwargs: Unpack[SemanticSegmentorRunParams], + **kwargs: Unpack[MultiTaskSegmentorRunParams], ) -> dict | AnnotationStore | Path | list[Path]: """Helper function to save predictions as dictionary or zarr.""" if output_type.lower() == "dict": @@ -778,7 +827,7 @@ def _save_predictions_as_annotationstore( processed_predictions: dict, task_name: str | None = None, save_path: Path | None = None, - **kwargs: Unpack[SemanticSegmentorRunParams], + **kwargs: Unpack[MultiTaskSegmentorRunParams], ) -> dict | AnnotationStore | Path | list[Path]: """Helper function to save predictions as annotationstore.""" # scale_factor set from kwargs @@ -848,7 +897,7 @@ def save_predictions( processed_predictions: dict, output_type: str, save_path: Path | None = None, - **kwargs: Unpack[SemanticSegmentorRunParams], + **kwargs: Unpack[MultiTaskSegmentorRunParams], ) -> dict | AnnotationStore | Path | list[Path]: """Save model predictions to disk or return them in memory. @@ -983,7 +1032,7 @@ def run( save_dir: os.PathLike | Path | None = None, overwrite: bool = False, output_type: str = "dict", - **kwargs: Unpack[SemanticSegmentorRunParams], + **kwargs: Unpack[MultiTaskSegmentorRunParams], ) -> AnnotationStore | Path | str | dict | list[Path]: """Run the semantic segmentation engine on input images. @@ -1017,7 +1066,7 @@ def run( output_type (str): Desired output format: "dict", "zarr", or "annotationstore". Default is "dict". - **kwargs (SemanticSegmentorRunParams): + **kwargs (MultiTaskSegmentorRunParams): Additional runtime parameters to configure segmentation. Optional Keys: @@ -1773,6 +1822,7 @@ def _create_wsi_info_dict( wsi_info_dict: tuple[dict] | None, wsi_proc_shape: tuple[int, ...], save_path: Path, + return_predictions: tuple[bool, ...] | None, memory_threshold: float = 80, ) -> tuple[dict[str, dict[Any, Any] | list[Any] | Any], ...]: """Create or reuse WSI info dictionaries for post-processed outputs. @@ -1813,10 +1863,15 @@ def _create_wsi_info_dict( if wsi_info_dict is not None: return wsi_info_dict + # Convert to tuple for each task + if return_predictions is None: + return_predictions = [None for _ in post_process_output] + return tuple( { "task_type": post_process_output_["task_type"], - "predictions": create_smart_array( + "predictions": None if return_predictions[idx] is None else + create_smart_array( shape=wsi_proc_shape, dtype=post_process_output_["predictions"].dtype, memory_threshold=memory_threshold, @@ -1825,7 +1880,7 @@ def _create_wsi_info_dict( ), "info_dict": {}, } - for post_process_output_ in post_process_output + for idx, post_process_output_ in enumerate(post_process_output) ) @@ -1838,11 +1893,13 @@ def _update_tile_based_predictions_array( x_start, y_start, x_end, y_end = bounds for idx, post_process_output_ in enumerate(post_process_output): + if wsi_info_dict[idx]["predictions"] is None: + continue max_h, max_w = wsi_info_dict[idx]["predictions"].shape x_end, y_end = min(x_end, max_w), min(y_end, max_h) - wsi_info_dict[idx]["predictions"][y_start:y_end, x_start:x_end, :] = ( + wsi_info_dict[idx]["predictions"][y_start:y_end, x_start:x_end] = ( post_process_output_["predictions"][ - 0 : y_end - y_start, 0 : x_end - x_start, : + 0 : y_end - y_start, 0 : x_end - x_start ] ) From 30c2ae77964e2ff136a98cb412900579db271d0b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 30 Jan 2026 11:55:33 +0000 Subject: [PATCH 055/156] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/engine/multi_task_segmentor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 9ce69eb39..0c22ce3d0 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -404,7 +404,7 @@ def post_process_wsi( # skipcq: PYL-R0201 probabilities, save_path=save_path.with_suffix(".zarr"), memory_threshold=kwargs.get("memory_threshold", 80), - return_predictions=kwargs.get("return_predictions", None), + return_predictions=kwargs.get("return_predictions"), ) tasks = set() @@ -1870,8 +1870,9 @@ def _create_wsi_info_dict( return tuple( { "task_type": post_process_output_["task_type"], - "predictions": None if return_predictions[idx] is None else - create_smart_array( + "predictions": None + if return_predictions[idx] is None + else create_smart_array( shape=wsi_proc_shape, dtype=post_process_output_["predictions"].dtype, memory_threshold=memory_threshold, From 2c0dccdc128dec5dd5496a9a3da63ab0bd83f6c2 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 14:42:34 +0100 Subject: [PATCH 056/156] :white_check_mark: Add support for tile based "predictions" output --- tests/engines/test_multi_task_segmentor.py | 39 ++++++++++++++----- .../models/engine/multi_task_segmentor.py | 29 ++++++++++++-- 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 13e5d4c04..01295cf8f 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -300,19 +300,21 @@ def test_wsi_mtsegmentor_zarr( batch_size=2, output_type="zarr", ioconfig=ioconfig, + return_predictions=(True, True), # True for both tasks. ) - output_ = zarr.open(output_full[wsi4_1k_1k_svs], mode="r") - assert 37 < np.mean(output_["nuclei_segmentation"]["predictions"][:]) < 41 - assert 0.87 < np.mean(output_["layer_segmentation"]["predictions"][:]) < 0.91 - assert "probabilities" not in output_ - assert "canvas" not in output_["nuclei_segmentation"] - assert "count" not in output_["nuclei_segmentation"] - assert "canvas" not in output_["layer_segmentation"] - assert "count" not in output_["layer_segmentation"] + output_full_ = zarr.open(output_full[wsi4_1k_1k_svs], mode="r") + assert 37 < np.mean(output_full_["nuclei_segmentation"]["predictions"][:]) < 41 + assert 0.50 < np.mean(output_full_["layer_segmentation"]["predictions"][:]) < 0.54 + assert "probabilities" not in output_full_ + assert "canvas" not in output_full_["nuclei_segmentation"] + assert "count" not in output_full_["nuclei_segmentation"] + assert "canvas" not in output_full_["layer_segmentation"] + assert "count" not in output_full_["layer_segmentation"] # Redefine tile size to force tile-based processing. ioconfig.tile_shape = (512, 512) + mtsegmentor.drop_keys = [] # Return Probabilities is False output_tile = mtsegmentor.run( @@ -334,7 +336,25 @@ def test_wsi_mtsegmentor_zarr( return_predictions=(False, True), ) - _ = zarr.open(output_tile[wsi4_1k_1k_svs], mode="r") + output_tile_ = zarr.open(output_tile[wsi4_1k_1k_svs], mode="r") + assert "predictions" not in output_tile_["nuclei_segmentation"] + assert 0.87 < np.mean(output_tile_["layer_segmentation"]["predictions"][:]) < 0.91 + predictions_tile = output_tile_["layer_segmentation"]["predictions"] + # Full predictions are usually larger in size with extra padding as it's faster to + # process full arrays if they can be divided into rectangular chunks in dask/zarr + predictions_full = output_full_["layer_segmentation"]["predictions"][ + 0 : predictions_tile.shape[0], 0 : predictions_tile.shape[1] + ] + overlap_pct = np.mean(predictions_full == predictions_tile) * 100 + assert overlap_pct > 99 + assert len(output_full_["layer_segmentation"]["contours"]) == len( + output_tile_["layer_segmentation"]["contours"] + ) + assert ( + len(output_tile_["nuclei_segmentation"]["contours"]) + / len(output_full_["nuclei_segmentation"]["contours"]) + > 0.9 + ) wsi4_1k_1k_svs.unlink() @@ -371,6 +391,7 @@ def test_multi_input_wsi_mtsegmentor_zarr( output_type="zarr", stride_shape=(160, 160), verbose=True, + return_predictions=(True,), ) output_ = zarr.open(output[wsi4_512_512_svs], mode="r") diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 0c22ce3d0..7dd054fda 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -395,13 +395,17 @@ def post_process_wsi( # skipcq: PYL-R0201 probabilities_is_zarr = True break + return_predictions = kwargs.get("return_predictions") # If dask array can fit in memory process without tiling. # This ignores post-processing tile size even if it is smaller. if not probabilities_is_zarr: - post_process_predictions = self.model.postproc_func(probabilities) + post_process_predictions = self._process_full_wsi( + probabilities=probabilities, + return_predictions=return_predictions, + ) else: post_process_predictions = self._process_tile_mode( - probabilities, + probabilities=probabilities, save_path=save_path.with_suffix(".zarr"), memory_threshold=kwargs.get("memory_threshold", 80), return_predictions=kwargs.get("return_predictions"), @@ -427,6 +431,20 @@ def post_process_wsi( # skipcq: PYL-R0201 return raw_predictions + def _process_full_wsi( + self: MultiTaskSegmentor, + probabilities: list[da.Array | np.ndarray], + *, + return_predictions: tuple[bool, ...] | None = None, + ) -> list[dict] | None: + """Helper function to post process WSI when it can fit in memory.""" + post_process_predictions = self.model.postproc_func(probabilities) + if return_predictions is None: + return_predictions = [False for _ in post_process_predictions] + for idx, return_predictions_ in enumerate(return_predictions): + if not return_predictions_: + del post_process_predictions[idx]["predictions"] + def _process_tile_mode( self: MultiTaskSegmentor, probabilities: list[da.Array | np.ndarray], @@ -1847,6 +1865,9 @@ def _create_wsi_info_dict( save_path (Path): Filesystem path where Zarr arrays will be stored if disk-backed allocation is required. + return_predictions (tuple[bool, ...]): + Whether to return predictions for individual tasks. Default is None, + which returns no predictions. memory_threshold (float, optional): Fraction of available RAM allowed for in-memory allocation. Must be between 0.0 and 100. Defaults to 80. @@ -1865,13 +1886,13 @@ def _create_wsi_info_dict( # Convert to tuple for each task if return_predictions is None: - return_predictions = [None for _ in post_process_output] + return_predictions = [False for _ in post_process_output] return tuple( { "task_type": post_process_output_["task_type"], "predictions": None - if return_predictions[idx] is None + if not return_predictions[idx] else create_smart_array( shape=wsi_proc_shape, dtype=post_process_output_["predictions"].dtype, From c963514660f24523982dcef7179cfc5ce66b49a0 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 14:45:32 +0100 Subject: [PATCH 057/156] :white_check_mark: Add support for tile based "predictions" output --- tiatoolbox/models/engine/multi_task_segmentor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 7dd054fda..837af8fa0 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -445,6 +445,8 @@ def _process_full_wsi( if not return_predictions_: del post_process_predictions[idx]["predictions"] + return post_process_predictions + def _process_tile_mode( self: MultiTaskSegmentor, probabilities: list[da.Array | np.ndarray], From 0aaaf5bc01df74b1cd9c87d31ae57902a0a3dde3 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 14:56:16 +0100 Subject: [PATCH 058/156] :bug: Fix deepsource error --- tiatoolbox/models/architecture/hovernet.py | 40 ++++++++++++------ .../models/architecture/hovernetplus.py | 41 +++++++------------ 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 10fdc7cc9..9fe39adc9 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -802,19 +802,11 @@ def postproc(self: HoVerNet, raw_maps: list[np.ndarray]) -> tuple[dict, ...]: "type": da.empty(shape=0), } else: - # dask dataframe does not support transpose - nuc_inst_info_df = pd.DataFrame(nuc_inst_info_dict).transpose() - for key, col in nuc_inst_info_df.items(): - col_np = col.to_numpy() - nuc_inst_info_dict_[key] = ( - da.from_array( - col_np, - chunks=(len(col),), - ) - if is_dask - else col_np - ) - + nuc_inst_info_dict_ = _inst_dict_for_dask_processing( + inst_info_dict=nuc_inst_info_dict, + inst_info_dict_=nuc_inst_info_dict_, + is_dask=is_dask, + ) nuclei_seg = { "task_type": self.tasks[0], "predictions": da.array(pred_inst) @@ -875,3 +867,25 @@ def infer_batch( # skipcq: PYL-W0221 if "tp" in pred_dict: return pred_dict["np"], pred_dict["hv"], pred_dict["tp"] return pred_dict["np"], pred_dict["hv"] + + +def _inst_dict_for_dask_processing( + inst_info_dict: dict, + inst_info_dict_: dict, + *, + is_dask: bool, +) -> dict: + """Helper function to convert dictionary with numpy arrays to dask arrays.""" + # dask dataframe does not support transpose + inst_info_df = pd.DataFrame(inst_info_dict).transpose() + for key, col in inst_info_df.items(): + col_np = col.to_numpy() + inst_info_dict_[key] = ( + da.from_array( + col_np, + chunks=(len(col),), + ) + if is_dask + else col_np + ) + return inst_info_dict_ diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 43e0b7aec..7b27e6d87 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -7,13 +7,15 @@ import cv2 import dask.array as da import numpy as np -import pandas as pd import torch import torch.nn.functional as F # noqa: N812 from skimage import morphology from torch import nn -from tiatoolbox.models.architecture.hovernet import HoVerNet +from tiatoolbox.models.architecture.hovernet import ( + HoVerNet, + _inst_dict_for_dask_processing, +) from tiatoolbox.models.architecture.utils import UpSample2x from tiatoolbox.utils.misc import get_bounding_box @@ -361,18 +363,11 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] "type": da.empty(shape=0) if is_dask else np.empty(0), } else: - # dask dataframe does not support transpose - nuc_inst_info_df = pd.DataFrame(nuc_inst_info_dict).transpose() - for key, col in nuc_inst_info_df.items(): - col_np = col.to_numpy() - nuc_inst_info_dict_[key] = ( - da.from_array( - col_np, - chunks=(len(col),), - ) - if is_dask - else col_np - ) + nuc_inst_info_dict_ = _inst_dict_for_dask_processing( + inst_info_dict=nuc_inst_info_dict, + inst_info_dict_=nuc_inst_info_dict_, + is_dask=is_dask, + ) nuclei_seg = { "task_type": self.tasks[0], @@ -388,19 +383,11 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] "type": da.empty(shape=0) if is_dask else np.empty(0), } else: - # dask dataframe does not support transpose - layer_info_df = pd.DataFrame(layer_info_dict).transpose() - for key, col in layer_info_df.items(): - col_np = col.to_numpy() - layer_info_dict_[key] = ( - da.from_array( - col.to_numpy(), - chunks=(len(col),), - ) - if is_dask - else col_np - ) - + layer_info_dict_ = _inst_dict_for_dask_processing( + inst_info_dict=layer_info_dict, + inst_info_dict_=layer_info_dict_, + is_dask=is_dask, + ) layer_seg = { "task_type": self.tasks[1], "predictions": da.array(pred_layer) if is_dask else pred_layer, From e3f9458f4bd56e32e6882fc5de3137f342e18005 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 15:15:30 +0100 Subject: [PATCH 059/156] :bug: Fix deepsource error --- .../models/engine/multi_task_segmentor.py | 118 ++++++++++++------ 1 file changed, 80 insertions(+), 38 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 837af8fa0..2173e7f59 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1653,27 +1653,32 @@ def _save_annotation_store( def _process_instance_predictions( inst_dict: dict, ioconfig: IOSegmentorConfig, - tile_shape: list, - tile_flag: list, + tile_shape: tuple[int, int], + tile_flag: tuple[int, int, int, int], tile_mode: int, - tile_tl: tuple, + tile_tl: tuple[int, int], ref_inst_dict: dict, ) -> list | tuple: """Function to merge new tile prediction with existing prediction. Args: - inst_dict (dict): Dictionary containing instance information. - ioconfig (:class:`IOSegmentorConfig`): Object defines information + inst_dict (dict): + Dictionary containing instance information. + ioconfig (:class:`IOSegmentorConfig`): + Object defines information about input and output placement of patches. - tile_shape (list): A list of the tile shape. - tile_flag (list): A list of flag to indicate if instances within + tile_shape (tuple(int, int)): + A list of the tile shape. + tile_flag (list): + A list of flag to indicate if instances within an area extended from each side (by `ioconfig.margin`) of the tile should be replaced by those within the same spatial region in the accumulated output this run. The format is [top, bottom, left, right], 1 indicates removal while 0 is not. For example, [1, 1, 0, 0] denotes replacing top and bottom instances within `ref_inst_dict` with new ones after this processing. - tile_mode (int): A flag to indicate the type of this tile. There + tile_mode (int): + A flag to indicate the type of this tile. There are 4 flags: - 0: A tile from tile grid without any overlapping, it is not an overlapping tile from tile generation. The predicted @@ -1686,16 +1691,20 @@ def _process_instance_predictions( less height (hence horizontal strip). - 3: tile strip stands at the cross-section of four normal tiles (flag 0). - tile_tl (tuple): Top left coordinates of the current tile. - ref_inst_dict (dict): Dictionary contains accumulated output. The + tile_tl (tuple): + Top left coordinates of the current tile. + ref_inst_dict (dict): + Dictionary contains accumulated output. The expected format is {instance_id: {type: int, contour: List[List[int]], centroid:List[float], box:List[int]}. Returns: - new_inst_dict (dict): A dictionary contain new instances to be accumulated. + new_inst_dict (dict): + A dictionary contain new instances to be accumulated. The expected format is {instance_id: {type: int, contour: List[List[int]], centroid:List[float], box:List[int]}. - remove_insts_in_orig (list): List of instance id within `ref_inst_dict` + remove_insts_in_orig (list): + List of instance id within `ref_inst_dict` to be removed to prevent overlapping predictions. These instances are those get cutoff at the boundary due to the tiling process. @@ -1704,7 +1713,56 @@ def _process_instance_predictions( if len(inst_dict) == 0: return {}, [] - # ! + sel_indices, margin_lines = _get_sel_indices( + ioconfig=ioconfig, + tile_shape=tile_shape, + inst_dict=inst_dict, + tile_tl=tile_tl, + tile_mode=tile_mode, + tile_flag=tile_flag, + ) + + def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: + """Helper to retrieved selected instance uids.""" + if len(sel_indices) > 0: + # not sure how costly this is in large dict + inst_uids = list(inst_dict.keys()) + return [inst_uids[idx] for idx in sel_indices] + + remove_insts_in_tile = retrieve_sel_uids(sel_indices, inst_dict) + + # external removal only for tile at cross-sections + # this one should contain UUID with the reference database + remove_insts_in_orig = [] + if tile_mode == 3: # noqa: PLR2004 + inst_boxes = [v["box"] for v in ref_inst_dict.values()] + inst_boxes = np.array(inst_boxes) + + geometries = [shapely_box(*bounds) for bounds in inst_boxes] + ref_inst_rtree = STRtree(geometries) + sel_indices = [ + geo for bounds in margin_lines for geo in ref_inst_rtree.query(bounds) + ] + + remove_insts_in_orig = retrieve_sel_uids(sel_indices, ref_inst_dict) + + new_inst_dict = _move_tile_space_to_wsi_space( + inst_dict=inst_dict, + tile_tl=tile_tl, + remove_insts_in_tile=remove_insts_in_tile, + ) + + return new_inst_dict, remove_insts_in_orig + + +def _get_sel_indices( + ioconfig: IOSegmentorConfig, + tile_shape: tuple[int, int], + tile_flag: tuple[int, int, int, int], + tile_mode: int, + tile_tl: tuple[int, int], + inst_dict: dict, +) -> tuple[list, list]: m = ioconfig.margin w, h = tile_shape inst_boxes = [v["box"] for v in inst_dict.values()] @@ -1773,32 +1831,17 @@ def _process_instance_predictions( msg = f"Unknown tile mode {tile_mode}." raise ValueError(msg) - def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: - """Helper to retrieved selected instance uids.""" - if len(sel_indices) > 0: - # not sure how costly this is in large dict - inst_uids = list(inst_dict.keys()) - return [inst_uids[idx] for idx in sel_indices] + return sel_indices, margin_lines - remove_insts_in_tile = retrieve_sel_uids(sel_indices, inst_dict) - - # external removal only for tile at cross-sections - # this one should contain UUID with the reference database - remove_insts_in_orig = [] - if tile_mode == 3: # noqa: PLR2004 - inst_boxes = [v["box"] for v in ref_inst_dict.values()] - inst_boxes = np.array(inst_boxes) - - geometries = [shapely_box(*bounds) for bounds in inst_boxes] - ref_inst_rtree = STRtree(geometries) - sel_indices = [ - geo for bounds in margin_lines for geo in ref_inst_rtree.query(bounds) - ] - - remove_insts_in_orig = retrieve_sel_uids(sel_indices, ref_inst_dict) +def _move_tile_space_to_wsi_space( + inst_dict: dict, + tile_tl: tuple, + remove_insts_in_tile: list, +) -> dict: + """Helper function to move inst dict from tile space to wsi space.""" # move inst position from tile space back to WSI space - # an also generate universal uid as replacement for storage + # and also generate universal uid as replacement for storage new_inst_dict = {} for inst_uid, inst_info in inst_dict.items(): if inst_uid not in remove_insts_in_tile: @@ -1808,8 +1851,7 @@ def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: inst_info["contours"] += tile_tl inst_uuid = uuid.uuid4().hex new_inst_dict[inst_uuid] = inst_info - - return new_inst_dict, remove_insts_in_orig + return new_inst_dict def _get_inst_info_dicts(post_process_output: tuple[dict]) -> list: From f5fe07298343be249f3f6f523e237548f75e389a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 15:38:07 +0100 Subject: [PATCH 060/156] :bug: Fix `mypy` error --- .../models/engine/multi_task_segmentor.py | 2 ++ tiatoolbox/utils/misc.py | 30 ++++++++++++------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 2173e7f59..557e93bd5 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1763,6 +1763,7 @@ def _get_sel_indices( tile_tl: tuple[int, int], inst_dict: dict, ) -> tuple[list, list]: + """Helper function to retrieve margin lines and selected indices within bounds.""" m = ioconfig.margin w, h = tile_shape inst_boxes = [v["box"] for v in inst_dict.values()] @@ -1943,6 +1944,7 @@ def _create_wsi_info_dict( memory_threshold=memory_threshold, zarr_path=save_path, chunks=post_process_output_["predictions"].shape, + name="predictions", ), "info_dict": {}, } diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 148fe7eda..ed6f9471e 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1676,6 +1676,7 @@ def create_smart_array( shape: tuple[int, ...], dtype: np.dtype | str, memory_threshold: float, + name: str | None, zarr_path: str | Path = "array.zarr", chunks: tuple[int, ...] | None = None, ) -> np.ndarray | zarr.Array: @@ -1702,6 +1703,8 @@ def create_smart_array( chunks (tuple(int,...) | None): Chunk shape for the Zarr array. If None, a reasonable default is chosen based on the array shape. + name (str | None): + Name for the zarr dataset. Returns: np.ndarray | zarr.core.Array: @@ -1719,17 +1722,22 @@ def create_smart_array( if fits_in_memory: # Allocate in-memory NumPy array - arr = np.zeros(shape, dtype=dtype) - else: - if zarr_path is None: - temp_dir = tempfile.mkdtemp(prefix="smartarray_") - zarr_path = Path(str(temp_dir)) / "array.zarr" + return np.zeros(shape, dtype=dtype) + + if zarr_path is None: + temp_dir = tempfile.mkdtemp(prefix="smartarray_") + zarr_path = Path(str(temp_dir)) / "array.zarr" - # Allocate Zarr array on disk - if chunks is None: - # Default chunking: try to chunk along spatial dims - chunks = (*(min(s, 512) for s in shape[:-1]), shape[-1]) + # Allocate Zarr array on disk + if chunks is None: + # Default chunking: try to chunk along spatial dims + chunks = (*(min(s, 512) for s in shape[:-1]), shape[-1]) - arr = zarr.open(zarr_path, mode="w", shape=shape, chunks=chunks, dtype=dtype) + zarr_group = zarr.open(zarr_path, mode="a") - return arr + return zarr_group.create_dataset( + name=name, + shape=shape, + chunks=chunks, + dtype=dtype, + ) From 46834169d9c8c72deabe3fa9b281d7f8ee855a6b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 15:49:41 +0100 Subject: [PATCH 061/156] :bug: Fix deepsource errors --- .../models/architecture/hovernetplus.py | 10 ++--- .../models/engine/multi_task_segmentor.py | 38 ++++++++++--------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 7b27e6d87..49f8dfbf6 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -333,16 +333,14 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] >>> output = model.postproc(output) """ - np_map, hv_map, tp_map, ls_map = raw_maps - # Assumes raw_maps is a tuple of dask or numpy arrays. # Only return dask if it's required. is_dask = isinstance(raw_maps[0], da.Array) + raw_maps = [ + raw_maps_.compute() if is_dask else raw_maps_ for raw_maps_ in raw_maps + ] - np_map = np_map.compute() if is_dask else np_map - hv_map = hv_map.compute() if is_dask else hv_map - tp_map = tp_map.compute() if is_dask else tp_map - ls_map = ls_map.compute() if is_dask else ls_map + np_map, hv_map, tp_map, ls_map = raw_maps pred_inst = HoVerNetPlus._proc_np_hv(np_map, hv_map, scale_factor=0.5) # fx=0.5 as nuclear processing is at 0.5 mpp instead of 0.25 mpp diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 557e93bd5..85e097291 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1713,7 +1713,7 @@ def _process_instance_predictions( if len(inst_dict) == 0: return {}, [] - sel_indices, margin_lines = _get_sel_indices( + sel_indices, margin_lines = _get_sel_indices_margin_lines( ioconfig=ioconfig, tile_shape=tile_shape, inst_dict=inst_dict, @@ -1755,7 +1755,7 @@ def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: return new_inst_dict, remove_insts_in_orig -def _get_sel_indices( +def _get_sel_indices_margin_lines( ioconfig: IOSegmentorConfig, tile_shape: tuple[int, int], tile_flag: tuple[int, int, int, int], @@ -1798,7 +1798,10 @@ def _get_sel_indices( margin_lines = [shapely_box(*v.flatten().tolist()) for v in margin_lines] # the ids within this match with those within `inst_map`, not UUID - sel_indices = [] + if tile_mode not in [0, 1, 2, 3]: + msg = f"Unknown tile mode {tile_mode}." + raise ValueError(msg) + if tile_mode in [0, 3]: # for `full grid` tiles `cross section` tiles # -- extend from the boundary by the margin size, remove @@ -1815,22 +1818,21 @@ def _get_sel_indices( for geo in tile_rtree.query(bounds) if bounds.contains(geometries[geo]) ] - elif tile_mode in [1, 2]: - # for `horizontal/vertical strip` tiles - # -- extend from the marked edges (top/bot or left/right) by - # the margin size, remove all nuclei lie within the margin - # area (including on the margin line) - # -- remove all nuclei on the boundary also - - sel_boxes = [ - margin_boxes[idx] if flag else boundary_lines[idx] - for idx, flag in enumerate(tile_flag) - ] + return sel_indices, margin_lines + + # otherwise if tile_mode in [1, 2]: + # for `horizontal/vertical strip` tiles + # -- extend from the marked edges (top/bot or left/right) by + # the margin size, remove all nuclei lie within the margin + # area (including on the margin line) + # -- remove all nuclei on the boundary also + + sel_boxes = [ + margin_boxes[idx] if flag else boundary_lines[idx] + for idx, flag in enumerate(tile_flag) + ] - sel_indices = [geo for bounds in sel_boxes for geo in tile_rtree.query(bounds)] - else: - msg = f"Unknown tile mode {tile_mode}." - raise ValueError(msg) + sel_indices = [geo for bounds in sel_boxes for geo in tile_rtree.query(bounds)] return sel_indices, margin_lines From af8b67be5ba4adcb137d3122c88701b442e06c07 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 16:28:51 +0100 Subject: [PATCH 062/156] :bug: Fix coverage for `misc.py` --- tests/engines/test_multi_task_segmentor.py | 2 +- tests/test_utils.py | 73 ++++++++++++++++++- .../models/engine/multi_task_segmentor.py | 2 +- tiatoolbox/utils/misc.py | 7 +- 4 files changed, 77 insertions(+), 7 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 01295cf8f..5b17173ef 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -326,7 +326,7 @@ def test_wsi_mtsegmentor_zarr( save_dir=track_tmp_path / "wsi_out_tile_based", batch_size=2, output_type="zarr", - memory_threshold=1, # Memory threshold forces tile_mode + memory_threshold=0, # Memory threshold forces tile_mode ioconfig=ioconfig, # HoVerNet does not return predictions once # contours have been calculated in original implementation. diff --git a/tests/test_utils.py b/tests/test_utils.py index 32f6e90ec..e40abc880 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,6 +5,7 @@ import hashlib import json import shutil +import tempfile from pathlib import Path from typing import TYPE_CHECKING, NoReturn @@ -13,6 +14,7 @@ import joblib import numpy as np import pandas as pd +import psutil import pytest import shapely import tifffile @@ -35,7 +37,7 @@ ) from tiatoolbox.utils import misc from tiatoolbox.utils.exceptions import FileNotSupportedError -from tiatoolbox.utils.misc import cast_to_min_dtype +from tiatoolbox.utils.misc import cast_to_min_dtype, create_smart_array from tiatoolbox.utils.transforms import locsize2bounds if TYPE_CHECKING: @@ -2248,3 +2250,72 @@ def test_cast_to_min_dtype_numpy_large_value() -> None: result = cast_to_min_dtype(large_value) assert result == large_value assert result.dtype == object + + +def test_returns_numpy_when_fits_in_memory( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Test that a NumPy array is returned when the array fits in memory.""" + shape = (10, 10, 3) + dtype = np.float32 + bytes_needed = np.prod(shape) * np.dtype(dtype).itemsize + + # Mock available memory to be very large + class FakeVM: + available = bytes_needed * 10 + + monkeypatch.setattr(psutil, "virtual_memory", lambda: FakeVM()) + + arr = create_smart_array( + shape=shape, + dtype=dtype, + memory_threshold=100, # allow full RAM + name="test", + zarr_path=tmp_path / "array.zarr", + ) + + assert isinstance(arr, np.ndarray) + assert arr.shape == shape + assert arr.dtype == dtype + + +def test_creates_temp_dir_when_zarr_path_none( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Test that a temporary directory is created when zarr_path is None.""" + shape = (1000, 1000, 3) + dtype = np.float32 + + # Force fits_in_memory = False by mocking available memory to be tiny + class FakeVM: + available = 1 + + monkeypatch.setattr(psutil, "virtual_memory", lambda: FakeVM()) + + # Track calls to tempfile.mkdtemp + created_dirs = [] + + def fake_mkdtemp(prefix: str) -> str: + """Fake mkdtemp method.""" + _ = prefix + path = tmp_path / "tempdir" + created_dirs.append(path) + return str(path) + + monkeypatch.setattr(tempfile, "mkdtemp", fake_mkdtemp) + + arr = create_smart_array( + shape=shape, + dtype=dtype, + memory_threshold=0, # force Zarr allocation + name="test", + zarr_path=None, + ) + + # Ensure mkdtemp was called + assert len(created_dirs) == 1 + + # Ensure returned object is a Zarr array + assert isinstance(arr, zarr.Array) + assert arr.shape == shape + assert arr.dtype == dtype diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 85e097291..4272f7457 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1946,7 +1946,7 @@ def _create_wsi_info_dict( memory_threshold=memory_threshold, zarr_path=save_path, chunks=post_process_output_["predictions"].shape, - name="predictions", + name=f"{post_process_output_['task_type']}/predictions", ), "info_dict": {}, } diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index ed6f9471e..564c1fda7 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1677,7 +1677,7 @@ def create_smart_array( dtype: np.dtype | str, memory_threshold: float, name: str | None, - zarr_path: str | Path = "array.zarr", + zarr_path: str | Path | None = None, chunks: tuple[int, ...] | None = None, ) -> np.ndarray | zarr.Array: """Allocate a NumPy or Zarr array depending on available memory and a threshold. @@ -1729,9 +1729,8 @@ def create_smart_array( zarr_path = Path(str(temp_dir)) / "array.zarr" # Allocate Zarr array on disk - if chunks is None: - # Default chunking: try to chunk along spatial dims - chunks = (*(min(s, 512) for s in shape[:-1]), shape[-1]) + # Default chunking: try to chunk along spatial dims + chunks = shape if chunks is None else chunks zarr_group = zarr.open(zarr_path, mode="a") From 58201826174becf9461dddf7a896dd5e249636d5 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 17:09:26 +0100 Subject: [PATCH 063/156] :bug: Fix coverage for `multi_task_segmentor.py` --- tests/engines/test_multi_task_segmentor.py | 20 +++++++ tests/test_utils.py | 10 ++-- .../models/engine/multi_task_segmentor.py | 60 ++++++++++++------- 3 files changed, 63 insertions(+), 27 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 5b17173ef..f140cd942 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -17,6 +17,7 @@ from tiatoolbox.models.engine.multi_task_segmentor import ( MultiTaskSegmentor, _clear_zarr, + _get_sel_indices_margin_lines, _save_multitask_vertical_to_cache, ) from tiatoolbox.utils import env_detection as toolbox_env @@ -499,6 +500,25 @@ def test_raise_value_error_return_labels_wsi( output_type="zarr", ) + # inst_dict must contain boxes + inst_dict = { + 1: {"box": np.array([81, 0, 96, 9])}, + 2: {"box": np.array([138, 0, 151, 8])}, + } + + invalid_tile_mode = 99 # not in [0,1,2,3] + ioconfig = mtsegmentor.ioconfig + ioconfig.margin = 128 + with pytest.raises(ValueError, match=r".*Unknown tile mode.*"): + _get_sel_indices_margin_lines( + ioconfig=ioconfig, + tile_shape=(492, 492), + tile_flag=(0, 1, 0, 1), + tile_mode=invalid_tile_mode, + tile_tl=(0, 0), + inst_dict=inst_dict, + ) + def test_clear_zarr() -> None: """Test _clear_zarr working appropriately. diff --git a/tests/test_utils.py b/tests/test_utils.py index e40abc880..c84670ce3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2260,11 +2260,12 @@ def test_returns_numpy_when_fits_in_memory( dtype = np.float32 bytes_needed = np.prod(shape) * np.dtype(dtype).itemsize - # Mock available memory to be very large class FakeVM: + """Mock available memory to be very large.""" + available = bytes_needed * 10 - monkeypatch.setattr(psutil, "virtual_memory", lambda: FakeVM()) + monkeypatch.setattr(psutil, "virtual_memory", FakeVM) arr = create_smart_array( shape=shape, @@ -2286,11 +2287,12 @@ def test_creates_temp_dir_when_zarr_path_none( shape = (1000, 1000, 3) dtype = np.float32 - # Force fits_in_memory = False by mocking available memory to be tiny class FakeVM: + """Force fits_in_memory = False by mocking available memory to be tiny.""" + available = 1 - monkeypatch.setattr(psutil, "virtual_memory", lambda: FakeVM()) + monkeypatch.setattr(psutil, "virtual_memory", FakeVM) # Track calls to tempfile.mkdtemp created_dirs = [] diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 4272f7457..d856a6f8e 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1764,8 +1764,12 @@ def _get_sel_indices_margin_lines( inst_dict: dict, ) -> tuple[list, list]: """Helper function to retrieve margin lines and selected indices within bounds.""" - m = ioconfig.margin - w, h = tile_shape + if tile_mode not in [0, 1, 2, 3]: + msg = f"Unknown tile mode {tile_mode}." + raise ValueError(msg) + + margin = ioconfig.margin + width, height = tile_shape inst_boxes = [v["box"] for v in inst_dict.values()] inst_boxes = np.array(inst_boxes) @@ -1776,32 +1780,25 @@ def _get_sel_indices_margin_lines( # create margin bounding box, ordering should match with # created tile info flag (top, bottom, left, right) boundary_lines = [ - shapely_box(0, 0, w, 1), # top egde - shapely_box(0, h - 1, w, h), # bottom edge - shapely_box(0, 0, 1, h), # left - shapely_box(w - 1, 0, w, h), # right + shapely_box(0, 0, width, 1), # top egde + shapely_box(0, height - 1, width, height), # bottom edge + shapely_box(0, 0, 1, height), # left + shapely_box(width - 1, 0, width, height), # right ] margin_boxes = [ - shapely_box(0, 0, w, m), # top egde - shapely_box(0, h - m, w, h), # bottom edge - shapely_box(0, 0, m, h), # left - shapely_box(w - m, 0, w, h), # right + shapely_box(0, 0, width, margin), # top egde + shapely_box(0, height - margin, width, height), # bottom edge + shapely_box(0, 0, margin, height), # left + shapely_box(width - margin, 0, width, height), # right ] - # ! this is wrt to WSI coord space, not tile - margin_lines = [ - [[m, m], [w - m, m]], # top egde - [[m, h - m], [w - m, h - m]], # bottom edge - [[m, m], [m, h - m]], # left - [[w - m, m], [w - m, h - m]], # right - ] - margin_lines = np.array(margin_lines) + tile_tl[None, None] - margin_lines = [shapely_box(*v.flatten().tolist()) for v in margin_lines] + margin_lines = _get_margin_lines( + margin=margin, + height=height, + width=width, + tile_tl=tile_tl, + ) # the ids within this match with those within `inst_map`, not UUID - if tile_mode not in [0, 1, 2, 3]: - msg = f"Unknown tile mode {tile_mode}." - raise ValueError(msg) - if tile_mode in [0, 3]: # for `full grid` tiles `cross section` tiles # -- extend from the boundary by the margin size, remove @@ -1837,6 +1834,23 @@ def _get_sel_indices_margin_lines( return sel_indices, margin_lines +def _get_margin_lines( + margin: int, + height: int, + width: int, + tile_tl: tuple[int, int], +) -> list: + # ! this is wrt to WSI coord space, not tile + margin_lines = [ + [[margin, margin], [width - margin, margin]], # top egde + [[margin, height - margin], [width - margin, height - margin]], # bottom edge + [[margin, margin], [margin, height - margin]], # left + [[width - margin, margin], [width - margin, height - margin]], # right + ] + margin_lines = np.array(margin_lines) + tile_tl[None, None] + return [shapely_box(*v.flatten().tolist()) for v in margin_lines] + + def _move_tile_space_to_wsi_space( inst_dict: dict, tile_tl: tuple, From ce865bd81fc6dee900ed5c6840a1b80ae74429aa Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 17:50:51 +0100 Subject: [PATCH 064/156] :memo: Update docstrings. --- .../models/engine/multi_task_segmentor.py | 161 ++++++++++++++++-- .../models/engine/semantic_segmentor.py | 2 +- 2 files changed, 147 insertions(+), 16 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index d856a6f8e..c49a3f7d6 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -98,7 +98,106 @@ class MultiTaskSegmentorRunParams(SemanticSegmentorRunParams, total=False): class MultiTaskSegmentor(SemanticSegmentor): - """A multitask segmentation engine for models like hovernet and hovernetplus.""" + """MultiTask segmentation engine to run models like hovernet and hovernetplus. + + MultiTaskSegmentor performs segmentation across multiple model heads + (e.g., semantic, instance, edge). It abstracts model invocation, + preprocessing, and output postprocessing for multi-head segmentation. + + Args: + model (str | ModelABC): + A PyTorch model instance or name of a pretrained model from TIAToolbox. + The user can request pretrained models from the toolbox model zoo using + the list of pretrained models available at this `link + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights using the `weights` parameter. Default is `None`. + batch_size (int): + Number of image patches processed per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. + + >>> engine = SemanticSegmentor( + ... model="pretrained-model", + ... weights="/path/to/pretrained-local-weights.pth" + ... ) + + device (str): + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". + verbose (bool): + Whether to enable verbose logging. Default is True. + + Attributes: + images (list[str | Path] | np.ndarray): + Input image patches or WSI paths. + masks (list[str | Path] | np.ndarray): + Optional tissue masks for WSI processing. + These are only utilized when patch_mode is False. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + patch_mode (bool): + Whether input is treated as patches (`True`) or WSIs (`False`). + model (ModelABC): + Loaded PyTorch model. + ioconfig (ModelIOConfigABC): + IO configuration for patch extraction and resolution. + return_labels (bool): + Whether to include labels in the output. + input_resolutions (list[dict]): + Resolution settings for model input. Supported + units are `level`, `power` and `mpp`. Keys should be "units" and + "resolution" e.g., [{"units": "mpp", "resolution": 0.25}]. Please see + :class:`WSIReader` for details. + patch_input_shape (tuple[int, int]): + Shape of input patches (height, width). Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple[int, int]): + Stride used during patch extraction. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + labels (list | None): + Optional labels for input images. + Only a single label per image is supported. + drop_keys (list): + Keys to exclude from model output. + output_type (str): + Format of output ("dict", "zarr", "annotationstore"). + output_locations (list | None): + Coordinates of output patches used during WSI processing. + + Examples: + >>> # list of 2 image patches as input + >>> wsis = ['path/img.svs', 'path/img.svs'] + >>> mtsegmentor = MultiTaskSegmentor(model="hovernetplus-oed") + >>> output = mtsegmentor.run(wsis, patch_mode=False) + + >>> # array of list of 2 image patches as input + >>> image_patches = [np.ndarray, np.ndarray] + >>> mtsegmentor = MultiTaskSegmentor(model="hovernetplus-oed") + >>> output = mtsegmentor.run(image_patches, patch_mode=True) + + >>> # list of 2 image patch files as input + >>> data = ['path/img.png', 'path/img.png'] + >>> mtsegmentor = MultiTaskSegmentor(model="hovernet_fast-pannuke") + >>> output = mtsegmentor.run(data, patch_mode=False) + + >>> # list of 2 image tile files as input + >>> tile_file = ['path/tile1.png', 'path/tile2.png'] + >>> mtsegmentor = MultiTaskSegmentor(model="hovernet_fast-pannuke") + >>> output = mtsegmentor.run(tile_file, patch_mode=False) + + >>> # list of 2 wsi files as input + >>> wsis = ['path/wsi1.svs', 'path/wsi2.svs'] + >>> mtsegmentor = MultiTaskSegmentor(model="hovernet_fast-pannuke") + >>> output = mtsegmentor.run(wsis, patch_mode=False) + + + """ def __init__( self: MultiTaskSegmentor, @@ -110,7 +209,25 @@ def __init__( device: str = "cpu", verbose: bool = True, ) -> None: - """Initialize :class:`NucleusInstanceSegmentor`.""" + """Initialize :class:`MultiTaskSegmentor`. + + Args: + model (str | ModelABC): + A PyTorch model instance or name of a pretrained model from TIAToolbox. + If a string is provided, the corresponding pretrained weights will be + downloaded unless overridden via `weights`. + batch_size (int): + Number of image patches processed per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". + verbose (bool): + Whether to enable verbose logging. Default is True. + + """ self.tasks = set() super().__init__( model=model, @@ -127,26 +244,40 @@ def infer_patches( *, return_coordinates: bool = False, ) -> dict[str, list[da.Array]]: - """Run model inference on image patches and return predictions. + """Run inference on a batch of image patches using the multitask model. - This method performs batched inference using a PyTorch DataLoader, - and accumulates predictions in Dask arrays. It supports optional inclusion - of coordinates and labels in the output. + This method processes patches provided by a PyTorch ``DataLoader`` and runs + them through the model's ``infer_batch`` method. Models with multiple heads + (e.g., semantic, instance, edge) may return multiple outputs per patch. + Outputs are collected as Dask arrays for efficient large-scale aggregation. Args: dataloader (DataLoader): - PyTorch DataLoader containing image patches for inference. + A PyTorch dataloader that yields dicts containing ``"image"`` tensors + and optionally other metadata (e.g., coordinates). return_coordinates (bool): - Whether to include coordinates in the output. Required when - called by `infer_wsi` and `patch_mode` is False. + Whether to return the spatial coordinates associated with each patch + (when available from the dataset). Default is False. Returns: - dict[str, dask.array.Array]: - Dictionary containing prediction results as Dask arrays. - Keys include: - - "probabilities": Model output probabilities. - - "coordinates": Patch coordinates (if `return_coordinates` is - True). + dict[str, list[da.Array]]: + A dictionary containing the model outputs for all patches. + + Keys: + probabilities (list[da.Array]): + A list of Dask arrays containing model outputs for each head. + Each array has shape ``(N, C, H, W)`` depending on the model. + coordinates (da.Array): + Returned only when ``return_coordinates=True``. + A Dask array of shape ``(N, 2)`` or ``(N, 4)`` depending on + how patch coordinates are stored in the dataset. + + Notes: + - The number of model outputs (heads) is inferred dynamically from the + first forward pass. + - Outputs are stacked via ``dask.array.concatenate`` for scalability. + - This method does not perform postprocessing; raw logits/probabilities + are returned exactly as produced by the model. """ keys = ["probabilities"] diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 2405c9374..57efd8994 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -244,7 +244,7 @@ class SemanticSegmentor(PatchPredictor): >>> # array of list of 2 image patches as input >>> image_patches = [np.ndarray, np.ndarray] >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") - >>> output = segmentor.run(data, patch_mode=True) + >>> output = segmentor.run(image_patches, patch_mode=True) >>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] From 260bbc947af65bc142af9168fb462053b540f78d Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 22:15:30 +0100 Subject: [PATCH 065/156] :memo: Update docstrings. --- .../models/engine/multi_task_segmentor.py | 468 ++++++++++++++++-- 1 file changed, 426 insertions(+), 42 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index c49a3f7d6..6b34b32c7 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -80,10 +80,10 @@ class MultiTaskSegmentorRunParams(SemanticSegmentorRunParams, total=False): Shape of output patches (height, width). return_labels (bool): Whether to return labels with predictions. - return_probabilities (bool): - Whether to return per-class probabilities. return_predictions (tuple(bool, ...): Whether to return array predictions for individual tasks. + return_probabilities (bool): + Whether to return per-class probabilities. scale_factor (tuple[float, float]): Scale factor for converting annotations to baseline resolution. Typically model_mpp / slide_mpp. @@ -346,7 +346,96 @@ def infer_wsi( save_path: Path, **kwargs: Unpack[MultiTaskSegmentorRunParams], ) -> dict[str, da.Array]: - """Perform model inference on a whole slide image (WSI).""" + """Perform model inference on a whole slide image (WSI). + + This method iterates over WSI patches produced by a DataLoader, + runs each patch through the model's ``infer_batch`` callback, and + incrementally assembles full-resolution model outputs for each model + head (e.g., semantic, instance, edge). Patch-level outputs are merged + row-by-row using horizontal stitching, optionally spilling intermediate + results to disk when memory usage exceeds a threshold. After all rows + are processed, vertical merging is performed to generate the final + probability maps for each multitask head. + + Raw probabilities and patch coordinates are returned as Dask arrays. + This method does not perform any post-processing; downstream calls to + ``post_process_wsi`` are required to convert model logits into + task-specific outputs (e.g., instances, contours, or label maps). + + Args: + dataloader (DataLoader): + A PyTorch dataloader yielding dictionaries with keys such as + ``"image"`` and ``"output_locs"`` that correspond to extracted + WSI patches and their placement metadata. + save_path (Path): + A filesystem path used to store temporary Zarr cache data when + memory spilling is triggered. The directory is created if needed. + **kwargs (MultiTaskSegmentorRunParams): + Additional runtime parameters to configure segmentation. + + Optional Keys: + auto_get_mask (bool): + Automatically generate segmentation masks using + `wsireader.tissue_mask()` during processing. + batch_size (int): + Number of image patches per forward pass. + class_dict (dict): + Mapping of classification outputs to class names. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). + labels (list): + Optional labels for input images. Only a single label per image + is supported. + memory_threshold (int): + Memory usage threshold (percentage) to trigger caching behavior. + num_workers (int): + Number of workers for DataLoader and post-processing. + output_file (str): + Filename for saving output (e.g., ".zarr" or ".db"). + output_resolutions (Resolution): + Resolution used for writing output predictions. + patch_output_shape (tuple[int, int]): + Shape of output patches (height, width). + return_labels (bool): + Whether to return labels with predictions. + return_predictions (tuple(bool, ...): + Whether to return array predictions for individual tasks. + return_probabilities (bool): + Whether to return per-class probabilities. + scale_factor (tuple[float, float]): + Scale factor for annotations (model_mpp / slide_mpp). + Used to convert coordinates to baseline resolution. + stride_shape (tuple[int, int]): + Stride used during WSI processing. + Defaults to `patch_input_shape` if not provided. + verbose (bool): + Whether to enable verbose logging. + + Returns: + dict[str, da.Array]: + A dictionary containing the raw multitask model outputs. + + Keys: + probabilities (list[da.Array]): + One Dask array per model head, each representing the final + WSI-sized probability map for that task. Each array has + shape ``(H, W, C)`` depending on the head's channel count. + coordinates (da.Array): + A Dask array of shape ``(N, 2)`` or ``(N, 4)``, containing + accumulated patch coordinate metadata produced during the + WSI dataloader iteration. + + Notes: + - The number of model heads is inferred from the first + ``infer_batch`` call. + - Patch predictions are merged horizontally when the x-coordinate + changes row, and vertically after all rows are processed. + - Large WSIs may trigger spilling intermediate canvas data to disk + when memory exceeds ``memory_threshold``. + - This function returns *raw probabilities only*. For task-specific + segmentation or instance extraction, call ``post_process_wsi``. + + """ # Default Memory threshold percentage is 80. memory_threshold = kwargs.get("memory_threshold", 80) @@ -478,21 +567,87 @@ def post_process_patches( # skipcq: PYL-R0201 raw_predictions: dict, **kwargs: Unpack[MultiTaskSegmentorRunParams], # noqa: ARG002 ) -> dict: - """Post-process raw patch predictions from inference. + """Post-process raw patch-level predictions for multitask segmentation. - This method applies a post-processing function (e.g., smoothing, filtering) - to the raw model predictions. It supports delayed execution using Dask - and returns a Dask array for efficient computation. + This method applies the model's ``postproc_func`` to per-patch probability + maps produced by ``infer_patches``. For multitask models (multiple heads), + it zips the per-head probability arrays across patches and invokes + ``postproc_func`` to obtain one or more task dictionaries per patch (e.g., + semantic labels, instance info, edges). The per-patch outputs are then + reorganized into a task-centric structure using + ``build_post_process_raw_predictions`` for downstream saving. Args: - raw_predictions (dask.array.Array): - Raw model predictions as a dask array. - **kwargs (EngineABCRunParams): - Additional runtime parameters used for post-processing. + raw_predictions (dict): + Dictionary containing raw model outputs from ``infer_patches``. + Expected keys: + - ``"probabilities"`` (list[da.Array]): + One Dask array per model head. Each array typically has shape + ``(N, H, W, C)`` for ``N`` patches, with head-specific channels. + These are *raw* logits/probabilities and are not normalized + beyond what the model provides. + **kwargs (MultiTaskSegmentorRunParams): + Additional runtime parameters to configure segmentation. + + Optional Keys: + auto_get_mask (bool): + Automatically generate segmentation masks using + `wsireader.tissue_mask()` during processing. + batch_size (int): + Number of image patches per forward pass. + class_dict (dict): + Mapping of classification outputs to class names. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). + labels (list): + Optional labels for input images. Only a single label per image + is supported. + memory_threshold (int): + Memory usage threshold (percentage) to trigger caching behavior. + num_workers (int): + Number of workers for DataLoader and post-processing. + output_file (str): + Filename for saving output (e.g., ".zarr" or ".db"). + output_resolutions (Resolution): + Resolution used for writing output predictions. + patch_output_shape (tuple[int, int]): + Shape of output patches (height, width). + return_labels (bool): + Whether to return labels with predictions. + return_predictions (tuple(bool, ...): + Whether to return array predictions for individual tasks. + return_probabilities (bool): + Whether to return per-class probabilities. + scale_factor (tuple[float, float]): + Scale factor for annotations (model_mpp / slide_mpp). + Used to convert coordinates to baseline resolution. + stride_shape (tuple[int, int]): + Stride used during WSI processing. + Defaults to `patch_input_shape` if not provided. + verbose (bool): + Whether to enable verbose logging. Returns: - dask.array.Array: - Post-processed predictions as a Dask array. + dict: + A task-organized dictionary suitable for saving, where each entry + corresponds to a task produced by ``postproc_func``. For each task + (e.g., ``"semantic"``, ``"instance"``), keys and value types depend + on the model's post-processing output. Typical patterns include: + - ``"predictions"``: list[da.Array] with per-patch outputs, + if the model returns patch-level prediction arrays. + - ``"info_dict"``: list[dict] with per-patch metadata dictionaries + (e.g., instance tables, properties). Lists are aligned to the + number of input patches. + Any pre-existing keys in ``raw_predictions`` (e.g., ``"coordinates"``) + are preserved as returned by ``build_post_process_raw_predictions``. + + Notes: + - This method is *patch-level* post-processing only; it does not perform + WSI-scale tiling or stitching. For WSI outputs, use ``post_process_wsi``. + - Inputs are typically Dask arrays; computation remains lazy until an + explicit save step or ``dask.compute`` is invoked downstream. + - The exact set of task keys and payload shapes are determined by the + model's ``postproc_func`` for each head. """ probabilities = raw_predictions["probabilities"] @@ -501,23 +656,106 @@ def post_process_patches( # skipcq: PYL-R0201 for probs_for_idx in zip(*probabilities, strict=False) ] - raw_predictions = self.build_post_process_raw_predictions( + return self.build_post_process_raw_predictions( post_process_predictions=post_process_predictions, raw_predictions=raw_predictions, ) - # Need to update info_dict - _ = raw_predictions - - return raw_predictions - def post_process_wsi( # skipcq: PYL-R0201 self: MultiTaskSegmentor, raw_predictions: dict, save_path: Path, **kwargs: Unpack[MultiTaskSegmentorRunParams], ) -> dict: - """Post-process raw patch predictions from inference.""" + """Post-process whole slide image (WSI) predictions for multitask segmentation. + + This method converts raw WSI-scale probability maps (produced by + ``infer_wsi``) into task-specific outputs using the model's + ``postproc_func``. If the probability maps are fully in memory, the method + processes the entire WSI at once. If they are Zarr-backed (spilled during + inference) or too large, it switches to tile mode: it iterates over WSI + tiles, applies ``postproc_func`` per tile, merges instance predictions + across tile boundaries, and optionally writes intermediate arrays to Zarr + under ``save_path.with_suffix(".zarr")`` for memory efficiency. + + The result is organized into a task-centric dictionary (e.g., semantic, + instance) with arrays and/or metadata suitable for saving or further use. + + Args: + raw_predictions (dict): + Dictionary containing WSI-scale model outputs from ``infer_wsi``. + Expected key: + - ``"probabilities"`` (tuple[da.Array]): + One Dask array per model head. Each array is either + memory-backed (Dask→NumPy) or Zarr-backed depending on + memory spilling during inference. + save_path (Path): + Base path for writing intermediate Zarr arrays in tile mode and + for allocating per-task outputs when disk-backed arrays are needed. + **kwargs (MultiTaskSegmentorRunParams): + Additional runtime parameters to configure segmentation. + + Optional Keys: + auto_get_mask (bool): + Automatically generate segmentation masks using + `wsireader.tissue_mask()` during processing. + batch_size (int): + Number of image patches per forward pass. + class_dict (dict): + Mapping of classification outputs to class names. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). + labels (list): + Optional labels for input images. Only a single label per image + is supported. + memory_threshold (int): + Memory usage threshold (percentage) to trigger caching behavior. + num_workers (int): + Number of workers for DataLoader and post-processing. + output_file (str): + Filename for saving output (e.g., ".zarr" or ".db"). + output_resolutions (Resolution): + Resolution used for writing output predictions. + patch_output_shape (tuple[int, int]): + Shape of output patches (height, width). + return_labels (bool): + Whether to return labels with predictions. + return_predictions (tuple(bool, ...): + Whether to return array predictions for individual tasks. + return_probabilities (bool): + Whether to return per-class probabilities. + scale_factor (tuple[float, float]): + Scale factor for annotations (model_mpp / slide_mpp). + Used to convert coordinates to baseline resolution. + stride_shape (tuple[int, int]): + Stride used during WSI processing. + Defaults to `patch_input_shape` if not provided. + verbose (bool): + Whether to enable verbose logging. + + Returns: + dict: + A task-organized dictionary of WSI-scale outputs. For each task + (e.g., ``"semantic"``, ``"instance"``), typical entries include: + - ``"predictions"`` (da.Array or np.ndarray, optional): + Full-resolution task prediction map, present only where + enabled by ``return_predictions``. + - Additional task-specific keys (e.g., ``"info_dict"``, + per-instance dictionaries, contours, classes, probabilities). + The set of keys and their exact shapes/types are determined by the + model's ``postproc_func``. + + Notes: + - Full-WSI mode is selected when probability maps are not Zarr-backed; + otherwise tile mode is used. + - Tile mode uses model-specific merging of instances across tile + boundaries and may write intermediate arrays under a ``.zarr`` group + next to ``save_path``. + - Probability maps themselves are not modified here; this method produces + task-centric outputs from them. Use ``save_predictions`` to persist + results as ``dict``, ``zarr``, or ``annotationstore``. + + """ probabilities = raw_predictions["probabilities"] probabilities_is_zarr = False @@ -568,7 +806,50 @@ def _process_full_wsi( *, return_predictions: tuple[bool, ...] | None = None, ) -> list[dict] | None: - """Helper function to post process WSI when it can fit in memory.""" + """Convert full-WSI probability maps into task-specific outputs in memory. + + This helper is used when the WSI-scale probability maps (one per model head) + fit in memory without requiring Zarr-backed tiling. It invokes the model's + ``postproc_func`` once on the complete list of head maps and returns a list + of per-task dictionaries (e.g., semantic, instance). Optionally, it drops + the ``"predictions"`` array for tasks where returning the full-resolution + map is not requested. + + Args: + probabilities (list[da.Array | np.ndarray]): + Full-resolution probability maps, one per model head. Each element + is either a Dask array or NumPy array with shape ``(H, W, C)``, + where ``C`` is head-specific. These are the outputs of + ``infer_wsi`` after horizontal/vertical stitching. + return_predictions (tuple[bool, ...] | None): + Per-task flags indicating whether to keep the task's + full-resolution ``"predictions"`` array in the result. If + ``None``, no task predictions are returned (all ``"predictions"`` + keys are removed). The tuple length must match the number of + task dictionaries returned by ``postproc_func``. + + Returns: + list[dict] | None: + A list of task dictionaries returned by the model's + ``postproc_func``. Each dictionary must include + ``"task_type"`` and may include keys such as + ``"predictions"`` (``np.ndarray`` or ``da.Array``) and/or an + ``"info_dict"`` with task-specific metadata. If all task + predictions are dropped and no other outputs are produced, + this may return ``None``. + + Notes: + - This function performs no tiling or disk spilling; it assumes the + inputs fit in memory. For large WSIs or Zarr-backed probability + maps, use ``_process_tile_mode`` instead. + - The exact set of task keys and value types is model-dependent and + determined by ``postproc_func``. + - When ``return_predictions`` is provided, it is applied positionally + to the sequence of task dictionaries emitted by ``postproc_func``: + if a task's flag is ``False``, that task's ``"predictions"`` key is + removed from the output. + + """ post_process_predictions = self.model.postproc_func(probabilities) if return_predictions is None: return_predictions = [False for _ in post_process_predictions] @@ -586,7 +867,72 @@ def _process_tile_mode( *, return_predictions: tuple[bool, ...] | None = None, ) -> list[dict] | None: - """Helper function to process WSI in tile mode.""" + """Convert WSI probability maps into outputs using tile-mode processing. + + This helper is used when WSI-scale probability maps are Zarr-backed or too + large to fit comfortably in memory. It iterates over WSI tiles, extracts the + corresponding sub-arrays from each model head, applies the model's + ``postproc_func`` per tile, and merges task outputs across tile boundaries. + For instance-type tasks, it removes duplicated/cut instances near tile + margins using configuration from ``IOSegmentorConfig`` (tile flags, margin) + and consolidates detections into the slide coordinate system. + + Optionally, full-resolution per-task prediction arrays (e.g., dense label or + probability maps) are allocated as NumPy or Zarr via ``create_smart_array`` + and incrementally filled at the appropriate tile locations. Allocation and + spilling behavior are governed by ``memory_threshold``. + + Args: + probabilities (list[da.Array | np.ndarray]): + WSI-scale probability maps, one per model head, with shape + ``(H, W, C)`` per head. These are the outputs of ``infer_wsi`` + (after horizontal/vertical stitching) and may be Zarr-backed. + save_path (Path): + Base path used for creating a ``.zarr`` group to store + disk-backed arrays when memory usage exceeds the threshold and for + per-task predictions when requested by ``return_predictions``. + memory_threshold (float): + Maximum allowed RAM usage (percentage) for in-memory arrays before + switching to or continuing with Zarr-backed allocation. Defaults to 80. + return_predictions (tuple[bool, ...] | None): + Per-task flags indicating whether to retain a full-resolution + ``"predictions"`` array for each task. If ``None``, no task-level + prediction arrays are retained (i.e., they are set to ``None`` and not + allocated). The tuple length must match the number of task dictionaries + produced by ``postproc_func``. + + Returns: + list[dict] | None: + A list of task dictionaries (one per multitask head output as produced + by ``postproc_func``) with fields such as: + - ``"task_type"`` (str): Name/type of the task (e.g., + ``"semantic"``, ``"instance"``). + - ``"predictions"`` (np.ndarray or Zarr-backed array | None): + Full-resolution task prediction array if enabled by + ``return_predictions``; otherwise ``None``. + - ``"info_dict"`` (dict): Task-specific metadata accumulated across + tiles. For instance tasks, this includes merged instance tables + (e.g., boxes, centroids, contours) keyed by UUIDs in WSI space. + + Returns ``None`` only if ``postproc_func`` yields no outputs. + + Notes: + - Tile layout is derived from the engine IO config; each tile's bounds + are used to slice per-head probability maps and to place results back + into WSI space. + - For instance tasks, objects near tile margins are pruned/merged using + per-tile flags and a configurable margin to avoid duplicates across + tiles. Instance coordinates (boxes, centroids, contours) are translated + from tile space to WSI space prior to consolidation. + - When ``return_predictions`` requests any task array, allocation is done + via ``create_smart_array`` to choose between NumPy and Zarr based on + ``memory_threshold``. Arrays are filled tile-by-tile using the tile + bounds. + - Computation remains lazy for Dask-backed inputs until explicitly + computed or saved downstream. Probability maps themselves are not + modified in this method; it only derives task-centric outputs. + + """ highest_input_resolution = self.ioconfig.highest_input_resolution wsi_reader = self.dataloader.dataset.reader @@ -856,33 +1202,71 @@ def build_post_process_raw_predictions( post_process_predictions: list[tuple], raw_predictions: dict, ) -> dict: - """Merge per-image outputs into a task-organized prediction structure. - - This function takes a list of outputs, where each element corresponds to one - image and contains one or more segmentation dictionaries. Each segmentation - dictionary must include a ``"task_type"`` key along with any number of - additional fields (e.g., ``"predictions"``, ``"info_dict"``, or others). - - The function reorganizes these outputs into ``raw_predictions`` by grouping - entries under their respective task types. For each task, all keys except - ``"task_type"`` are stored in dictionaries indexed by ``img_id``. Existing - content in ``raw_predictions`` is preserved and extended as needed. + """Merge per-image, per-task outputs into a task-organized prediction structure. + + This function takes a list of outputs where each element corresponds to one + image and contains one or more task dictionaries returned by the model's + post-processing step (e.g., semantic, instance). Each task dictionary must + include a ``"task_type"`` key along with any number of task-specific fields + (for example, ``"predictions"``, ``"info_dict"``, or additional metadata). + The function reorganizes this data into ``raw_predictions`` by grouping + entries under their respective task types and aligning values across images. + + The merging logic is as follows: + 1) For each task (identified by ``"task_type"``), values for keys other than + ``"task_type"`` are temporarily collected into lists, one entry per image. + 2) After all images are processed, list entries are normalized: + + - If all entries for a key are array-like (``np.ndarray`` or + ``dask.array.Array``), + they are stacked along a new leading dimension (image axis). + - If all entries for a key are dictionaries, their subkeys are expanded + into separate lists aligned across images (the original composite key + is removed). + 3) Existing content in ``raw_predictions`` is preserved and extended as + needed. Args: post_process_predictions (list[tuple]): - A list where each element represents one image. Each element is an - iterable of segmentation dictionaries. Each segmentation dictionary - must contain a ``"task_type"`` field and may contain any number of - additional fields. + A list where each element represents a single image. Each element is + an iterable of task dictionaries. Every task dictionary **must** + contain: + - ``"task_type"`` (str): Name/type of the task + (e.g., ``"semantic"``, ``"instance"``, ``"edge"``). + and **may** contain any number of additional fields, such as: + - ``"predictions"``: array-like output for that task + - ``"info_dict"``: dictionary of task-specific metadata + - Any other task-dependent keys raw_predictions (dict): - A dictionary that will be updated in-place. It may already contain - task entries or other unrelated keys. New tasks and new fields are - added dynamically as they appear in ``outputs``. + Dictionary that will be updated **in-place**. It may already contain + task entries or unrelated keys (e.g., ``"probabilities"``, + ``"coordinates"``). New tasks and fields are added as they appear. Returns: dict: - The updated ``raw_predictions`` dictionary, containing all tasks and - their associated per-image fields. + The updated ``raw_predictions`` dictionary containing one entry per + task type. Under each task name, keys hold per-image arrays (stacked + as Dask/NumPy where applicable) or lists/dicts aligned across images. + Example structure: + { + "semantic": { + "predictions": da.Array | np.ndarray, # stacked over images + "info_dict": [dict, dict, ...] # or expanded subkeys + }, + "instance": { + "info_dict": [...], # per-image metadata + "contours": [...], "classes": [...], # task-dependent keys + }, + "coordinates": da.Array, # if previously present + } + + Notes: + - Array stacking occurs only when **all** per-image entries for a key are + array-like; mixed types remain as lists. + - Dictionary expansion occurs only when **all** per-image entries for a key + are dictionaries; subkeys are promoted to top-level keys under the task + and aligned across images. + - The set ``self.tasks`` is updated to include all encountered task types. """ tasks = set() From 21b33985608041a5b6d07a9ab0a1cd154ca3c1cb Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 22:24:21 +0100 Subject: [PATCH 066/156] :memo: Update docstrings. --- .../models/engine/multi_task_segmentor.py | 171 ++++++++++++++---- 1 file changed, 140 insertions(+), 31 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 6b34b32c7..08cacba12 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1436,20 +1436,44 @@ def save_predictions( ) -> dict | AnnotationStore | Path | list[Path]: """Save model predictions to disk or return them in memory. - Depending on the output type, this method saves predictions as a zarr group, - an AnnotationStore (SQLite database), or returns them as a dictionary. + Depending on ``output_type``, this method either: + - returns a Python dictionary (``"dict"``), + - writes a Zarr group to disk and returns the path (``"zarr"``), or + - writes one or more SQLite-backed AnnotationStore ``.db`` files and + returns the resulting path(s) (``"annotationstore"``). + + For multitask outputs, this function also: + - Preserves task separation when saving to Zarr (one group per task). + - Optionally saves raw probability maps if ``return_probabilities=True`` + (as Zarr only; probabilities cannot be written to AnnotationStore). + - Merges per-task keys for saving to AnnotationStore, including optional + coordinates to establish slide origin. Args: processed_predictions (dict): - Dictionary containing processed model predictions. + Task-organized dictionary produced by post-processing (e.g. from + ``post_process_patches`` or ``post_process_wsi``). For multitask + models this typically includes: + - ``"probabilities"`` (optional): list[da.Array] of WSI maps, + present if preserved for saving. + - Per-task sub-dicts (e.g., ``"semantic"``, ``"instance"``), + each containing task-specific arrays/metadata such as + ``"predictions"``, ``"info_dict"``, etc. + - ``"coordinates"`` (optional): Dask/NumPy array used to set + spatial origin when saving vector outputs. output_type (str): - Desired output format. - Supported values are "dict", "zarr", and "annotationstore". + Desired output format. Supported values are: + ``"dict"``, ``"zarr"``, or ``"annotationstore"`` (case-sensitive). save_path (Path | None): - Path to save the output file. - Required for "zarr" and "annotationstore" formats. - **kwargs (EngineABCRunParams): - Additional runtime parameters to update engine attributes. + Base filesystem path for file outputs. Required for + ``"zarr"`` and ``"annotationstore"``. For Zarr, a + ``save_path.with_suffix(".zarr")`` group is used. For + AnnotationStore, ``.db`` files are written (one per image in + patch mode, one per WSI in WSI mode). Ignored when + ``output_type="dict"``. + **kwargs (MultiTaskSegmentorRunParams): + Additional runtime parameters to configure segmentation. + Optional Keys: auto_get_mask (bool): Automatically generate segmentation masks using @@ -1460,34 +1484,62 @@ def save_predictions( Mapping of classification outputs to class names. device (str): Device to run the model on (e.g., "cpu", "cuda"). - See :class:`torch.device` for more details. + labels (list): + Optional labels for input images. Only a single label per image + is supported. memory_threshold (int): Memory usage threshold (percentage) to trigger caching behavior. num_workers (int): Number of workers for DataLoader and post-processing. output_file (str): - Filename for saving output (e.g., "zarr" or "annotationstore"). + Filename for saving output (e.g., ".zarr" or ".db"). + output_resolutions (Resolution): + Resolution used for writing output predictions. + patch_output_shape (tuple[int, int]): + Shape of output patches (height, width). + return_labels (bool): + Whether to return labels with predictions. + return_predictions (tuple(bool, ...): + Whether to return array predictions for individual tasks. + return_probabilities (bool): + Whether to return per-class probabilities. scale_factor (tuple[float, float]): Scale factor for annotations (model_mpp / slide_mpp). - Used to convert coordinates from non-baseline to baseline - resolution. - stride_shape (IntPair): - Stride used during WSI processing, at requested read resolution. - Must be positive. Defaults to `patch_input_shape` if not - provided. + Used to convert coordinates to baseline resolution. + stride_shape (tuple[int, int]): + Stride used during WSI processing. + Defaults to `patch_input_shape` if not provided. verbose (bool): Whether to enable verbose logging. Returns: - dict | AnnotationStore | Path | list [Path]: - - If output_type is "dict": returns predictions as a dictionary. - - If output_type is "zarr": returns path to saved zarr file. - - If output_type is "annotationstore": returns an AnnotationStore - or path to .db file. + dict | AnnotationStore | Path | list[Path]: + - If ``output_type == "dict"``: + Returns the (possibly simplified) prediction dictionary. + For a single task, the task level is flattened. + - If ``output_type == "zarr"``: + Returns the ``Path`` to the saved ``.zarr`` group. + - If ``output_type == "annotationstore"``: + Returns a list of paths to saved ``.db`` files (patch mode), + or a single path / store handle for WSI mode. If probability + maps were requested for saving, the Zarr path holding those + maps may also be included. Raises: TypeError: - If an unsupported output_type is provided. + If an unsupported ``output_type`` is provided. + + Notes: + - For ``"dict"`` and ``"zarr"``, saving is delegated to + ``_save_predictions_as_dict_zarr`` to keep behavior aligned across + engines. + - When ``output_type == "annotationstore"``, arrays are first computed + (via a Zarr/dict pass) to obtain concrete NumPy payloads suitable + for vector export, after which per-task stores are written using + ``_save_predictions_as_annotationstore``. + - If ``return_probabilities=True``, probability maps are written only + to Zarr, never to AnnotationStore. A guidance message is logged + describing how to visualize heatmaps (e.g., converting to OME-TIFF). """ if output_type in ["dict", "zarr"]: @@ -1569,7 +1621,7 @@ def run( output_type: str = "dict", **kwargs: Unpack[MultiTaskSegmentorRunParams], ) -> AnnotationStore | Path | str | dict | list[Path]: - """Run the semantic segmentation engine on input images. + """Run the `MultiTaskSegmentor` engine on input images. This method orchestrates the full inference pipeline, including preprocessing, model inference, post-processing, and saving results. It supports both @@ -1614,7 +1666,9 @@ def run( Mapping of classification outputs to class names. device (str): Device to run the model on (e.g., "cpu", "cuda"). - + labels (list): + Optional labels for input images. Only a single label per image + is supported. memory_threshold (int): Memory usage threshold (percentage) to trigger caching behavior. num_workers (int): @@ -1626,7 +1680,9 @@ def run( patch_output_shape (tuple[int, int]): Shape of output patches (height, width). return_labels (bool): - Whether to return labels with predictions. Should be False. + Whether to return labels with predictions. + return_predictions (tuple(bool, ...): + Whether to return array predictions for individual tasks. return_probabilities (bool): Whether to return per-class probabilities. scale_factor (tuple[float, float]): @@ -1647,12 +1703,12 @@ def run( Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] >>> image_patches = [np.ndarray, np.ndarray] - >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") - >>> output = segmentor.run(image_patches, patch_mode=True) + >>> mtsegmentor = MultiTaskSegmentor(model="hovernet_fast-pannuke") + >>> output = mtsegmentor.run(image_patches, patch_mode=True) >>> output ... "/path/to/Output.db" - >>> output = segmentor.run( + >>> output = mtsegmentor.run( ... image_patches, ... patch_mode=True, ... output_type="zarr" @@ -1660,7 +1716,7 @@ def run( >>> output ... "/path/to/Output.zarr" - >>> output = segmentor.run(wsis, patch_mode=False) + >>> output = mtsegmentor.run(wsis, patch_mode=False) >>> output.keys() ... ['wsi1.svs', 'wsi2.svs'] >>> output['wsi1.svs'] @@ -1697,7 +1753,60 @@ def dict_to_store( origin: tuple[float, float] = (0, 0), scale_factor: tuple[float, float] = (1, 1), ) -> AnnotationStore: - """Helper function to convert dict to store.""" + """Write polygonal multitask predictions into an SQLite-backed AnnotationStore. + + Converts a task dictionary (with per-object fields) into `Annotation` records, + applying coordinate scaling and translation to move predictions into the slide's + baseline coordinate space. Each geometry is created from the per-object + `"contours"` entry, validated, and shifted by `origin`. All remaining keys in + `processed_predictions` are attached as annotation properties; the `"type"` key + can be mapped via `class_dict`. + + Expected `processed_predictions` structure: + - "contours": list-like of polygon coordinates per object, where each item + is shaped like `[[x0, y0], [x1, y1], ..., [xN, yN]]`. These are interpreted + according to `"geom_type"` (default `"Polygon"`). + - Optional "geom_type": str (e.g., "Polygon", "MultiPolygon"). + Defaults to "Polygon". + - Additional per-object fields (e.g., "type", "probability", scores, attributes) + with list-like values aligned to `contours` length. + + Args: + store (SQLiteStore): + Target annotation store that will receive the converted annotations. + processed_predictions (dict): + Dictionary containing per-object fields. Must include `"contours"`; + may include `"geom_type"` and any number of additional fields to be + written as properties. + class_dict (dict | None): + Optional mapping for the `"type"` field. When provided and when + `"type"` is present in `processed_predictions`, each `"type"` value is + replaced by `class_dict[type_id]` in the saved annotation properties. + origin (tuple[float, float]): + `(x0, y0)` offset to add to the final geometry coordinates (in pixels) + after scaling. Typically corresponds to the tile/patch origin in WSI + space. + scale_factor (tuple[float, float]): + `(sx, sy)` factors applied to coordinates before translation, used to + convert from model space to baseline slide resolution (e.g., + `model_mpp / slide_mpp`). + + Returns: + AnnotationStore: + The input `store` after appending all converted annotations. + + Notes: + - Geometries are constructed from `processed_predictions["contours"]` using + `geom_type` (default `"Polygon"`), scaled by `scale_factor`, and translated + by `origin`. Invalid geometries are auto-corrected using `make_valid_poly`. + - Per-object properties are created by taking the i-th element from each + remaining key in `processed_predictions`. Scalars are coerced to arrays + first, then converted with `.tolist()` to ensure JSON-serializable values. + - If `class_dict` is provided and a `"type"` key exists, `"type"` values are + mapped prior to saving. + - All annotations are appended in a single batch via `store.append_many(...)`. + + """ contour = processed_predictions.pop("contours") ann = [] From c6b7d72355e105531197f0e3d880db43793757af Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 22:46:13 +0100 Subject: [PATCH 067/156] :memo: Update docstrings. --- .../models/engine/multi_task_segmentor.py | 278 ++++++++++++++++-- 1 file changed, 260 insertions(+), 18 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 08cacba12..12443e970 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1848,30 +1848,74 @@ def prepare_multitask_full_batch( *, is_last: bool, ) -> tuple[list[np.ndarray], np.ndarray, np.ndarray]: - """Prepare full-sized output and count arrays for a batch of patch predictions. - - This function aligns patch-level predictions with global output locations when - a mask (e.g., auto_get_mask) is applied. It initializes full-sized arrays and - fills them using matched indices. If the batch is the last in the sequence, - it pads the arrays to cover remaining locations. + """Align patch predictions to the global output index and pad to cover gaps. + + This helper prepares a *full-sized* set of outputs for the current batch by + aligning patch-level predictions with the remaining global output locations. + It uses the provided `full_output_locs` (the outstanding locations yet to be + filled) to place each patch's predictions at the correct indices, returning + arrays sized to the current span. If this is the final batch (`is_last=True`), + it pads the arrays with zeros to cover any remaining, unmatched output + locations and appends those locations to `output_locs`. + + Concretely: + 1) A lookup is built over `full_output_locs` so each row in `batch_locs` + maps to a unique index (“match”). + 2) For each head in `batch_output`, an appropriately sized zero-initialized + array is created and the matched batch predictions are placed at the + computed indices. + 3) `output_locs` is extended by the portion of `full_output_locs` covered + in this call; `full_output_locs` is advanced accordingly. + 4) If `is_last=True`, the function also appends any remaining locations to + `output_locs` and pads the per-head arrays with zeros so their first + dimension matches the updated number of locations. Args: - batch_output (np.ndarray): - Patch-level model predictions of shape (N, H, W, C). + batch_output (tuple[np.ndarray]): + Tuple of per-head patch predictions for the current batch. Each + element has shape ``(N, H, W, C)`` (head-specific), where ``N`` is + the number of patches in the batch. batch_locs (np.ndarray): - Output locations corresponding to `batch_output`. + Array of output locations (e.g., patch output boxes) corresponding + to `batch_output`. Each row must uniquely identify a location and + match rows in `full_output_locs`. full_output_locs (np.ndarray): - Remaining global output locations to be matched. + The remaining global output location array, carrying the canonical + order of all locations that should be filled. This is progressively + consumed from the front as batches are placed. output_locs (np.ndarray): - Accumulated output location array across batches. + Accumulated output location array across previous batches. This is + extended in-place with the portion of `full_output_locs` filled in + this call, and with any remaining tail (zeros padded in outputs) + when `is_last=True`. is_last (bool): - Flag indicating whether this is the final batch. + Whether this is the final batch. When True, any locations left in + `full_output_locs` after placing matches are appended to + `output_locs`, and the per-head output arrays are padded with zeros + to match the total number of output locations. Returns: tuple[list[np.ndarray], np.ndarray, np.ndarray]: - - full_batch_output: Full-sized output array with predictions placed. - - full_output_locs: Updated remaining global output locations. - - output_locs: Updated accumulated output locations. + - full_batch_output (list[np.ndarray]): + One array per head containing the aligned outputs for this call. + Each has shape ``(M, H, W, C)``, where ``M`` is the number of + locations consumed (and possibly padded to include the remaining + tail when `is_last=True`). + - full_output_locs (np.ndarray): + Updated remaining global output locations (the unconsumed tail). + - output_locs (np.ndarray): + Updated accumulated output locations including those added by + this call (and any final tail when `is_last=True`). + + Notes: + - Ordering is defined by `full_output_locs`. The number of rows + consumed during this call equals ``max(match_indices) + 1``. + - Padding on the last batch is performed with zeros of the same dtype + as each head's predictions (uint8 for the padded section in the + implementation). + - This function is agnostic to the semantic meaning of locations; it + only ensures that per-head arrays and the accumulated location index + remain consistent across batches. """ # Use np.intersect1d once numpy version is upgraded to 2.0 @@ -1919,7 +1963,87 @@ def merge_multitask_horizontal( output_locs: np.ndarray, change_indices: np.ndarray | list[int], ) -> tuple[list[da.Array], list[da.Array], list[np.ndarray], np.ndarray, np.ndarray]: - """Merge horizontal patches incrementally for each row of patches.""" + """Merge horizontally a run of patch outputs into per-head row blocks. + + This helper performs **row-wise stitching** of patch predictions for + multitask heads. It consumes the leftmost segment of ``canvas_np`` (per head) + up to each index in ``change_indices``—which mark where the dataloader + advanced to a new row of output patches—and merges that segment into a + horizontally concatenated row block for each head. The merged blocks and + their per-pixel hit counts are appended to ``canvas`` and ``count`` (as + Dask arrays with chunking equal to the merged row height), while the consumed + portion is removed from ``canvas_np``. The function also updates and returns + ``output_locs`` (with the consumed locations removed) and accumulates the + vertical extents of each merged row in ``output_locs_y_``. + + For each row segment: + 1) The function determines the row's horizontal span from + ``output_locs`` (min x0, max x1). + 2) For each head, it calls ``merge_batch_to_canvas`` to place the segment's + patch outputs into a contiguous row block and an aligned count map. + 3) The row block and count map are wrapped as Dask arrays and appended to + the running lists in ``canvas`` and ``count`` (one list per head). + 4) The segment is removed from ``canvas_np`` and ``output_locs``; the + segment's vertical bounds ``(y0, y1)`` are appended to ``output_locs_y_``. + + Args: + canvas (list[da.Array] | list[None]): + Accumulated per-head row blocks (probability/logit sums) as Dask + arrays. Each entry grows along the first axis with each merged row. + Pass ``None`` for each head on the first call. + count (list[da.Array] | list[None]): + Accumulated per-head row count maps, aligned with ``canvas``. + Pass ``None`` for each head on the first call. + output_locs_y_ (np.ndarray): + Accumulated vertical extents of already-merged rows. Each appended + element is ``[y0, y1]`` corresponding to the merged row's span. + Pass ``None`` on the first call; it will be initialized internally + via concatenation. + canvas_np (list[np.ndarray]): + In-memory patch outputs awaiting merge, one list entry per head. + Each head's entry is a NumPy array of stacked patch outputs for the + **current** unmerged part of the row, with shape + ``(N_seg, H, W, C)`` for the segment being merged. + output_locs (np.ndarray): + Output placement boxes for the awaiting patches in ``canvas_np``, + shaped ``(N_pending, 4)`` as ``[x0, y0, x1, y1]``. The function + consumes from the front up to each ``change_indices`` boundary and + returns the remaining tail. + change_indices (np.ndarray | list[int]): + Sorted indices (relative to the current ``output_locs``) where a + **row change** occurs. Each index marks the end of a contiguous row + segment to be merged in this call. + + Returns: + tuple[list[da.Array], list[da.Array], list[np.ndarray], np.ndarray, np.ndarray]: + - ``canvas``: + Updated list of per-head Dask arrays containing concatenated row + blocks (values are sums; normalization happens later). + - ``count``: + Updated list of per-head Dask arrays containing concatenated row + hit counts for normalization. + - ``canvas_np``: + Updated in-memory per-head arrays with consumed segment removed. + - ``output_locs``: + Updated placement boxes with the consumed segment removed. + - ``output_locs_y_``: + Updated array of accumulated vertical row extents, with the new + row's ``[y0, y1]`` appended. + + Notes: + - The merged row block shape per head is + ``(row_height, row_width, C)``, where: + * ``row_height`` is the head's patch output height, + * ``row_width`` is ``max(x1) - min(x0)`` for the row, + * ``C`` is the number of channels for that head. + - ``merge_batch_to_canvas`` handles placement and accumulation of + overlapping patch outputs and produces a matching count map. + - Normalization (division by counts) is **not** performed here; it is + done later during vertical merging to form the final probability maps. + - Dask chunking is set to the full row height to facilitate subsequent + vertical concatenation and overlap handling. + + """ start_idx = 0 for c_idx in change_indices: output_locs_ = output_locs[: c_idx - start_idx] @@ -1962,7 +2086,60 @@ def save_multitask_to_cache( count_zarr: list[zarr.Array | None], save_path: str | Path = "temp.zarr", ) -> tuple[list[zarr.Array], list[zarr.Array]]: - """Save computed canvas and count list of arrays to Zarr cache.""" + """Write accumulated horizontal row blocks to a Zarr cache on disk. + + This function is called when intermediate per-head accumulators + (``canvas`` and ``count``) become large enough to risk exceeding the + memory threshold. It computes the current Dask arrays for each head, + writes them to Zarr datasets under ``save_path``, and updates + ``canvas_zarr`` / ``count_zarr`` so later merges operate directly on + Zarr-backed arrays rather than holding everything in memory. + + For each head: + 1) The corresponding ``canvas`` and ``count`` Dask arrays are fully + computed. + 2) If this is the first time spilling for that head, new Zarr datasets + are created using chunk shapes consistent with the canvas rows. + 3) The computed rows are appended to the Zarr datasets by resizing the + arrays and writing the new rows at the end. + 4) The updated Zarr arrays are returned to be wrapped by Dask in later + steps. + + Args: + canvas (list[da.Array]): + Accumulated per-head row blocks (probability/logit sums). Each + head's entry has shape ``(N_rows, H, W, C)`` where ``N_rows`` grows + as horizontal rows are merged. + count (list[da.Array]): + Accumulated per-head row hit counts aligned with ``canvas``, + with matching shape and chunking. + canvas_zarr (list[zarr.Array | None]): + List of Zarr datasets for storing accumulated ``canvas`` values + per head. ``None`` entries indicate that no Zarr datasets have + been created yet for those heads. + count_zarr (list[zarr.Array | None]): + List of Zarr datasets mirroring ``canvas_zarr`` but storing hit + counts instead of accumulated values. + save_path (str | Path): + Path to the Zarr group used for caching. A new group is created + if needed on the first spill. + + Returns: + tuple[list[zarr.Array], list[zarr.Array]]: + Updated ``canvas_zarr`` and ``count_zarr`` lists, where each head + now has a Zarr dataset containing all accumulated rows up to this + point. + + Notes: + - Chunking for the Zarr datasets follows the Dask chunk size along + the row axis to allow efficient later vertical merging. + - This function does **not** normalize probabilities; normalization + happens in the final vertical merge via + ``merge_multitask_vertical_chunkwise``. + - After spilling, upstream functions will reset in-memory ``canvas`` + and ``count`` to free RAM and continue populating new entries. + + """ zarr_group = None for idx, canvas_ in enumerate(canvas): computed_values = compute(*[canvas_, count[idx]]) @@ -2016,7 +2193,71 @@ def merge_multitask_vertical_chunkwise( save_path: Path, memory_threshold: int = 80, ) -> list[da.Array]: - """Merge vertically chunked arrays into a single probability map.""" + """Merge horizontally stitched row blocks into final WSI probability maps. + + After horizontal stitching, each head has a stack of row blocks (values) and + matching row-wise count maps. This function merges those rows **vertically**, + resolving overlaps between adjacent rows using the provided `output_locs_y_` + spans. For each head and row boundary, overlapping rows are summed in the + overlap region, then normalized by the corresponding summed counts. The + normalized row is appended to a Zarr-backed or Dask-backed accumulator to + build the final full-height probability map. + + Concretely, for each head: + 1) Iterate across row boundaries using `output_locs_y_`, compute overlap height. + 2) If there is an overlap with the next row, add overlapping slices from + the next row's canvas and count into the tail of the current row. + 3) Normalize the current row by its count map (with zero-division guarded). + 4) Append normalized rows to Zarr (or keep in-memory) via `store_probabilities`. + 5) Periodically spill in-memory arrays to Zarr when memory exceeds + `memory_threshold` (via `_save_multitask_vertical_to_cache`). + 6) After processing all rows, clear temporary Zarr datasets for canvas/count + and return a Dask view (from Zarr if spilled, otherwise from memory). + + Args: + canvas (list[da.Array]): + Per-head Dask arrays of horizontally merged **row blocks** (sums). + For each head `h`, `canvas[h]` has shape + `(N_rows, row_height, row_width, C)`, chunked along the row axis. + count (list[da.Array]): + Per-head Dask arrays of **row-wise hit counts** matching `canvas`. + output_locs_y_ (np.ndarray): + Array of shape `(N_rows, 2)` where each row is `[y0, y1]` indicating + the vertical extent of the corresponding row block in slide + coordinates. Overlaps are computed as `prev_y1 - next_y0`. + zarr_group (zarr.Group): + Zarr group used to create/append the per-head probability datasets + (under `"probabilities/{idx}"`) and to clear temporary `"canvas"` and + `"count"` datasets after finalization. + save_path (Path): + Base path of the Zarr store (used when spilling additional data and + when returning Zarr-backed Dask arrays). + memory_threshold (int): + Maximum allowed RAM usage (percentage) before converting in-memory + probability accumulators to Zarr-backed arrays. Default is 80. + + Returns: + list[da.Array]: + One Dask array per head, each representing the **final** WSI-sized + probability map with shape `(H, W, C)`. If spilling occurred, these + are backed by Zarr datasets created under `zarr_group`; otherwise + they are in-memory Dask arrays. + + Notes: + - Overlaps along the vertical direction are handled by **additive merge** + of both values and counts, followed by normalization. Non-overlapping + regions are passed through unchanged. + - Zero counts are guarded by replacing with 1 during normalization to + avoid division by zero; this is safe because values are zero where + counts are zero. + - Chunking along the first axis (row blocks) is preserved to facilitate + incremental appends and memory spill; final arrays are exposed with + appropriate Dask chunking for downstream use. + - Temporary row-level `"canvas/*"` and `"count/*"` datasets are deleted + before returning when Zarr-backed accumulators are used (see + `_clear_zarr`). + + """ y0s, y1s = np.unique(output_locs_y_[:, 0]), np.unique(output_locs_y_[:, 1]) overlaps = np.append(y1s[:-1] - y0s[1:], 0) @@ -2464,6 +2705,7 @@ def _get_margin_lines( width: int, tile_tl: tuple[int, int], ) -> list: + """Helper function to get margin lines.""" # ! this is wrt to WSI coord space, not tile margin_lines = [ [[margin, margin], [width - margin, margin]], # top egde From c2f0ac6cde4e0fe68211909ab541ac16acc629e1 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 22:49:52 +0100 Subject: [PATCH 068/156] :memo: Update docstrings. --- .../models/engine/multi_task_segmentor.py | 113 +++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 12443e970..954adc538 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1,4 +1,115 @@ -"""This module enables multi-task segmentor.""" +"""Multi-task segmentation engine for computational pathology. + +This module implements the :class:`MultiTaskSegmentor` and supporting utilities to +run multi-head segmentation models (e.g., HoVerNet/HoVerNetplus-style architectures) +on histology images in both patch and whole slide image (WSI) workflows. +It provides consistent orchestration for data loading, model invocation, tiled +stitching, memory-aware caching, post-processing per task, and saving outputs +to in-memory dictionaries, Zarr, or AnnotationStore (.db). + +Overview + - **Patch mode**: `infer_patches` runs a model on batches of image patches, + producing one probability/logit tensor per model head. `post_process_patches` + converts these into task-specific outputs (e.g., semantic maps, instances). + - **WSI mode**: `infer_wsi` iterates over all WSI patches, assembles head + outputs via horizontal row-merge and vertical normalization, and returns + WSI-scale probability maps per head. `post_process_wsi` consumes these + maps in either full-WSI or tile mode (for Zarr-backed arrays), deriving + task-centric outputs and merging instances across tile boundaries. + - **Memory awareness**: intermediate accumulators spill to Zarr automatically + once usage exceeds a configurable `memory_threshold`, enabling processing + of very large slides on limited RAM. + +Key Classes + MultiTaskSegmentor + Core engine for multi-head segmentation. Extends :class:`SemanticSegmentor` + to run models with multiple output heads and to produce task-centric + predictions after post-processing. Supports patch and WSI workflows, + dict/Zarr/AnnotationStore outputs, and device/batch/stride configuration. + + MultiTaskSegmentorRunParams + TypedDict of runtime parameters used across the engine. Extends + :class:`SemanticSegmentorRunParams` with additional multitask options: + `return_predictions`, `return_probabilities`, `memory_threshold`, etc. + +Important Functions + infer_patches(dataloader, *, return_coordinates=False) -> dict + Run model on a collection of patches; returns per-head probabilities + as Dask arrays and optionally patch coordinates. + + infer_wsi(dataloader, save_path, **kwargs) -> dict + Run model on a WSI via patch extraction and incremental stitching, with + optional Zarr caching when memory pressure is high. + + post_process_patches(raw_predictions, **kwargs) -> dict + Apply the model's post-processing per patch and reorganize results into + a task-centric dictionary (e.g., "semantic", "instance"). + + post_process_wsi(raw_predictions, save_path, **kwargs) -> dict + Convert WSI-scale head maps into task-specific outputs, either in memory + (full-WSI) or via tile-mode with instance de-duplication across tile + boundaries. + + save_predictions(processed_predictions, output_type, save_path=None, **kwargs) + Persist results as `dict`, `zarr`, or `annotationstore`. Probability maps + are saved to Zarr; vector outputs are written to AnnotationStore. + + Helper utilities + - build_post_process_raw_predictions(...) + Group per-image outputs by task and normalize array/dict payloads. + - prepare_multitask_full_batch(...) + Align a batch's predictions to global output indices and pad the tail. + - merge_multitask_horizontal(...) + Row-wise stitching of patch predictions for each head. + - save_multitask_to_cache(...) + Spill accumulated row blocks (canvas/count) to Zarr. + - merge_multitask_vertical_chunkwise(...) + Normalize and merge rows vertically into final WSI probability maps. + - dict_to_store(...) + Convert polygonal task predictions to :class:`AnnotationStore` records. + +Inputs and Outputs + - **Inputs**: lists of file paths or :class:`WSIReader` instances (WSI mode), + or `np.ndarray` patches (NHWC) in patch mode. Optional masks and IO configs + control extraction resolution, patch/tile shapes, and stride. + - **Raw outputs**: per-head probability maps/logits as Dask arrays + (patch- or WSI-scale). + - **Post-processed outputs**: task-centric dictionaries (e.g., instance tables, + semantic predictions), optionally including full-resolution prediction arrays + if requested via `return_predictions`. + - **Saved outputs**: + * `dict`: in-memory Python structures + * `zarr`: hierarchical arrays (optionally with probability maps) + * `annotationstore`: SQLite-backed vector annotations (.db) + +Examples: + Patch-mode prediction: + >>> patches = [np.ndarray, np.ndarray] # NHWC + >>> mt = MultiTaskSegmentor(model="hovernetplus-oed", device="cuda") + >>> out = mt.run(patches, patch_mode=True, output_type="dict") + + WSI-mode prediction with Zarr caching and AnnotationStore output: + >>> wsis = [Path("slide1.svs"), Path("slide2.svs")] + >>> mt = MultiTaskSegmentor(model="hovernet_fast-pannuke", device="cuda") + >>> out = mt.run( + ... wsis, + ... patch_mode=False, + ... save_dir=Path("outputs/"), + ... output_type="annotationstore", + ... memory_threshold=80, + ... auto_get_mask=True, + ... overwrite=True, + ... ) + +Notes: + - The engine infers the number of model heads from the first `infer_batch` + call and maintains per-head arrays throughout merging. + - Probability normalization is performed during the final vertical merge + (row accumulation divided by row counts). + - Probability maps are not written to AnnotationStore; use Zarr to persist + them and convert to OME-TIFF separately if needed for visualization. + +""" from __future__ import annotations From d8ed2f5a745a5919fc1bd1631b3363022ef4215a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 30 Jan 2026 23:35:04 +0100 Subject: [PATCH 069/156] :pushpin: Pin `dask>=2026.1.2` - dask version 2026.1.2 has breaking changes --- requirements/requirements.txt | 2 +- tiatoolbox/models/engine/engine_abc.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index bca19411e..f48dda33e 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -2,7 +2,7 @@ aiohttp>=3.8.1 albumentations>=1.3.0 bokeh>=3.8.2 Click>=8.2.0 -dask>=2025.12.0 +dask>=2026.1.2 defusedxml>=0.7.1 filelock>=3.9.0 flask>=2.2.2 diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 4b7ba16d1..01722cabf 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -763,9 +763,7 @@ def _get_tasks_for_saving_zarr( url=save_path, component=component, compute=False, - zarr_array_kwargs={ - "object_codec": object_codec, - }, + object_codec=object_codec, # zarr kwargs ) write_tasks.append(task) @@ -781,9 +779,7 @@ def _get_tasks_for_saving_zarr( url=save_path, component=component, compute=False, - zarr_array_kwargs={ - "object_codec": object_codec, - }, + object_codec=object_codec, # zarr kwargs ) write_tasks.append(task) From ccc7fd499159de9c08036a03a17b15482fae2f5c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 31 Jan 2026 02:17:57 +0100 Subject: [PATCH 070/156] :white_check_mark: Add CLI support. --- tests/engines/test_multi_task_segmentor.py | 43 ++++- tiatoolbox/cli/__init__.py | 2 + tiatoolbox/cli/common.py | 65 ++++++++ tiatoolbox/cli/multitask_segmentor.py | 150 ++++++++++++++++++ .../models/engine/multi_task_segmentor.py | 38 ++++- 5 files changed, 289 insertions(+), 9 deletions(-) create mode 100644 tiatoolbox/cli/multitask_segmentor.py diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index f140cd942..01c333075 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -12,7 +12,9 @@ import pytest import torch import zarr +from click.testing import CliRunner +from tiatoolbox import cli from tiatoolbox.annotation import SQLiteStore from tiatoolbox.models.engine.multi_task_segmentor import ( MultiTaskSegmentor, @@ -356,7 +358,6 @@ def test_wsi_mtsegmentor_zarr( / len(output_full_["nuclei_segmentation"]["contours"]) > 0.9 ) - wsi4_1k_1k_svs.unlink() def test_multi_input_wsi_mtsegmentor_zarr( @@ -746,3 +747,43 @@ def assert_annotation_store_patch_output( else: assert annotations_geometry_type == [] assert annotations_list == [] + + +# ------------------------------------------------------------------------------------- +# Command Line Interface +# ------------------------------------------------------------------------------------- + + +def test_cli_model_single_file(remote_sample: Callable, track_tmp_path: Path) -> None: + """Test semantic segmentor CLI single file.""" + wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") + runner = CliRunner() + models_wsi_result = runner.invoke( + cli.main, + [ + "multitask-segmentor", + "--img-input", + str(wsi4_512_512_svs), + "--patch-mode", + "False", + "--output-path", + str(track_tmp_path / "output"), + "--return-predictions", + "False, True", + ], + ) + + assert models_wsi_result.exit_code == 0 + assert ( + track_tmp_path / "output" / f"{wsi4_512_512_svs.stem}_layer_segmentation.db" + ).exists() + assert ( + track_tmp_path / "output" / f"{wsi4_512_512_svs.stem}_nuclei_segmentation.db" + ).exists() + zarr_group = zarr.open( + str(track_tmp_path / "output" / f"{wsi4_512_512_svs.stem}.zarr"), mode="r" + ) + assert "probabilities" in zarr_group + assert "nuclei_segmentation" not in zarr_group + assert "layer_segmentation" in zarr_group + assert "predictions" in zarr_group["layer_segmentation"] diff --git a/tiatoolbox/cli/__init__.py b/tiatoolbox/cli/__init__.py index eb908f458..74e01240f 100644 --- a/tiatoolbox/cli/__init__.py +++ b/tiatoolbox/cli/__init__.py @@ -8,6 +8,7 @@ from tiatoolbox import __version__ from tiatoolbox.cli.common import tiatoolbox_cli from tiatoolbox.cli.deep_feature_extractor import deep_feature_extractor +from tiatoolbox.cli.multitask_segmentor import multitask_segmentor from tiatoolbox.cli.nucleus_detector import nucleus_detector from tiatoolbox.cli.nucleus_instance_segment import nucleus_instance_segment from tiatoolbox.cli.patch_predictor import patch_predictor @@ -45,6 +46,7 @@ def main() -> int: main.add_command(read_bounds) main.add_command(save_tiles) main.add_command(semantic_segmentor) +main.add_command(multitask_segmentor) main.add_command(nucleus_detector) main.add_command(deep_feature_extractor) main.add_command(slide_info) diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index d239a6ec5..f4396a7b0 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -527,6 +527,71 @@ def cli_return_probabilities( ) +def parse_bool_list( + ctx: click.Context, # noqa: ARG001 + param: click.Parameter, # noqa: ARG001 + value: str | None, +) -> tuple[bool, ...] | None: + """Parse a comma-separated list of boolean values for a Click option. + + This function is intended for use as a Click callback. It converts a + comma-separated string (e.g., ``"true,false,1,0"``) into a tuple of Python + booleans. Each item is stripped, lowercased, and validated against a set of + accepted truthy and falsy representations. + + Accepted truthy values: + ``"true"``, ``"1"``, ``"yes"``, ``"y"`` + + Accepted falsy values: + ``"false"``, ``"0"``, ``"no"``, ``"n"`` + + Args: + ctx (click.Context): + The Click context object (unused but required by Click callback API). + param (click.Parameter): + The Click parameter object (unused but required by Click callback API). + value (str | None): + The raw string provided by the user. If ``None``, the function returns + ``None`` unchanged. + + Returns: + tuple[bool, ...] | None: + A tuple of parsed boolean values, or ``None`` if no value was provided. + + Raises: + click.BadParameter: + If any item in the comma-separated list is not a valid boolean string. + """ + if value is None: + return None + items = value.split(",") + out = [] + for item in items: + item_ = item.strip().lower() + if item_ in ("true", "1", "yes", "y"): + out.append(True) + elif item_ in ("false", "0", "no", "n"): + out.append(False) + else: + msg = f"Invalid boolean: {item_}" + raise click.BadParameter(msg) + return tuple(out) + + +def cli_return_predictions( + usage_help: str = "Whether to return raw model probabilities.", + *, + default: tuple[bool, ...] | None = None, +) -> Callable: + """Enables --return-probabilities option for cli.""" + return click.option( + "--return-predictions", + callback=parse_bool_list, + help=add_default_to_usage_help(usage_help, default=default), + default=default, + ) + + def cli_merge_predictions( usage_help: str = "Whether to merge the predictions to form a 2-dimensional map.", *, diff --git a/tiatoolbox/cli/multitask_segmentor.py b/tiatoolbox/cli/multitask_segmentor.py new file mode 100644 index 000000000..6647d144a --- /dev/null +++ b/tiatoolbox/cli/multitask_segmentor.py @@ -0,0 +1,150 @@ +"""Command line interface for multitask segmentation.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from tiatoolbox.cli.common import ( + cli_auto_get_mask, + cli_batch_size, + cli_device, + cli_file_type, + cli_img_input, + cli_input_resolutions, + cli_masks, + cli_memory_threshold, + cli_model, + cli_num_workers, + cli_output_file, + cli_output_path, + cli_output_resolutions, + cli_output_type, + cli_overwrite, + cli_patch_input_shape, + cli_patch_mode, + cli_patch_output_shape, + cli_return_predictions, + cli_return_probabilities, + cli_scale_factor, + cli_stride_shape, + cli_verbose, + cli_weights, + cli_yaml_config_path, + prepare_ioconfig, + prepare_model_cli, + tiatoolbox_cli, +) + +if TYPE_CHECKING: # pragma: no cover + from tiatoolbox.type_hints import IntPair + + +@tiatoolbox_cli.command() +@cli_img_input() +@cli_output_path( + usage_help="Output directory where model segmentation will be saved.", + default="semantic_segmentation", +) +@cli_output_file(default=None) +@cli_file_type( + default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs", +) +@cli_input_resolutions(default=None) +@cli_output_resolutions(default=None) +@cli_model(default="hovernetplus-oed") +@cli_weights() +@cli_device(default="cpu") +@cli_batch_size(default=1) +@cli_yaml_config_path() +@cli_masks(default=None) +@cli_num_workers(default=0) +@cli_output_type( + default="AnnotationStore", +) +@cli_memory_threshold(default=80) +@cli_patch_input_shape(default=None) +@cli_patch_output_shape(default=None) +@cli_stride_shape(default=None) +@cli_scale_factor(default=None) +@cli_patch_mode(default=False) +@cli_return_predictions(default=None) +@cli_return_probabilities(default=True) +@cli_auto_get_mask(default=True) +@cli_overwrite(default=False) +@cli_verbose(default=True) +def multitask_segmentor( + model: str, + weights: str, + img_input: str, + file_types: str, + input_resolutions: list[dict], + output_resolutions: list[dict], + masks: str | None, + output_path: str, + patch_input_shape: IntPair | None, + patch_output_shape: tuple[int, int] | None, + stride_shape: IntPair | None, + scale_factor: tuple[float, float] | None, + batch_size: int, + yaml_config_path: str, + num_workers: int, + device: str, + output_type: str, + memory_threshold: int, + output_file: str | None, + *, + patch_mode: bool, + return_predictions: tuple[bool, ...] | None, + return_probabilities: bool, + auto_get_mask: bool, + verbose: bool, + overwrite: bool, +) -> None: + """Process a set of input images with a semantic segmentation engine.""" + from tiatoolbox.models import IOSegmentorConfig, MultiTaskSegmentor # noqa: PLC0415 + + files_all, masks_all, output_path = prepare_model_cli( + img_input=img_input, + output_path=output_path, + masks=masks, + file_types=file_types, + ) + + ioconfig = prepare_ioconfig( + IOSegmentorConfig, + pretrained_weights=weights, + yaml_config_path=yaml_config_path, + ) + + mtsegmentor = MultiTaskSegmentor( + model=model, + weights=weights, + batch_size=batch_size, + num_workers=num_workers, + verbose=verbose, + ) + + _ = mtsegmentor.run( + images=files_all, + masks=masks_all, + patch_mode=patch_mode, + patch_input_shape=patch_input_shape, + patch_output_shape=patch_output_shape, + input_resolutions=input_resolutions, + output_resolutions=output_resolutions, + batch_size=batch_size, + ioconfig=ioconfig, + device=device, + save_dir=output_path, + output_type=output_type, + return_predictions=return_predictions, + return_probabilities=return_probabilities, + auto_get_mask=auto_get_mask, + memory_threshold=memory_threshold, + num_workers=num_workers, + output_file=output_file, + scale_factor=scale_factor, + stride_shape=stride_shape, + overwrite=overwrite, + verbose=verbose, + ) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 954adc538..cb5ca1aac 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -257,6 +257,9 @@ class MultiTaskSegmentor(SemanticSegmentor): IO configuration for patch extraction and resolution. return_labels (bool): Whether to include labels in the output. + return_predictions (dict): + This dictionary helps keep track of which tasks require predictions in + the output. input_resolutions (list[dict]): Resolution settings for model input. Supported units are `level`, `power` and `mpp`. Keys should be "units" and @@ -340,6 +343,7 @@ def __init__( """ self.tasks = set() + self.return_predictions_dict = {} super().__init__( model=model, batch_size=batch_size, @@ -676,7 +680,7 @@ def infer_wsi( def post_process_patches( # skipcq: PYL-R0201 self: MultiTaskSegmentor, raw_predictions: dict, - **kwargs: Unpack[MultiTaskSegmentorRunParams], # noqa: ARG002 + **kwargs: Unpack[MultiTaskSegmentorRunParams], ) -> dict: """Post-process raw patch-level predictions for multitask segmentation. @@ -770,6 +774,7 @@ def post_process_patches( # skipcq: PYL-R0201 return self.build_post_process_raw_predictions( post_process_predictions=post_process_predictions, raw_predictions=raw_predictions, + return_predictions=kwargs.get("return_predictions"), ) def post_process_wsi( # skipcq: PYL-R0201 @@ -892,8 +897,11 @@ def post_process_wsi( # skipcq: PYL-R0201 ) tasks = set() - for seg in post_process_predictions: + for idx, seg in enumerate(post_process_predictions): task_name = seg["task_type"] + self.return_predictions_dict[task_name] = ( + return_predictions[idx] if return_predictions is not None else False + ) tasks.add(task_name) raw_predictions[task_name] = {} @@ -1312,6 +1320,7 @@ def build_post_process_raw_predictions( self: MultiTaskSegmentor, post_process_predictions: list[tuple], raw_predictions: dict, + return_predictions: tuple[bool, ...] | None, ) -> dict: """Merge per-image, per-task outputs into a task-organized prediction structure. @@ -1352,6 +1361,8 @@ def build_post_process_raw_predictions( Dictionary that will be updated **in-place**. It may already contain task entries or unrelated keys (e.g., ``"probabilities"``, ``"coordinates"``). New tasks and fields are added as they appear. + return_predictions (tuple[bool, ...]): + Whether to return array predictions for individual tasks. Returns: dict: @@ -1382,8 +1393,11 @@ def build_post_process_raw_predictions( """ tasks = set() for seg_list in post_process_predictions: - for seg in seg_list: + for idx, seg in enumerate(seg_list): task = seg["task_type"] + self.return_predictions_dict[task] = ( + return_predictions[idx] if return_predictions is not None else False + ) tasks.add(task) # Initialize task entry if needed @@ -1486,13 +1500,16 @@ def _save_predictions_as_annotationstore( logger.info("Saving predictions as AnnotationStore.") - # predictions are not required when saving to AnnotationStore. - for key in ("canvas", "count", "predictions"): + for key in ("canvas", "count"): processed_predictions.pop(key, None) keys_to_compute = list(processed_predictions.keys()) if "probabilities" in keys_to_compute: keys_to_compute.remove("probabilities") + if "predictions" in keys_to_compute: + if not self.return_predictions_dict.get(task_name): + processed_predictions.pop("predictions") + keys_to_compute.remove("predictions") if self.patch_mode: for idx, curr_image in enumerate(self.images): values = [processed_predictions[key][idx] for key in keys_to_compute] @@ -1663,16 +1680,20 @@ def save_predictions( # Save to AnnotationStore return_probabilities = kwargs.get("return_probabilities", False) + return_predictions = kwargs.get("return_predictions", (False,)) + return_predictions_ = any(rp_ is True for rp_ in return_predictions) output_type_ = ( "zarr" - if is_zarr(save_path.with_suffix(".zarr")) or return_probabilities + if is_zarr(save_path.with_suffix(".zarr")) + or return_probabilities + or return_predictions_ else "dict" ) # This runs dask.compute and returns numpy arrays # for saving annotationstore output. class_dict = kwargs.get("class_dict", self.model.class_dict) - if len(self.tasks) == 1: + if len(self.tasks) == 1 and class_dict is not None: kwargs["class_dict"] = class_dict[next(iter(self.tasks))] else: kwargs["class_dict"] = class_dict @@ -1707,7 +1728,8 @@ def save_predictions( **kwargs, ) save_paths += out_path - del processed_predictions[task_name] + if not self.return_predictions_dict[task_name]: + del processed_predictions[task_name] return save_paths From 4ac2ef446ccdfc6fcbd404405aa0b3b871986671 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 31 Jan 2026 02:26:09 +0100 Subject: [PATCH 071/156] :white_check_mark: Update CLI tests. --- tests/test_cli.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_cli.py b/tests/test_cli.py index 58b716c36..f05cfc43d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -10,6 +10,7 @@ cli_class_dict, cli_input_resolutions, cli_output_resolutions, + parse_bool_list, ) @@ -143,3 +144,35 @@ def test_cli_resolutions_not_list(option: str) -> None: ) assert result.exit_code != 0 assert "Must be a JSON list of dictionaries" in result.output + + +def test_parse_bool_list_none() -> None: + """parse_bool_list should return None when value is None.""" + result = parse_bool_list(ctx=None, param=None, value=None) + assert result is None + + +@pytest.mark.parametrize( + ("input_str", "expected"), + [ + ("true,false", (True, False)), + ("1,0", (True, False)), + ("yes,no", (True, False)), + ("y,n", (True, False)), + (" true , 0 , YES ", (True, False, True)), + ], +) +def test_parse_bool_list_valid( + input_str: str, + expected: tuple[bool, ...], +) -> None: + """parse_bool_list should correctly parse valid boolean lists.""" + result = parse_bool_list(ctx=None, param=None, value=input_str) + assert result == expected + + +@pytest.mark.parametrize("bad_value", ["foo", "true,bar", "1,2", "yes,maybe"]) +def test_parse_bool_list_invalid(bad_value: str) -> None: + """parse_bool_list should raise BadParameter on invalid tokens.""" + with pytest.raises(click.BadParameter): + parse_bool_list(ctx=None, param=None, value=bad_value) From 1cbb01be3a58ae01ad9433f5b2b94fa39108f671 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 31 Jan 2026 07:02:00 +0100 Subject: [PATCH 072/156] :bug: Fix ARG001 error --- tiatoolbox/cli/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index f4396a7b0..65b8b026f 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -528,8 +528,8 @@ def cli_return_probabilities( def parse_bool_list( - ctx: click.Context, # noqa: ARG001 - param: click.Parameter, # noqa: ARG001 + _ctx: click.Context, + _param: click.Parameter, value: str | None, ) -> tuple[bool, ...] | None: """Parse a comma-separated list of boolean values for a Click option. From 78019448e4fdd0aa6fb1c87d53289f25b40d3ddf Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 31 Jan 2026 07:12:00 +0100 Subject: [PATCH 073/156] :bug: Fix deepsource and `mypy` errors. --- tests/test_cli.py | 6 +++--- tiatoolbox/cli/common.py | 2 +- tiatoolbox/models/engine/multi_task_segmentor.py | 15 ++++++++++++++- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index f05cfc43d..beb6d831c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -148,7 +148,7 @@ def test_cli_resolutions_not_list(option: str) -> None: def test_parse_bool_list_none() -> None: """parse_bool_list should return None when value is None.""" - result = parse_bool_list(ctx=None, param=None, value=None) + result = parse_bool_list(_ctx=None, _param=None, value=None) assert result is None @@ -167,7 +167,7 @@ def test_parse_bool_list_valid( expected: tuple[bool, ...], ) -> None: """parse_bool_list should correctly parse valid boolean lists.""" - result = parse_bool_list(ctx=None, param=None, value=input_str) + result = parse_bool_list(_ctx=None, _param=None, value=input_str) assert result == expected @@ -175,4 +175,4 @@ def test_parse_bool_list_valid( def test_parse_bool_list_invalid(bad_value: str) -> None: """parse_bool_list should raise BadParameter on invalid tokens.""" with pytest.raises(click.BadParameter): - parse_bool_list(ctx=None, param=None, value=bad_value) + parse_bool_list(_ctx=None, _param=None, value=bad_value) diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index 65b8b026f..aa3253a63 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -587,7 +587,7 @@ def cli_return_predictions( return click.option( "--return-predictions", callback=parse_bool_list, - help=add_default_to_usage_help(usage_help, default=default), + help=add_default_to_usage_help(usage_help, default=None), default=default, ) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index cb5ca1aac..90f539f93 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1415,6 +1415,20 @@ def build_post_process_raw_predictions( raw_predictions[task][key].append(value) + raw_predictions = self._rearrange_raw_predictions_to_per_task_dict( + tasks=tasks, + raw_predictions=raw_predictions, + ) + + self.tasks = tasks + return raw_predictions + + @staticmethod + def _rearrange_raw_predictions_to_per_task_dict( + tasks: set, + raw_predictions: dict, + ) -> dict: + """Rearranges `raw_predictions` to per-task output.""" for task in tasks: task_dict = raw_predictions[task] for key in list(task_dict.keys()): @@ -1431,7 +1445,6 @@ def build_post_process_raw_predictions( del raw_predictions[task][key] - self.tasks = tasks return raw_predictions def _save_predictions_as_dict_zarr( From 3cf850460a965bd8d0b1fed2bdd1e0b68490a749 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 2 Feb 2026 11:33:13 +0000 Subject: [PATCH 074/156] :green_heart: Address Co-Pilot review comments. --- tests/test_utils.py | 44 ------------------- tiatoolbox/cli/common.py | 4 +- tiatoolbox/cli/multitask_segmentor.py | 5 +-- tiatoolbox/models/architecture/hovernet.py | 2 +- .../models/architecture/hovernetplus.py | 2 +- tiatoolbox/models/engine/engine_abc.py | 2 - .../models/engine/multi_task_segmentor.py | 2 +- tiatoolbox/utils/misc.py | 10 ++--- 8 files changed, 9 insertions(+), 62 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index c84670ce3..cde89f120 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,7 +5,6 @@ import hashlib import json import shutil -import tempfile from pathlib import Path from typing import TYPE_CHECKING, NoReturn @@ -2278,46 +2277,3 @@ class FakeVM: assert isinstance(arr, np.ndarray) assert arr.shape == shape assert arr.dtype == dtype - - -def test_creates_temp_dir_when_zarr_path_none( - monkeypatch: pytest.MonkeyPatch, tmp_path: Path -) -> None: - """Test that a temporary directory is created when zarr_path is None.""" - shape = (1000, 1000, 3) - dtype = np.float32 - - class FakeVM: - """Force fits_in_memory = False by mocking available memory to be tiny.""" - - available = 1 - - monkeypatch.setattr(psutil, "virtual_memory", FakeVM) - - # Track calls to tempfile.mkdtemp - created_dirs = [] - - def fake_mkdtemp(prefix: str) -> str: - """Fake mkdtemp method.""" - _ = prefix - path = tmp_path / "tempdir" - created_dirs.append(path) - return str(path) - - monkeypatch.setattr(tempfile, "mkdtemp", fake_mkdtemp) - - arr = create_smart_array( - shape=shape, - dtype=dtype, - memory_threshold=0, # force Zarr allocation - name="test", - zarr_path=None, - ) - - # Ensure mkdtemp was called - assert len(created_dirs) == 1 - - # Ensure returned object is a Zarr array - assert isinstance(arr, zarr.Array) - assert arr.shape == shape - assert arr.dtype == dtype diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index aa3253a63..5b8260d4d 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -579,11 +579,11 @@ def parse_bool_list( def cli_return_predictions( - usage_help: str = "Whether to return raw model probabilities.", + usage_help: str = "Whether to return predictions for individual tasks.", *, default: tuple[bool, ...] | None = None, ) -> Callable: - """Enables --return-probabilities option for cli.""" + """Enables --return-predictions option for cli.""" return click.option( "--return-predictions", callback=parse_bool_list, diff --git a/tiatoolbox/cli/multitask_segmentor.py b/tiatoolbox/cli/multitask_segmentor.py index 6647d144a..d9b2c018e 100644 --- a/tiatoolbox/cli/multitask_segmentor.py +++ b/tiatoolbox/cli/multitask_segmentor.py @@ -100,7 +100,7 @@ def multitask_segmentor( verbose: bool, overwrite: bool, ) -> None: - """Process a set of input images with a semantic segmentation engine.""" + """Process a set of input images with a multitask segmentation engine.""" from tiatoolbox.models import IOSegmentorConfig, MultiTaskSegmentor # noqa: PLC0415 files_all, masks_all, output_path = prepare_model_cli( @@ -132,7 +132,6 @@ def multitask_segmentor( patch_output_shape=patch_output_shape, input_resolutions=input_resolutions, output_resolutions=output_resolutions, - batch_size=batch_size, ioconfig=ioconfig, device=device, save_dir=output_path, @@ -141,10 +140,8 @@ def multitask_segmentor( return_probabilities=return_probabilities, auto_get_mask=auto_get_mask, memory_threshold=memory_threshold, - num_workers=num_workers, output_file=output_file, scale_factor=scale_factor, stride_shape=stride_shape, overwrite=overwrite, - verbose=verbose, ) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 9fe39adc9..8e500f18f 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -819,7 +819,7 @@ def postproc(self: HoVerNet, raw_maps: list[np.ndarray]) -> tuple[dict, ...]: @staticmethod def infer_batch( # skipcq: PYL-W0221 - model: nn.Module, batch_data: np.ndarray, device: str + model: nn.Module, batch_data: np.ndarray | torch.Tensor, *, device: str ) -> tuple: """Run inference on an input batch. diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 49f8dfbf6..74e8f49fb 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -396,7 +396,7 @@ def postproc(self: HoVerNetPlus, raw_maps: list[np.ndarray]) -> tuple[dict, ...] @staticmethod def infer_batch( # skipcq: PYL-W0221 - model: nn.Module, batch_data: np.ndarray, *, device: str + model: nn.Module, batch_data: np.ndarray | torch.Tensor, *, device: str ) -> tuple: """Run inference on an input batch. diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 01722cabf..0270f76b4 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -73,8 +73,6 @@ from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.type_hints import IntPair, Resolution, Units -dask.config.set({"dataframe.convert-string": False}) - class EngineABCRunParams(TypedDict, total=False): """Parameters for configuring the :func:`EngineABC.run()` method. diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 90f539f93..b4bb56474 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -421,7 +421,7 @@ def infer_patches( tqdm_loop = ( tqdm_(dataloader, leave=False, desc="Inferring patches") if self.verbose - else self.dataloader + else dataloader ) for batch_data in tqdm_loop: diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 564c1fda7..a01107760 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1677,7 +1677,7 @@ def create_smart_array( dtype: np.dtype | str, memory_threshold: float, name: str | None, - zarr_path: str | Path | None = None, + zarr_path: str | Path, chunks: tuple[int, ...] | None = None, ) -> np.ndarray | zarr.Array: """Allocate a NumPy or Zarr array depending on available memory and a threshold. @@ -1697,14 +1697,14 @@ def create_smart_array( Fraction of available RAM allowed for this allocation. Must be between 0.0 and 100. A value of 100 allows using all available RAM; 0.0 forces Zarr allocation. + name (str | None): + Name for the zarr dataset. zarr_path (str | None): Filesystem path where the Zarr array will be created if needed. Defaults to "array.zarr". chunks (tuple(int,...) | None): Chunk shape for the Zarr array. If None, a reasonable default is chosen based on the array shape. - name (str | None): - Name for the zarr dataset. Returns: np.ndarray | zarr.core.Array: @@ -1724,10 +1724,6 @@ def create_smart_array( # Allocate in-memory NumPy array return np.zeros(shape, dtype=dtype) - if zarr_path is None: - temp_dir = tempfile.mkdtemp(prefix="smartarray_") - zarr_path = Path(str(temp_dir)) / "array.zarr" - # Allocate Zarr array on disk # Default chunking: try to chunk along spatial dims chunks = shape if chunks is None else chunks From 46eb4dc11dcd9362b0f99ee2f012bbf49b30cc0f Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 3 Feb 2026 12:17:43 +0000 Subject: [PATCH 075/156] :green_heart: Address Co-Pilot comments --- tiatoolbox/models/architecture/hovernet.py | 16 +++++++++------- tiatoolbox/models/engine/engine_abc.py | 2 +- tiatoolbox/models/engine/multi_task_segmentor.py | 2 +- tiatoolbox/utils/misc.py | 1 - 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 8e500f18f..9f35ae5e1 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -787,19 +787,21 @@ def postproc(self: HoVerNet, raw_maps: list[np.ndarray]) -> tuple[dict, ...]: np_map = np_map.compute() if is_dask else np_map hv_map = hv_map.compute() if is_dask else hv_map - pred_type = tp_map.compute() if is_dask else tp_map + pred_type = tp_map.compute() if tp_map is not None and is_dask else tp_map pred_inst = HoVerNet._proc_np_hv(np_map, hv_map) nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) nuc_inst_info_dict_ = {} if not nuc_inst_info_dict: - nuc_inst_info_dict_ = { # inst_id should start at 1 - "box": da.empty(shape=0), - "centroid": da.empty(shape=0), - "contours": da.empty(shape=0), - "prob": da.empty(shape=0), - "type": da.empty(shape=0), + # inst_id should start at 1; use NumPy or Dask empty arrays + empty_array = da.empty(shape=0) if is_dask else np.empty(shape=0) + nuc_inst_info_dict_ = { + "box": empty_array, + "centroid": empty_array, + "contours": empty_array, + "prob": empty_array, + "type": empty_array, } else: nuc_inst_info_dict_ = _inst_dict_for_dask_processing( diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 0270f76b4..08b1d5579 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -815,7 +815,7 @@ def save_predictions_as_zarr( keys_to_compute = [k for k in keys_to_compute if k not in zarr_group] # If the task group already exists, only compute missing keys - if task_name in zarr_group: + if task_name is not None and task_name in zarr_group: task_group = zarr_group[task_name] keys_to_compute = [k for k in keys_to_compute if k not in task_group] diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index b4bb56474..89a3268a4 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -2094,7 +2094,7 @@ def prepare_multitask_full_batch( old_arr=full_batch_output[idx], new_arr=np.zeros( shape=(len(full_output_locs), *batch_output_.shape[1:]), - dtype=np.uint8, + dtype=batch_output_.dtype, ), ) diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 6b008cb80..896223f06 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1713,7 +1713,6 @@ def create_smart_array( Name for the zarr dataset. zarr_path (str | None): Filesystem path where the Zarr array will be created if needed. - Defaults to "array.zarr". chunks (tuple(int,...) | None): Chunk shape for the Zarr array. If None, a reasonable default is chosen based on the array shape. From a60b50f9f2f15589a23b1e1454975bf1c73a8303 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 3 Feb 2026 13:18:28 +0000 Subject: [PATCH 076/156] :zap: Use `prepare_full_batch` from `semantic_segmentor`. --- .../models/engine/multi_task_segmentor.py | 66 +++++++++---------- .../models/engine/semantic_segmentor.py | 1 + 2 files changed, 32 insertions(+), 35 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 89a3268a4..fbf7b367e 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -143,6 +143,7 @@ SemanticSegmentorRunParams, concatenate_none, merge_batch_to_canvas, + prepare_full_batch, store_probabilities, ) @@ -605,10 +606,13 @@ def infer_wsi( # Interpolate outputs for masked regions full_batch_output, full_output_locs, output_locs = ( prepare_multitask_full_batch( - batch_output, - batch_locs, - full_output_locs, - output_locs, + batch_output=batch_output, + batch_locs=batch_locs, + full_output_locs=full_output_locs, + output_locs=output_locs, + canvas_np=canvas_np, + save_path=save_path.with_name("full_batch_tmp"), + memory_threshold=memory_threshold, is_last=(batch_idx == (len(dataloader) - 1)), ) ) @@ -1991,6 +1995,9 @@ def prepare_multitask_full_batch( batch_locs: np.ndarray, full_output_locs: np.ndarray, output_locs: np.ndarray, + canvas_np: list[np.ndarray | zarr.Array | None] | None = None, + save_path: Path | str = "temp_fullbatch", + memory_threshold: int = 80, *, is_last: bool, ) -> tuple[list[np.ndarray], np.ndarray, np.ndarray]: @@ -2034,6 +2041,14 @@ def prepare_multitask_full_batch( extended in-place with the portion of `full_output_locs` filled in this call, and with any remaining tail (zeros padded in outputs) when `is_last=True`. + canvas_np (tuple[np.ndarray | zarr.Array] | None): + List of accumulated canvas arrays from previous batches. Used to check + total memory footprint when deciding numpy vs zarr. + save_path (Path | str): + Path to a directory; a unique temp subfolder will be created within it + to store the temporary full-batch zarr for this batch. + memory_threshold (int): + Memory usage threshold (in percentage) to trigger caching behavior. is_last (bool): Whether this is the final batch. When True, any locations left in `full_output_locs` after placing matches are appended to @@ -2064,41 +2079,22 @@ def prepare_multitask_full_batch( remain consistent across batches. """ - # Use np.intersect1d once numpy version is upgraded to 2.0 - full_output_dict = {tuple(row): i for i, row in enumerate(full_output_locs)} - matches = [full_output_dict[tuple(row)] for row in batch_locs] - - total_size = np.max(matches).astype(np.uint32) + 1 - full_batch_output = [np.empty(0) for _ in range(len(batch_output))] - + full_output_locs_ = full_output_locs.copy() + output_locs_ = output_locs for idx, batch_output_ in enumerate(batch_output): - # Initialize full output array - full_batch_output[idx] = np.zeros( - shape=(total_size, *batch_output_.shape[1:]), - dtype=batch_output_.dtype, + full_batch_output[idx], full_output_locs_, output_locs_ = prepare_full_batch( + batch_output=batch_output_, + batch_locs=batch_locs, + full_output_locs=full_output_locs, + output_locs=output_locs, + canvas_np=canvas_np[idx], + save_path=save_path.with_name(f"_{idx}"), + memory_threshold=memory_threshold, + is_last=is_last, ) - # Place matching outputs using matching indices - full_batch_output[idx][matches] = batch_output_ - - output_locs = concatenate_none( - old_arr=output_locs, new_arr=full_output_locs[:total_size] - ) - full_output_locs = full_output_locs[total_size:] - - if is_last: - output_locs = concatenate_none(old_arr=output_locs, new_arr=full_output_locs) - for idx, batch_output_ in enumerate(batch_output): - full_batch_output[idx] = concatenate_none( - old_arr=full_batch_output[idx], - new_arr=np.zeros( - shape=(len(full_output_locs), *batch_output_.shape[1:]), - dtype=batch_output_.dtype, - ), - ) - - return full_batch_output, full_output_locs, output_locs + return full_batch_output, full_output_locs_, output_locs_ def merge_multitask_horizontal( diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index beebdca98..787d5e9ec 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1421,6 +1421,7 @@ def prepare_full_batch( """ # Map batch locations back to indices in the full output grid. # Use a dict to avoid allocating a huge dense array when locations are sparse. + # Use np.intersect1d once numpy version is upgraded to 2.0 full_output_dict = {tuple(row): i for i, row in enumerate(full_output_locs)} matches = np.array([full_output_dict[tuple(row)] for row in batch_locs]) From 3a68f0a91ea23afe1f1c03a4fb4b851a0f10e3cf Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 3 Feb 2026 13:36:44 +0000 Subject: [PATCH 077/156] :zap: Use `save_to_cache` from `semantic_segmentor`. --- .../models/engine/multi_task_segmentor.py | 48 ++++--------------- .../models/engine/semantic_segmentor.py | 10 ++-- 2 files changed, 15 insertions(+), 43 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index fbf7b367e..6cacebc22 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -125,7 +125,6 @@ import psutil import torch import zarr -from dask import compute from shapely.geometry import box as shapely_box from shapely.geometry import shape as feature2geometry from shapely.strtree import STRtree @@ -144,6 +143,7 @@ concatenate_none, merge_batch_to_canvas, prepare_full_batch, + save_to_cache, store_probabilities, ) @@ -2282,47 +2282,15 @@ def save_multitask_to_cache( and ``count`` to free RAM and continue populating new entries. """ - zarr_group = None for idx, canvas_ in enumerate(canvas): - computed_values = compute(*[canvas_, count[idx]]) - canvas_computed, count_computed = computed_values - - chunk_shape = tuple(chunk[0] for chunk in canvas_.chunks) - if canvas_zarr[idx] is None: - # Only open zarr for first canvas. - zarr_group = zarr.open(str(save_path), mode="w") if idx == 0 else zarr_group - - canvas_zarr[idx] = zarr_group.create_dataset( - name=f"canvas/{idx}", - shape=(0, *canvas_computed.shape[1:]), - chunks=(chunk_shape[0], *canvas_computed.shape[1:]), - dtype=canvas_computed.dtype, - overwrite=True, - ) - - count_zarr[idx] = zarr_group.create_dataset( - name=f"count/{idx}", - shape=(0, *count_computed.shape[1:]), - dtype=count_computed.dtype, - chunks=(chunk_shape[0], *count_computed.shape[1:]), - overwrite=True, - ) - - canvas_zarr[idx].resize( - ( - canvas_zarr[idx].shape[0] + canvas_computed.shape[0], - *canvas_zarr[idx].shape[1:], - ) - ) - canvas_zarr[idx][-canvas_computed.shape[0] :] = canvas_computed - - count_zarr[idx].resize( - ( - count_zarr[idx].shape[0] + count_computed.shape[0], - *count_zarr[idx].shape[1:], - ) + canvas_zarr[idx], count_zarr[idx] = save_to_cache( + canvas=canvas_, + count=count[idx], + canvas_zarr=canvas_zarr[idx], + count_zarr=count_zarr[idx], + save_path=save_path, + zarr_dataset_name=(f"canvas/{idx}", f"count/{idx}"), ) - count_zarr[idx][-count_computed.shape[0] :] = count_computed return canvas_zarr, count_zarr diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 787d5e9ec..f52285986 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1119,6 +1119,7 @@ def save_to_cache( canvas_zarr: zarr.Array, count_zarr: zarr.Array, save_path: str | Path = "temp.zarr", + zarr_dataset_name: tuple[str, str] = ("canvas", "count"), ) -> tuple[zarr.Array, zarr.Array]: """Incrementally save computed canvas and count arrays to Zarr cache. @@ -1138,6 +1139,9 @@ def save_to_cache( Existing Zarr dataset for count data. If None, a new one is created. save_path (str | Path): Path to the Zarr group for saving datasets. Defaults to "temp.zarr". + zarr_dataset_name (tuple[str, str]): + Tuple of name for zarr dataset to save canvas and count. + Defaults to ("canvas", "count"). Returns: tuple[zarr.Array, zarr.Array]: @@ -1146,7 +1150,7 @@ def save_to_cache( chunk0 = canvas.chunks[0][0] if canvas_zarr is None: - zarr_group = zarr.open(str(save_path), mode="w") + zarr_group = zarr.open(str(save_path), mode="a") # Peek first block shapes to initialise datasets without computing all rows. # Blocks are 3D: (row_chunk, col_chunk, channel_chunk). Grab the first. @@ -1154,7 +1158,7 @@ def save_to_cache( first_count_block = count.blocks[0, 0, 0].compute() canvas_zarr = zarr_group.create_dataset( - name="canvas", + name=zarr_dataset_name[0], # Append along axis 0 (height); keep width/channels fixed. shape=(0, *first_canvas_block.shape[1:]), chunks=(chunk0, *first_canvas_block.shape[1:]), @@ -1163,7 +1167,7 @@ def save_to_cache( ) count_zarr = zarr_group.create_dataset( - name="count", + name=zarr_dataset_name[1], shape=(0, *first_count_block.shape[1:]), dtype=first_count_block.dtype, chunks=(chunk0, *first_count_block.shape[1:]), From 18368164c84cdd5b08335e1f62bd076810ef1f6b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 3 Feb 2026 13:41:23 +0000 Subject: [PATCH 078/156] :zap: Use `max(vm.available, 1)` to calculate memory usage. --- tests/engines/test_multi_task_segmentor.py | 2 +- tiatoolbox/models/engine/multi_task_segmentor.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 01c333075..7b4c49d93 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -572,7 +572,7 @@ def test_vertical_save_branch_without_patch( class FakeVM: """Fake psutil.virtual_memory() with extremely low free memory.""" - free = 1 # force used_percent > memory_threshold + available = 0 # force used_percent > memory_threshold monkeypatch.setattr(psutil, "virtual_memory", FakeVM) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 6cacebc22..c27bf7c9d 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -2452,7 +2452,7 @@ def _save_multitask_vertical_to_cache( vm = psutil.virtual_memory() # Calculate total bytes for all outputs total_bytes = sum(0 if arr is None else arr.nbytes for arr in probabilities_da) - used_percent = (total_bytes / vm.free) * 100 + used_percent = (total_bytes / max(vm.available, 1)) * 100 if probabilities_zarr[idx] is None and used_percent > memory_threshold: msg = ( f"Current Memory usage: {used_percent} % " @@ -2549,7 +2549,7 @@ def _check_and_update_for_memory_overload( vm = psutil.virtual_memory() used_percent = vm.percent total_bytes = sum(arr.nbytes for arr in canvas) if canvas else 0 - canvas_used_percent = (total_bytes / vm.free) * 100 + canvas_used_percent = (total_bytes / max(vm.available, 1)) * 100 if not (used_percent > memory_threshold or canvas_used_percent > memory_threshold): return canvas, count, canvas_zarr, count_zarr, tqdm_loop From 9fb9ea69bdcc507b01b4646dd535e56023d772c5 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 3 Feb 2026 14:01:48 +0000 Subject: [PATCH 079/156] :zap: Use `merge_horizontal` from `semantic_segmentor`. --- .../models/engine/multi_task_segmentor.py | 49 ++++++------------- .../models/engine/semantic_segmentor.py | 10 ++-- 2 files changed, 21 insertions(+), 38 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index c27bf7c9d..b12c995f2 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -141,7 +141,7 @@ SemanticSegmentor, SemanticSegmentorRunParams, concatenate_none, - merge_batch_to_canvas, + merge_horizontal, prepare_full_batch, save_to_cache, store_probabilities, @@ -2100,7 +2100,7 @@ def prepare_multitask_full_batch( def merge_multitask_horizontal( canvas: list[None] | list[da.Array], count: list[None] | list[da.Array], - output_locs_y_: np.ndarray, + output_locs_y: np.ndarray, canvas_np: list[np.ndarray], output_locs: np.ndarray, change_indices: np.ndarray | list[int], @@ -2136,7 +2136,7 @@ def merge_multitask_horizontal( count (list[da.Array] | list[None]): Accumulated per-head row count maps, aligned with ``canvas``. Pass ``None`` for each head on the first call. - output_locs_y_ (np.ndarray): + output_locs_y (np.ndarray): Accumulated vertical extents of already-merged rows. Each appended element is ``[y0, y1]`` corresponding to the merged row's span. Pass ``None`` on the first call; it will be initialized internally @@ -2186,39 +2186,22 @@ def merge_multitask_horizontal( vertical concatenation and overlap handling. """ - start_idx = 0 - for c_idx in change_indices: - output_locs_ = output_locs[: c_idx - start_idx] - - batch_xs = np.min(output_locs[:, 0], axis=0) - batch_xe = np.max(output_locs[:, 2], axis=0) - - for idx, canvas_np_ in enumerate(canvas_np): - canvas_np__ = canvas_np_[: c_idx - start_idx] - merged_shape = ( - canvas_np__.shape[1], - batch_xe - batch_xs, - canvas_np__.shape[3], + output_locs_ = np.empty(0) + output_locs_y_ = np.empty(0) + + for idx, canvas_np_ in enumerate(canvas_np): + canvas[idx], count[idx], canvas_np[idx], output_locs_, output_locs_y_ = ( + merge_horizontal( + canvas_np=canvas_np_, + canvas=canvas[idx], + count=count[idx], + output_locs_y=output_locs_y, + output_locs=output_locs, + change_indices=change_indices, ) - canvas_merge, count_merge = merge_batch_to_canvas( - blocks=canvas_np__, - output_locations=output_locs_, - merged_shape=merged_shape, - ) - canvas_merge = da.from_array(canvas_merge, chunks=canvas_merge.shape) - count_merge = da.from_array(count_merge, chunks=count_merge.shape) - canvas[idx] = concatenate_none(old_arr=canvas[idx], new_arr=canvas_merge) - count[idx] = concatenate_none(old_arr=count[idx], new_arr=count_merge) - canvas_np[idx] = canvas_np[idx][c_idx - start_idx :] - - output_locs_y_ = concatenate_none( - old_arr=output_locs_y_, new_arr=output_locs[:, (1, 3)] ) - output_locs = output_locs[c_idx - start_idx :] - start_idx = c_idx - - return canvas, count, canvas_np, output_locs, output_locs_y_ + return canvas, count, canvas_np, output_locs_, output_locs_y_ def save_multitask_to_cache( diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index f52285986..c7d052f31 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1046,7 +1046,7 @@ def merge_batch_to_canvas( def merge_horizontal( canvas: None | da.Array, count: None | da.Array, - output_locs_y_: np.ndarray, + output_locs_y: np.ndarray, canvas_np: np.ndarray, output_locs: np.ndarray, change_indices: np.ndarray | list[int], @@ -1063,7 +1063,7 @@ def merge_horizontal( Existing Dask array for canvas data, or None if uninitialized. count (None | da.Array): Existing Dask array for count data, or None if uninitialized. - output_locs_y_ (np.ndarray): + output_locs_y (np.ndarray): Array tracking vertical output locations for merged patches. canvas_np (np.ndarray): NumPy array of canvas patches to be merged. @@ -1102,15 +1102,15 @@ def merge_horizontal( canvas = concatenate_none(old_arr=canvas, new_arr=canvas_merge) count = concatenate_none(old_arr=count, new_arr=count_merge) - output_locs_y_ = concatenate_none( - old_arr=output_locs_y_, new_arr=output_locs_[:, (1, 3)] + output_locs_y = concatenate_none( + old_arr=output_locs_y, new_arr=output_locs_[:, (1, 3)] ) canvas_np = canvas_np[c_idx - start_idx :] output_locs = output_locs[c_idx - start_idx :] start_idx = c_idx - return canvas, count, canvas_np, output_locs, output_locs_y_ + return canvas, count, canvas_np, output_locs, output_locs_y def save_to_cache( From f7e35be184c5314aa62b60a8d752992ed6bd7a33 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 3 Feb 2026 14:39:01 +0000 Subject: [PATCH 080/156] :bug: Fix saving and deletion of temp_zarr name. --- tiatoolbox/models/engine/multi_task_segmentor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index b12c995f2..70b5a9a49 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -114,6 +114,7 @@ from __future__ import annotations import gc +import shutil import uuid from collections import deque from pathlib import Path @@ -679,6 +680,9 @@ def infer_wsi( raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) + if save_path.with_name("full_batch_tmp").exists(): + shutil.rmtree(save_path.with_name("full_batch_tmp")) + return raw_predictions def post_process_patches( # skipcq: PYL-R0201 @@ -2089,7 +2093,7 @@ def prepare_multitask_full_batch( full_output_locs=full_output_locs, output_locs=output_locs, canvas_np=canvas_np[idx], - save_path=save_path.with_name(f"_{idx}"), + save_path=save_path, memory_threshold=memory_threshold, is_last=is_last, ) From dd3e5a537105aa945a6a58df019133fb0b38dd97 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 3 Feb 2026 14:56:12 +0000 Subject: [PATCH 081/156] :white_check_mark: Use `hovernet_original-kumar` --- tests/engines/test_multi_task_segmentor.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 7b4c49d93..0365e0083 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -16,6 +16,7 @@ from tiatoolbox import cli from tiatoolbox.annotation import SQLiteStore +from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.models.engine.multi_task_segmentor import ( MultiTaskSegmentor, _clear_zarr, @@ -414,8 +415,11 @@ def test_wsi_segmentor_annotationstore( ) -> None: """Test MultiTaskSegmentor for WSIs with AnnotationStore output.""" wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") + # testing different configuration for hovernet. + # kumar only has two probability maps + model_name = "hovernet_original-kumar" mtsegmentor = MultiTaskSegmentor( - model="hovernet_fast-pannuke", + model=model_name, batch_size=32, verbose=False, ) @@ -443,6 +447,9 @@ def test_wsi_segmentor_annotationstore( assert store_file_path.exists() assert store_file_path == output[wsi4_512_512_svs][0] + weights_path = Path(fetch_pretrained_weights(model_name=model_name)) + weights_path.unlink() + def test_wsi_segmentor_annotationstore_probabilities( remote_sample: Callable, track_tmp_path: Path, caplog: pytest.CaptureFixture From 27d6161bb2b14acefa1ba2a85a7989eb5f80efe6 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 3 Feb 2026 15:32:42 +0000 Subject: [PATCH 082/156] :white_check_mark: Update `NucleusInstanceSegmentor`. --- .../test_nucleus_instance_segmentor.py | 80 ++ tiatoolbox/models/__init__.py | 3 +- tiatoolbox/models/dataset/__init__.py | 2 - tiatoolbox/models/dataset/dataset_abc.py | 149 +-- .../models/engine/multi_task_segmentor.py | 4 +- .../engine/nucleus_instance_segmentor.py | 912 +++--------------- 6 files changed, 220 insertions(+), 930 deletions(-) create mode 100644 tests/engines/test_nucleus_instance_segmentor.py diff --git a/tests/engines/test_nucleus_instance_segmentor.py b/tests/engines/test_nucleus_instance_segmentor.py new file mode 100644 index 000000000..8121046d7 --- /dev/null +++ b/tests/engines/test_nucleus_instance_segmentor.py @@ -0,0 +1,80 @@ +"""Test NucleusInstanceSegmentor.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Final + +import numpy as np +import torch + +from tiatoolbox.models.engine.nucleus_instance_segmentor import ( + NucleusInstanceSegmentor, +) +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.wsicore import WSIReader + +from .test_multi_task_segmentor import ( + assert_output_lengths, + assert_predictions_and_boxes, +) + +if TYPE_CHECKING: + from collections.abc import Callable + + import pytest + +device = "cuda" if toolbox_env.has_gpu() else "cpu" +OutputType = dict[str, Any] | Any + + +def test_mtsegmentor_init(caplog: pytest.LogCaptureFixture) -> None: + """Tests NucleusInstanceSegmentor initialization.""" + segmentor = NucleusInstanceSegmentor(model="hovernetplus-oed", device=device) + + assert isinstance(segmentor, NucleusInstanceSegmentor) + assert isinstance(segmentor.model, torch.nn.Module) + assert ( + "NucleusInstanceSegmentor is deprecated and will be removed in " + "a future release." in caplog.text + ) + + +def test_mtsegmentor_patches(remote_sample: Callable) -> None: + """Tests NucleusInstanceSegmentor on image patches.""" + mtsegmentor = NucleusInstanceSegmentor( + model="hovernet_fast-pannuke", batch_size=32, verbose=False, device=device + ) + + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + mini_wsi = WSIReader.open(mini_wsi_svs) + size = (256, 256) + resolution = 0.50 + units: Final = "mpp" + + patch1 = mini_wsi.read_rect( + location=(0, 0), size=size, resolution=resolution, units=units + ) + patch2 = mini_wsi.read_rect( + location=(512, 512), size=size, resolution=resolution, units=units + ) + patch3 = np.zeros_like(patch1) + patches = np.stack([patch1, patch2, patch3], axis=0) + + assert not mtsegmentor.patch_mode + + output_dict = mtsegmentor.run( + images=patches, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + ) + + expected_counts_nuclei = [62, 33, 0] + assert_output_lengths( + output_dict, + expected_counts_nuclei, + fields=["box", "centroid", "contours", "prob", "type"], + ) + assert_predictions_and_boxes(output_dict, expected_counts_nuclei, is_zarr=False) diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index cd852cffa..58d0f2e25 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -10,7 +10,7 @@ from .architecture.micronet import MicroNet from .architecture.nuclick import NuClick from .architecture.sccnn import SCCNN -from .dataset import PatchDataset, WSIPatchDataset, WSIStreamDataset +from .dataset import PatchDataset, WSIPatchDataset from .engine.deep_feature_extractor import DeepFeatureExtractor from .engine.io_config import ( IOInstanceSegmentorConfig, @@ -43,7 +43,6 @@ "PatchPredictor", "SemanticSegmentor", "WSIPatchDataset", - "WSIStreamDataset", "architecture", "dataset", "engine", diff --git a/tiatoolbox/models/dataset/__init__.py b/tiatoolbox/models/dataset/__init__.py index 16c80fd18..ab593855f 100644 --- a/tiatoolbox/models/dataset/__init__.py +++ b/tiatoolbox/models/dataset/__init__.py @@ -6,7 +6,6 @@ PatchDataset, PatchDatasetABC, WSIPatchDataset, - WSIStreamDataset, ) from .info import DatasetInfoABC, KatherPatchDataset @@ -16,6 +15,5 @@ "PatchDataset", "PatchDatasetABC", "WSIPatchDataset", - "WSIStreamDataset", "predefined_preproc_func", ] diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index f81312827..2ed8dffc2 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -2,7 +2,6 @@ from __future__ import annotations -import copy import os from abc import ABC, abstractmethod from pathlib import Path @@ -11,20 +10,17 @@ import cv2 import numpy as np import torch -import torch.utils.data as torch_data from tiatoolbox import logger from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils import imread from tiatoolbox.utils.exceptions import DimensionMismatchError -from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader +from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader if TYPE_CHECKING: # pragma: no cover from collections.abc import Callable, Iterable - from multiprocessing.managers import Namespace from typing import TypeGuard - from tiatoolbox.models.engine.io_config import IOSegmentorConfig from tiatoolbox.type_hints import IntPair, Resolution, Units input_type = list[str | Path | np.ndarray] | np.ndarray @@ -191,149 +187,6 @@ def __getitem__(self: PatchDatasetABC, idx: int) -> None: ... # pragma: no cover -class WSIStreamDataset(torch_data.Dataset): - """Reading a wsi in parallel mode with persistent workers. - - To speed up the inference process for multiple WSIs. The - `torch.utils.data.Dataloader` is set to run in persistent mode. - Normally, this will prevent workers from altering their initial - states (such as provided input etc.). To sidestep this, we use a - shared parallel workspace context manager to send data and signal - from the main thread, thus allowing each worker to load a new wsi as - well as corresponding patch information. - - Args: - mp_shared_space (:class:`Namespace`): - A shared multiprocessing space, must be from - `torch.multiprocessing`. - ioconfig (:class:`IOSegmentorConfig`): - An object which contains I/O placement for patches. - wsi_paths (list): List of paths pointing to a WSI or tiles. - preproc (Callable): - Pre-processing function to be applied to a patch. - mode (str): - Either `"wsi"` or `"tile"` to indicate the format of images - in `wsi_paths`. - - Examples: - >>> ioconfig = IOSegmentorConfig( - ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], - ... ) - >>> mp_manager = torch_mp.Manager() - >>> mp_shared_space = mp_manager.Namespace() - >>> mp_shared_space.signal = 1 # adding variable to the shared space - >>> wsi_paths = ['A.svs', 'B.svs'] - >>> wsi_dataset = WSIStreamDataset(ioconfig, wsi_paths, mp_shared_space) - - """ - - def __init__( - self: WSIStreamDataset, - ioconfig: IOSegmentorConfig, - wsi_paths: list[str | Path], - mp_shared_space: Namespace, - preproc: Callable[[np.ndarray], np.ndarray] | None = None, - mode: str = "wsi", - ) -> None: - """Initialize :class:`WSIStreamDataset`.""" - super().__init__() - self.mode = mode - self.preproc = preproc - self.ioconfig = copy.deepcopy(ioconfig) - - if mode == "tile": - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"`. Resolutions will be converted ' - "to baseline value.", - stacklevel=2, - ) - self.ioconfig = self.ioconfig.to_baseline() - - self.mp_shared_space = mp_shared_space - self.wsi_paths = wsi_paths - self.wsi_idx = None # to be received externally via thread communication - self.reader = None - - def _get_reader(self: WSIStreamDataset, img_path: str | Path) -> WSIReader: - """Get appropriate reader for input path.""" - img_path = Path(img_path) - if self.mode == "wsi": - return WSIReader.open(img_path) - img = imread(img_path) - # initialise metadata for VirtualWSIReader. - # here, we simulate a whole-slide image, but with a single level. - metadata = WSIMeta( - mpp=np.array([1.0, 1.0]), - objective_power=10, - axes="YXS", - slide_dimensions=np.array(img.shape[:2][::-1]), - level_downsamples=[1.0], - level_dimensions=[np.array(img.shape[:2][::-1])], - ) - return VirtualWSIReader( - img, - info=metadata, - ) - - def __len__(self: WSIStreamDataset) -> int: - """Return the length of the instance attributes.""" - return len(self.mp_shared_space.patch_inputs) - - @staticmethod - def collate_fn(batch: list | np.ndarray) -> torch.Tensor: - """Prototype to handle reading exception. - - This will exclude any sample with `None` from the batch. As - such, wrapping `__getitem__` with try-catch and return `None` - upon exceptions will prevent crashing the entire program. But as - a side effect, the batch may not have the size as defined. - - """ - batch = [v for v in batch if v is not None] - return torch.utils.data.dataloader.default_collate(batch) - - def __getitem__(self: WSIStreamDataset, idx: int) -> tuple: - """Get an item from the dataset.""" - # ! no need to lock as we do not modify source value in shared space - if self.wsi_idx != self.mp_shared_space.wsi_idx: - self.wsi_idx = int(self.mp_shared_space.wsi_idx.item()) - self.reader = self._get_reader(self.wsi_paths[self.wsi_idx]) - - # this is in XY and at requested resolution (not baseline) - bounds = self.mp_shared_space.patch_inputs[idx] - bounds = bounds.numpy() # expected to be a torch.Tensor - - # be the same as bounds br-tl, unless bounds are of float - patch_data_ = [] - scale_factors = self.ioconfig.scale_to_highest( - self.ioconfig.input_resolutions, - self.ioconfig.resolution_unit, - ) - for idy, resolution in enumerate(self.ioconfig.input_resolutions): - resolution_bounds = np.round(bounds * scale_factors[idy]) - patch_data = self.reader.read_bounds( - resolution_bounds.astype(np.int32), - coord_space="resolution", - pad_constant_values=0, # expose this ? - **resolution, - ) - - if self.preproc is not None: - patch_data = patch_data.copy() - patch_data = self.preproc(patch_data) - patch_data_.append(patch_data) - if len(patch_data_) == 1: - patch_data_ = patch_data_[0] - - bound = self.mp_shared_space.patch_outputs[idx] - return patch_data_, bound - - class WSIPatchDataset(PatchDatasetABC): """Define a WSI-level patch dataset. diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 70b5a9a49..7caa492a6 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -233,7 +233,7 @@ class MultiTaskSegmentor(SemanticSegmentor): weights (str | Path | None): Path to model weights. If None, default weights are used. - >>> engine = SemanticSegmentor( + >>> engine = MultiTaskSegmentor( ... model="pretrained-model", ... weights="/path/to/pretrained-local-weights.pth" ... ) @@ -259,7 +259,7 @@ class MultiTaskSegmentor(SemanticSegmentor): IO configuration for patch extraction and resolution. return_labels (bool): Whether to include labels in the output. - return_predictions (dict): + return_predictions_dict (dict): This dictionary helps keep track of which tasks require predictions in the output. input_resolutions (list[dict]): diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index ce74355ae..5729547d1 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -2,813 +2,173 @@ from __future__ import annotations -import uuid -from collections import deque +import warnings from typing import TYPE_CHECKING -# replace with the sql database once the PR in place -import joblib -import numpy as np -import torch -import tqdm -from shapely.geometry import box as shapely_box -from shapely.strtree import STRtree +from tiatoolbox import logger -from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset -from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor -from tiatoolbox.tools.patchextraction import PatchExtractor +from .multi_task_segmentor import MultiTaskSegmentor if TYPE_CHECKING: # pragma: no cover - from collections.abc import Callable + from pathlib import Path - from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig + from tiatoolbox.models.models_abc import ModelABC -def _process_instance_predictions( - inst_dict: dict, - ioconfig: IOSegmentorConfig, - tile_shape: list, - tile_flag: list, - tile_mode: int, - tile_tl: tuple, - ref_inst_dict: dict, -) -> list | tuple: - """Function to merge new tile prediction with existing prediction. +class NucleusInstanceSegmentor(MultiTaskSegmentor): + """NucleusInstanceSegmentor is segmentation engine to run models like hovernet. - Args: - inst_dict (dict): Dictionary containing instance information. - ioconfig (:class:`IOSegmentorConfig`): Object defines information - about input and output placement of patches. - tile_shape (list): A list of the tile shape. - tile_flag (list): A list of flag to indicate if instances within - an area extended from each side (by `ioconfig.margin`) of - the tile should be replaced by those within the same spatial - region in the accumulated output this run. The format is - [top, bottom, left, right], 1 indicates removal while 0 is not. - For example, [1, 1, 0, 0] denotes replacing top and bottom instances - within `ref_inst_dict` with new ones after this processing. - tile_mode (int): A flag to indicate the type of this tile. There - are 4 flags: - - 0: A tile from tile grid without any overlapping, it is not - an overlapping tile from tile generation. The predicted - instances are immediately added to accumulated output. - - 1: Vertical tile strip that stands between two normal tiles - (flag 0). It has the same height as normal tile but - less width (hence vertical strip). - - 2: Horizontal tile strip that stands between two normal tiles - (flag 0). It has the same width as normal tile but - less height (hence horizontal strip). - - 3: tile strip stands at the cross-section of four normal tiles - (flag 0). - tile_tl (tuple): Top left coordinates of the current tile. - ref_inst_dict (dict): Dictionary contains accumulated output. The - expected format is {instance_id: {type: int, - contour: List[List[int]], centroid:List[float], box:List[int]}. - - Returns: - new_inst_dict (dict): A dictionary contain new instances to be accumulated. - The expected format is {instance_id: {type: int, - contour: List[List[int]], centroid:List[float], box:List[int]}. - remove_insts_in_orig (list): List of instance id within `ref_inst_dict` - to be removed to prevent overlapping predictions. These instances - are those get cutoff at the boundary due to the tiling process. - - """ - # should be rare, no nuclei detected in input images - if len(inst_dict) == 0: - return {}, [] - - # ! - m = ioconfig.margin - w, h = tile_shape - inst_boxes = [v["box"] for v in inst_dict.values()] - inst_boxes = np.array(inst_boxes) - - geometries = [shapely_box(*bounds) for bounds in inst_boxes] - tile_rtree = STRtree(geometries) - # ! - - # create margin bounding box, ordering should match with - # created tile info flag (top, bottom, left, right) - boundary_lines = [ - shapely_box(0, 0, w, 1), # top egde - shapely_box(0, h - 1, w, h), # bottom edge - shapely_box(0, 0, 1, h), # left - shapely_box(w - 1, 0, w, h), # right - ] - margin_boxes = [ - shapely_box(0, 0, w, m), # top egde - shapely_box(0, h - m, w, h), # bottom edge - shapely_box(0, 0, m, h), # left - shapely_box(w - m, 0, w, h), # right - ] - # ! this is wrt to WSI coord space, not tile - margin_lines = [ - [[m, m], [w - m, m]], # top egde - [[m, h - m], [w - m, h - m]], # bottom edge - [[m, m], [m, h - m]], # left - [[w - m, m], [w - m, h - m]], # right - ] - margin_lines = np.array(margin_lines) + tile_tl[None, None] - margin_lines = [shapely_box(*v.flatten().tolist()) for v in margin_lines] - - # the ids within this match with those within `inst_map`, not UUID - sel_indices = [] - if tile_mode in [0, 3]: - # for `full grid` tiles `cross section` tiles - # -- extend from the boundary by the margin size, remove - # nuclei whose entire contours lie within the margin area - sel_boxes = [ - box - for idx, box in enumerate(margin_boxes) - if tile_flag[idx] or tile_mode == 3 # noqa: PLR2004 - ] - - sel_indices = [ - geo - for bounds in sel_boxes - for geo in tile_rtree.query(bounds) - if bounds.contains(geometries[geo]) - ] - elif tile_mode in [1, 2]: - # for `horizontal/vertical strip` tiles - # -- extend from the marked edges (top/bot or left/right) by - # the margin size, remove all nuclei lie within the margin - # area (including on the margin line) - # -- remove all nuclei on the boundary also - - sel_boxes = [ - margin_boxes[idx] if flag else boundary_lines[idx] - for idx, flag in enumerate(tile_flag) - ] - - sel_indices = [geo for bounds in sel_boxes for geo in tile_rtree.query(bounds)] - else: - msg = f"Unknown tile mode {tile_mode}." - raise ValueError(msg) - - def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: - """Helper to retrieved selected instance uids.""" - if len(sel_indices) > 0: - # not sure how costly this is in large dict - inst_uids = list(inst_dict.keys()) - return [inst_uids[idx] for idx in sel_indices] - - remove_insts_in_tile = retrieve_sel_uids(sel_indices, inst_dict) + .. deprecated:: 2.1.0 + `NucleusInstanceSegmentor` will be removed in a future release. + Use :class:`MultiTaskSegmentor` instead. - # external removal only for tile at cross-sections - # this one should contain UUID with the reference database - remove_insts_in_orig = [] - if tile_mode == 3: # noqa: PLR2004 - inst_boxes = [v["box"] for v in ref_inst_dict.values()] - inst_boxes = np.array(inst_boxes) - - geometries = [shapely_box(*bounds) for bounds in inst_boxes] - ref_inst_rtree = STRtree(geometries) - sel_indices = [ - geo for bounds in margin_lines for geo in ref_inst_rtree.query(bounds) - ] - - remove_insts_in_orig = retrieve_sel_uids(sel_indices, ref_inst_dict) - - # move inst position from tile space back to WSI space - # an also generate universal uid as replacement for storage - new_inst_dict = {} - for inst_uid, inst_info in inst_dict.items(): - if inst_uid not in remove_insts_in_tile: - inst_info["box"] += np.concatenate([tile_tl] * 2) - inst_info["centroid"] += tile_tl - inst_info["contour"] += tile_tl - inst_uuid = uuid.uuid4().hex - new_inst_dict[inst_uuid] = inst_info - return new_inst_dict, remove_insts_in_orig - - -# Python is yet to be able to natively pickle Object method/static -# method. Only top-level function is passable to multiprocessing as -# caller. May need 3rd party libraries to use method/static method -# otherwise. -def _process_tile_predictions( - ioconfig: IOSegmentorConfig, - tile_bounds: np.ndarray, - tile_flag: list, - tile_mode: int, - tile_output: list, - # this would be replaced by annotation store - # in the future - ref_inst_dict: dict, - postproc: Callable, - merge_predictions: Callable, -) -> tuple[dict, list]: - """Function to merge new tile prediction with existing prediction. + NucleusInstanceSegmentor inherits MultiTaskSegmentor as it is a special case of + MultiTaskSegmentor with a single task. Args: - ioconfig (:class:`IOSegmentorConfig`): - Object defines information about input and output placement - of patches. - tile_bounds (:class:`numpy.array`): - Boundary of the current tile, defined as `(top_left_x, - top_left_y, bottom_x, bottom_y)`. - tile_flag (list): - A list of flag to indicate if instances within an area - extended from each side (by `ioconfig.margin`) of the tile - should be replaced by those within the same spatial region - in the accumulated output this run. The format is `[top, - bottom, left, right]`, 1 indicates removal while 0 is not. - For example, `[1, 1, 0, 0]` denotes replacing top and bottom - instances within `ref_inst_dict` with new ones after this - processing. - tile_mode (int): - A flag to indicate the type of this tile. There are 4 flags: - - 0: A tile from tile grid without any overlapping, it is - not an overlapping tile from tile generation. The - predicted instances are immediately added to - accumulated output. - - 1: Vertical tile strip that stands between two normal - tiles (flag 0). It has the same height as normal tile - but less width (hence vertical strip). - - 2: Horizontal tile strip that stands between two normal - tiles (flag 0). It has the same width as normal tile - but less height (hence horizontal strip). - - 3: Tile strip stands at the cross-section of four normal - tiles (flag 0). - tile_output (list): - A list of patch predictions, that lie within this tile, to - be merged and processed. - ref_inst_dict (dict): - Dictionary contains accumulated output. The expected format - is `{instance_id: {type: int, contour: List[List[int]], - centroid:List[float], box:List[int]}`. - postproc (callable): - Function to post-process the raw assembled tile. - merge_predictions (callable): - Function to merge the `tile_output` into raw tile - prediction. - - Returns: - tuple: - - :py:obj:`dict` - New instances dictionary: - A dictionary contain new instances to be accumulated. - The expected format is `{instance_id: {type: int, - contour: List[List[int]], centroid:List[float], - box:List[int]}`. - - :py:obj:`list` - Instances IDs to remove: - List of instance IDs within `ref_inst_dict` to be - removed to prevent overlapping predictions. These - instances are those get cut off at the boundary due to - the tiling process. - - """ - locations, predictions = list(zip(*tile_output, strict=False)) - - # convert from WSI space to tile space - tile_tl = tile_bounds[:2] - tile_br = tile_bounds[2:] - locations = [np.reshape(loc, (2, -1)) for loc in locations] - locations_in_tile = [loc - tile_tl[None] for loc in locations] - locations_in_tile = [loc.flatten() for loc in locations_in_tile] - locations_in_tile = np.array(locations_in_tile) - - tile_shape = tile_br - tile_tl # in width height - - # As the placement output is calculated wrt the highest possible - # resolution within input, the output will need to re-calibrate if - # it is at different resolution than the input. - ioconfig = ioconfig.to_baseline() - fx_list = [v["resolution"] for v in ioconfig.output_resolutions] - - head_raws = [] - for idx, fx in enumerate(fx_list): - head_tile_shape = np.ceil(tile_shape * fx).astype(np.int32) - head_locations = np.ceil(locations_in_tile * fx).astype(np.int32) - head_predictions = [v[idx][0] for v in predictions] - head_raw = merge_predictions( - head_tile_shape[::-1], - head_predictions, - head_locations, - ) - head_raws.append(head_raw) - _, inst_dict = postproc(head_raws) - - new_inst_dict, remove_insts_in_orig = _process_instance_predictions( - inst_dict, - ioconfig, - tile_shape, - tile_flag, - tile_mode, - tile_tl, - ref_inst_dict, - ) + model (str | ModelABC): + A PyTorch model instance or name of a pretrained model from TIAToolbox. + The user can request pretrained models from the toolbox model zoo using + the list of pretrained models available at this `link + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights using the `weights` parameter. Default is `None`. + batch_size (int): + Number of image patches processed per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. + + >>> engine = SemanticSegmentor( + ... model="pretrained-model", + ... weights="/path/to/pretrained-local-weights.pth" + ... ) + + device (str): + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". + verbose (bool): + Whether to enable verbose logging. Default is True. + + Attributes: + images (list[str | Path] | np.ndarray): + Input image patches or WSI paths. + masks (list[str | Path] | np.ndarray): + Optional tissue masks for WSI processing. + These are only utilized when patch_mode is False. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + patch_mode (bool): + Whether input is treated as patches (`True`) or WSIs (`False`). + model (ModelABC): + Loaded PyTorch model. + ioconfig (ModelIOConfigABC): + IO configuration for patch extraction and resolution. + return_labels (bool): + Whether to include labels in the output. + return_predictions_dict (dict): + This dictionary helps keep track of which tasks require predictions in + the output. + input_resolutions (list[dict]): + Resolution settings for model input. Supported + units are `level`, `power` and `mpp`. Keys should be "units" and + "resolution" e.g., [{"units": "mpp", "resolution": 0.25}]. Please see + :class:`WSIReader` for details. + patch_input_shape (tuple[int, int]): + Shape of input patches (height, width). Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple[int, int]): + Stride used during patch extraction. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + labels (list | None): + Optional labels for input images. + Only a single label per image is supported. + drop_keys (list): + Keys to exclude from model output. + output_type (str): + Format of output ("dict", "zarr", "annotationstore"). + output_locations (list | None): + Coordinates of output patches used during WSI processing. - return new_inst_dict, remove_insts_in_orig + Examples: + >>> # list of 2 image patches as input + >>> wsis = ['path/img.svs', 'path/img.svs'] + >>> mtsegmentor = MultiTaskSegmentor(model="hovernetplus-oed") + >>> output = mtsegmentor.run(wsis, patch_mode=False) + >>> # array of list of 2 image patches as input + >>> image_patches = [np.ndarray, np.ndarray] + >>> mtsegmentor = MultiTaskSegmentor(model="hovernetplus-oed") + >>> output = mtsegmentor.run(image_patches, patch_mode=True) -class NucleusInstanceSegmentor(SemanticSegmentor): - """An engine specifically designed to handle tiles or WSIs inference. + >>> # list of 2 image patch files as input + >>> data = ['path/img.png', 'path/img.png'] + >>> mtsegmentor = MultiTaskSegmentor(model="hovernet_fast-pannuke") + >>> output = mtsegmentor.run(data, patch_mode=False) - Note, if `model` is supplied in the arguments, it will ignore the - `pretrained_model` and `pretrained_weights` arguments. Additionally, - unlike `SemanticSegmentor`, this engine assumes each input model - will ultimately predict one single target: the nucleus instance - within the tiles/WSIs. Each WSI prediction will be store under a - `.dat` file which contains a dictionary of form: + >>> # list of 2 image tile files as input + >>> tile_file = ['path/tile1.png', 'path/tile2.png'] + >>> mtsegmentor = MultiTaskSegmentor(model="hovernet_fast-pannuke") + >>> output = mtsegmentor.run(tile_file, patch_mode=False) - .. code-block:: yaml + >>> # list of 2 wsi files as input + >>> wsis = ['path/wsi1.svs', 'path/wsi2.svs'] + >>> mtsegmentor = MultiTaskSegmentor(model="hovernet_fast-pannuke") + >>> output = mtsegmentor.run(wsis, patch_mode=False) - inst_uid: - # top left and bottom right of bounding box - box: (start_x, start_y, end_x, end_y) - # centroid coordinates - centroid: (x, y) - # array/list of points - contour: [(x1, y1), (x2, y2), ...] - # the type of nuclei - type: int - # the probabilities of being this nuclei type - prob: float - - Args: - model (nn.Module): - Use externally defined PyTorch model for prediction with. - weights already loaded. Default is `None`. If provided, - `pretrained_model` argument is ignored. - pretrained_model (str): - Name of the existing models support by tiatoolbox for - processing the data. For a full list of pretrained models, - refer to the `docs - `_. - By default, the corresponding pretrained weights will also - be downloaded. However, you can override with your own set - of weights via the `pretrained_weights` argument. Argument - is case insensitive. - pretrained_weights (str): - Path to the weight of the corresponding `pretrained_model`. - batch_size (int) : - Number of images fed into the model each time. - num_loader_workers (int): - Number of workers to load the data. - Take note that they will also perform preprocessing. - num_postproc_workers (int): - Number of workers to post-process predictions. - verbose (bool): - Whether to output logging information. - dataset_class (obj): - Dataset class to be used instead of default. - auto_generate_mask (bool): - To automatically generate tile/WSI tissue mask - if is not provided. - - Examples: - >>> # Sample output of a network - >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] - >>> predictor = SemanticSegmentor(model='hovernet_fast-pannuke') - >>> output = predictor.predict(wsis, mode='wsi') - >>> list(output.keys()) - [('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')] - >>> # Each output of 'A/wsi.svs' - >>> # will be respectively stored in 'output/0.dat', 'output/0.dat' """ def __init__( - self: NucleusInstanceSegmentor, + self: MultiTaskSegmentor, + model: str | ModelABC, batch_size: int = 8, - num_loader_workers: int = 0, - num_postproc_workers: int = 0, - model: torch.nn.Module | None = None, - pretrained_model: str | None = None, - pretrained_weights: str | None = None, - dataset_class: Callable = WSIStreamDataset, + num_workers: int = 0, + weights: str | Path | None = None, *, + device: str = "cpu", verbose: bool = True, - auto_generate_mask: bool = False, ) -> None: - """Initialize :class:`NucleusInstanceSegmentor`.""" - super().__init__( - batch_size=batch_size, - num_loader_workers=num_loader_workers, - num_postproc_workers=num_postproc_workers, - model=model, - pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, - verbose=verbose, - auto_generate_mask=auto_generate_mask, - dataset_class=dataset_class, - ) - # default is None in base class and is un-settable - # hence we redefine the namespace here - self.num_postproc_workers = ( - num_postproc_workers if num_postproc_workers > 0 else None - ) - - # adding more runtime placeholder - self._wsi_inst_info = None - self._futures = [] - - @staticmethod - def _get_tile_info( - image_shape: list[int] | np.ndarray, - ioconfig: IOInstanceSegmentorConfig, - ) -> list[list, ...]: - """Generating tile information. - - To avoid out of memory problem when processing WSI-scale in - general, the predictor will perform the inference and assemble - on a large image tiles (each may have size of 4000x4000 compared - to patch output of 256x256) first before stitching every tiles - by the end to complete the WSI output. For nuclei instance - segmentation, the stitching process will require removal of - predictions within some bounding areas. This function generates - both the tile placement and the flag to indicate how the removal - should be done to achieve the above goal. + """Initialize :class:`MultiTaskSegmentor`. Args: - image_shape (:class:`numpy.ndarray`, list(int)): - The shape of WSI to extract the tile from, assumed to be - in `[width, height]`. - ioconfig (:obj:IOInstanceSegmentorConfig): - The input and output configuration objects. - - Returns: - list: - - :py:obj:`list` - Tiles and flags - - :class:`numpy.ndarray` - Grid tiles - - :class:`numpy.ndarray` - Removal flags - - :py:obj:`list` - Tiles and flags - - :class:`numpy.ndarray` - Vertical strip tiles - - :class:`numpy.ndarray` - Removal flags - - :py:obj:`list` - Tiles and flags - - :class:`numpy.ndarray` - Horizontal strip tiles - - :class:`numpy.ndarray` - Removal flags - - :py:obj:`list` - Tiles and flags - - :class:`numpy.ndarray` - Cross section tiles - - :class:`numpy.ndarray` - Removal flags + model (str | ModelABC): + A PyTorch model instance or name of a pretrained model from TIAToolbox. + If a string is provided, the corresponding pretrained weights will be + downloaded unless overridden via `weights`. + batch_size (int): + Number of image patches processed per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". + verbose (bool): + Whether to enable verbose logging. Default is True. """ - margin = np.array(ioconfig.margin) - tile_shape = np.array(ioconfig.tile_shape) - tile_shape = ( - np.floor(tile_shape / ioconfig.patch_output_shape) - * ioconfig.patch_output_shape - ).astype(np.int32) - image_shape = np.array(image_shape) - tile_outputs = PatchExtractor.get_coordinates( - image_shape=image_shape, - patch_input_shape=tile_shape, - patch_output_shape=tile_shape, - stride_shape=tile_shape, - ) - - # * === Now generating the flags to indicate which side should - # * === be removed in postproc callback - boxes = tile_outputs[1] - - # This saves computation time if the image is smaller than the expected tile - if np.all(image_shape <= tile_shape): - flag = np.zeros([boxes.shape[0], 4], dtype=np.int32) - return [[boxes, flag]] - - # * remove all sides for boxes - # unset for those lie within the selection - def unset_removal_flag(boxes: tuple, removal_flag: np.ndarray) -> np.ndarray: - """Unset removal flags for tiles intersecting image boundaries.""" - sel_boxes = [ - shapely_box(0, 0, w, 0), # top edge - shapely_box(0, h, w, h), # bottom edge - shapely_box(0, 0, 0, h), # left - shapely_box(w, 0, w, h), # right - ] - geometries = [shapely_box(*bounds) for bounds in boxes] - spatial_indexer = STRtree(geometries) - - for idx, sel_box in enumerate(sel_boxes): - sel_indices = list(spatial_indexer.query(sel_box)) - removal_flag[sel_indices, idx] = 0 - return removal_flag - - w, h = image_shape - boxes = tile_outputs[1] - # expand to full four corners - boxes_br = boxes[:, 2:] - boxes_tr = np.dstack([boxes[:, 2], boxes[:, 1]])[0] - boxes_bl = np.dstack([boxes[:, 0], boxes[:, 3]])[0] - - # * remove edges on all sides, excluding edges at on WSI boundary - flag = np.ones([boxes.shape[0], 4], dtype=np.int32) - flag = unset_removal_flag(boxes, flag) - info = deque([[boxes, flag]]) - - # * create vertical boxes at tile boundary and - # * flag top and bottom removal, excluding those - # * on the WSI boundary - # ------------------- - # | =|= =|= | - # | =|= =|= | - # | >=|= >=|= | - # ------------------- - # | >=|= >=|= | - # | =|= =|= | - # | >=|= >=|= | - # ------------------- - # | >=|= >=|= | - # | =|= =|= | - # | =|= =|= | - # ------------------- - # only select boxes having right edges removed - sel_indices = np.nonzero(flag[..., 3]) - _boxes = np.concatenate( - [ - boxes_tr[sel_indices] - np.array([margin, 0])[None], - boxes_br[sel_indices] + np.array([margin, 0])[None], - ], - axis=-1, + warnings.warn( + "NucleusInstanceSegmentor is deprecated and will be " + "removed in a future release. " + "Use MultiTaskSegmentor instead.", + DeprecationWarning, + stacklevel=2, ) - _flag = np.full([_boxes.shape[0], 4], 0, dtype=np.int32) - _flag[:, [0, 1]] = 1 - _flag = unset_removal_flag(_boxes, _flag) - info.append([_boxes, _flag]) - - # * create horizontal boxes at tile boundary and - # * flag left and right removal, excluding those - # * on the WSI boundary - # ------------- - # | | | | - # | v|v v|v | - # |===|===|===| - # ------------- - # |===|===|===| - # | | | | - # | | | | - # ------------- - # only select boxes having bottom edges removed - sel_indices = np.nonzero(flag[..., 1]) - # top bottom left right - _boxes = np.concatenate( - [ - boxes_bl[sel_indices] - np.array([0, margin])[None], - boxes_br[sel_indices] + np.array([0, margin])[None], - ], - axis=-1, - ) - _flag = np.full([_boxes.shape[0], 4], 0, dtype=np.int32) - _flag[:, [2, 3]] = 1 - _flag = unset_removal_flag(_boxes, _flag) - info.append([_boxes, _flag]) - - # * create boxes at tile cross-section and all sides - # ------------------------ - # | | | | | - # | v| | | | - # | > =|= =|= =|= | - # -----=-=---=-=---=-=---- - # | =|= =|= =|= | - # | | | | | - # | =|= =|= =|= | - # -----=-=---=-=---=-=---- - # | =|= =|= =|= | - # | | | | | - # | | | | | - # ------------------------ - - # only select boxes having both right and bottom edges removed - sel_indices = np.nonzero(np.prod(flag[:, [1, 3]], axis=-1)) - _boxes = np.concatenate( - [ - boxes_br[sel_indices] - np.array([2 * margin, 2 * margin])[None], - boxes_br[sel_indices] + np.array([2 * margin, 2 * margin])[None], - ], - axis=-1, - ) - flag = np.full([_boxes.shape[0], 4], 1, dtype=np.int32) - info.append([_boxes, flag]) - - return info - - def _to_shared_space( - self: NucleusInstanceSegmentor, - wsi_idx: int, - patch_inputs: list, - patch_outputs: list, - ) -> None: - """Helper functions to transfer variable to shared space. - - We modify the shared space so that we can update worker info - without needing to re-create the worker. There should be no - race-condition because only by looping `self._loader` in main - thread will trigger querying new data from each worker, and this - portion should still be in sequential execution order in the - main thread. - Args: - wsi_idx (int): - The index of the WSI to be processed. This is used to - retrieve the file path. - patch_inputs (list): - A list of coordinates in `[start_x, start_y, end_x, - end_y]` format indicating the read location of the patch - in the WSI image. The coordinates are in the highest - resolution defined in `self.ioconfig`. - patch_outputs (list): - A list of coordinates in `[start_x, start_y, end_x, - end_y]` format indicating the write location of the - patch in the WSI image. The coordinates are in the - highest resolution defined in `self.ioconfig`. - - """ - patch_inputs = torch.from_numpy(patch_inputs).share_memory_() - patch_outputs = torch.from_numpy(patch_outputs).share_memory_() - self._mp_shared_space.patch_inputs = patch_inputs - self._mp_shared_space.patch_outputs = patch_outputs - self._mp_shared_space.wsi_idx = torch.Tensor([wsi_idx]).share_memory_() - - def _infer_once(self: NucleusInstanceSegmentor) -> list: - """Running the inference only once for the currently active dataloader.""" - num_steps = len(self._loader) - - pbar_desc = "Process Batch: " - pbar = tqdm.tqdm( - desc=pbar_desc, - leave=True, - total=int(num_steps), - ncols=80, - ascii=True, - position=0, + logger.warning( + "NucleusInstanceSegmentor is deprecated and will be " + "removed in a future release." ) - - cum_output = [] - for _, batch_data in enumerate(self._loader): - sample_datas, sample_infos = batch_data - batch_size = sample_infos.shape[0] - # ! depending on the protocol of the output within infer_batch - # ! this may change, how to enforce/document/expose this in a - # ! sensible way? - - # assume to return a list of L output, - # each of shape N x etc. (N=batch size) - sample_outputs = self.model.infer_batch( - self._model, - sample_datas, - device=self._device, - ) - # repackage so that it's a N list, each contains - # L x etc. output - sample_outputs = [np.split(v, batch_size, axis=0) for v in sample_outputs] - sample_outputs = list(zip(*sample_outputs, strict=False)) - - # tensor to numpy, costly? - sample_infos = sample_infos.numpy() - sample_infos = np.split(sample_infos, batch_size, axis=0) - - sample_outputs = list(zip(sample_infos, sample_outputs, strict=False)) - cum_output.extend(sample_outputs) - pbar.update() - pbar.close() - return cum_output - - def _predict_one_wsi( - self: NucleusInstanceSegmentor, - wsi_idx: int, - ioconfig: IOSegmentorConfig, - save_path: str, - mode: str, - ) -> None: - """Make a prediction on tile/wsi. - - Args: - wsi_idx (int): - Index of the tile/wsi to be processed within `self`. - ioconfig (IOInstanceSegmentorConfig): - Object which defines I/O placement during inference and - when assembling back to full tile/wsi. - save_path (str): - Location to save output prediction as well as possible - intermediate results. - mode (str): - `tile` or `wsi` to indicate run mode. - - """ - wsi_path = self.imgs[wsi_idx] - mask_path = None if self.masks is None else self.masks[wsi_idx] - wsi_reader, mask_reader = self.get_reader( - wsi_path, - mask_path, - mode, - auto_get_mask=self.auto_generate_mask, + super().__init__( + model=model, + batch_size=batch_size, + num_workers=num_workers, + weights=weights, + device=device, + verbose=verbose, ) - - # assume ioconfig has already been converted to `baseline` for `tile` mode - resolution = ioconfig.highest_input_resolution - wsi_proc_shape = wsi_reader.slide_dimensions(**resolution) - - # * retrieve patch placement - # this is in XY - (patch_inputs, patch_outputs) = self.get_coordinates(wsi_proc_shape, ioconfig) - if mask_reader is not None: - sel = self.filter_coordinates(mask_reader, patch_outputs, **resolution) - patch_outputs = patch_outputs[sel] - patch_inputs = patch_inputs[sel] - - # assume to be in [top_left_x, top_left_y, bot_right_x, bot_right_y] - geometries = [shapely_box(*bounds) for bounds in patch_outputs] - spatial_indexer = STRtree(geometries) - - # * retrieve tile placement and tile info flag - # tile shape will always be corrected to be multiple of output - tile_info_sets = self._get_tile_info(wsi_proc_shape, ioconfig) - - # ! running order of each set matters ! - self._futures = [] - - # ! DEPRECATION: - # ! will be deprecated upon finalization of SQL annotation store - self._wsi_inst_info = {} - # ! - - for set_idx, (set_bounds, set_flags) in enumerate(tile_info_sets): - for tile_idx, tile_bounds in enumerate(set_bounds): - tile_flag = set_flags[tile_idx] - - # select any patches that have their output - # within the current tile - sel_box = shapely_box(*tile_bounds) - sel_indices = list(spatial_indexer.query(sel_box)) - - # there is nothing in the tile - # Ignore coverage as the condition is difficult - # to reproduce on travis. - if len(sel_indices) == 0: # pragma: no cover - continue - - tile_patch_inputs = patch_inputs[sel_indices] - tile_patch_outputs = patch_outputs[sel_indices] - self._to_shared_space(wsi_idx, tile_patch_inputs, tile_patch_outputs) - - tile_infer_output = self._infer_once() - - self._process_tile_predictions( - ioconfig, - tile_bounds, - tile_flag, - set_idx, - tile_infer_output, - ) - - self._merge_post_process_results() - joblib.dump(self._wsi_inst_info, f"{save_path}.dat") - # may need to chain it with parents - self._wsi_inst_info = None # clean up - - def _process_tile_predictions( - self: NucleusInstanceSegmentor, - ioconfig: IOSegmentorConfig, - tile_bounds: np.ndarray, - tile_flag: list, - tile_mode: int, - tile_output: list, - ) -> None: - """Function to dispatch parallel post processing.""" - args = [ - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, - self._wsi_inst_info, - self.model.postproc_func, - self.merge_prediction, - ] - if self._postproc_workers is not None: - future = self._postproc_workers.submit(_process_tile_predictions, *args) - else: - future = _process_tile_predictions(*args) - self._futures.append(future) - - def _merge_post_process_results(self: NucleusInstanceSegmentor) -> None: - """Helper to aggregate results from parallel workers.""" - - def callback(new_inst_dict: dict, remove_uuid_list: list) -> None: - """Helper to aggregate worker's results.""" - # ! DEPRECATION: - # ! will be deprecated upon finalization of SQL annotation store - self._wsi_inst_info.update(new_inst_dict) - for inst_uuid in remove_uuid_list: - self._wsi_inst_info.pop(inst_uuid, None) - # ! - - for future in self._futures: - # not actually future but the results - if self._postproc_workers is None: - callback(*future) - continue - - # some errors happen, log it and propagate exception - # ! this will lead to discard a bunch of - # ! inferred tiles within this current WSI - if future.exception() is not None: - raise future.exception() - - # aggregate the result via callback - result = future.result() - # manually call the callback rather than - # attaching it when receiving/creating the future - callback(*result) From 86111a6818b358109a16db21b801475047c06c8d Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 3 Feb 2026 15:51:18 +0000 Subject: [PATCH 083/156] :white_check_mark: Add tests for io_config. --- tests/engines/test_ioconfig.py | 100 ++++++++++++++++++++++++++++++++- 1 file changed, 99 insertions(+), 1 deletion(-) diff --git a/tests/engines/test_ioconfig.py b/tests/engines/test_ioconfig.py index 41169298b..6b03cbeab 100644 --- a/tests/engines/test_ioconfig.py +++ b/tests/engines/test_ioconfig.py @@ -1,8 +1,9 @@ """Tests for IOconfig.""" +import numpy as np import pytest -from tiatoolbox.models import ModelIOConfigABC +from tiatoolbox.models import IOSegmentorConfig, ModelIOConfigABC def test_validation_error_io_config() -> None: @@ -21,3 +22,100 @@ def test_validation_error_io_config() -> None: input_resolutions=[{"units": "level", "resolution": 1.0}], patch_input_shape=(224, 224), ) + + +def test_scale_to_highest_mpp() -> None: + """Mpp → min(old_vals) / old_vals.""" + resolutions = [ + {"units": "mpp", "resolution": 0.25}, + {"units": "mpp", "resolution": 0.5}, + ] + result = ModelIOConfigABC.scale_to_highest(resolutions, units="mpp") + + expected = np.array([1.0, 0.5]) # 0.25 / [0.25, 0.5] + np.testing.assert_allclose(result, expected) + + +def test_scale_to_highest_mpp_reversed_order() -> None: + """Ensure order is preserved even when resolutions are reversed.""" + resolutions = [ + {"units": "mpp", "resolution": 0.5}, + {"units": "mpp", "resolution": 0.25}, + ] + result = ModelIOConfigABC.scale_to_highest(resolutions, units="mpp") + + expected = np.array([0.5, 1.0]) # 0.25 / [0.5, 0.25] + np.testing.assert_allclose(result, expected) + + +def test_scale_to_highest_baseline() -> None: + """Baseline → identity.""" + resolutions = [ + {"units": "baseline", "resolution": 2.0}, + {"units": "baseline", "resolution": 4.0}, + ] + result = ModelIOConfigABC.scale_to_highest(resolutions, units="baseline") + + expected = [2.0, 4.0] + assert result == expected + + +def test_scale_to_highest_power() -> None: + """Power → old_vals / max(old_vals).""" + resolutions = [ + {"units": "power", "resolution": 10}, + {"units": "power", "resolution": 5}, + ] + result = ModelIOConfigABC.scale_to_highest(resolutions, units="power") + + expected = np.array([1.0, 0.5]) # [10, 5] / 10 + np.testing.assert_allclose(result, expected) + + +def test_scale_to_highest_invalid_units() -> None: + """Test ModelIOConfigABC for unknown units.""" + resolutions = [{"units": "mpp", "resolution": 1.0}] + with pytest.raises(ValueError, match="Unknown units"): + ModelIOConfigABC.scale_to_highest(resolutions, units="unknown") + + +def test_modelio_to_baseline_without_save_resolution() -> None: + """Test ModelIOConfigABC when save_resolution is None. + + Ensure ModelIOConfigABC.to_baseline does NOT add or convert + save_resolution when it is None. + + """ + cfg = ModelIOConfigABC( + input_resolutions=[{"units": "mpp", "resolution": 0.5}], + output_resolutions=[{"units": "mpp", "resolution": 1.0}], + patch_input_shape=(224, 224), + stride_shape=(224, 224), + ) + + new_cfg = cfg.to_baseline() + + # save_resolution should not appear in the new config + assert not hasattr(new_cfg, "save_resolution") or new_cfg.save_resolution is None + + +def test_ios_to_baseline_without_save_resolution() -> None: + """Test IOSegmentorConfig when save_resolution is None. + + Ensure IOSegmentorConfig.to_baseline leaves save_resolution=None + when no save_resolution is provided. + + """ + cfg = IOSegmentorConfig( + input_resolutions=[{"units": "mpp", "resolution": 0.5}], + output_resolutions=[{"units": "mpp", "resolution": 1.0}], + patch_input_shape=(224, 224), + patch_output_shape=(112, 112), + stride_shape=(224, 224), + save_resolution=None, + ) + + new_cfg = cfg.to_baseline() + + # save_resolution should remain None after conversion + assert new_cfg.save_resolution is None From de65370c9114f55317196bb12211c285dad32b57 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 3 Feb 2026 16:48:01 +0000 Subject: [PATCH 084/156] :white_check_mark: Add `cli` for `nucleus_instance_segmentor`. --- .../test_nucleus_instance_segmentor.py | 37 +++++ tiatoolbox/cli/__init__.py | 1 + tiatoolbox/cli/common.py | 14 -- tiatoolbox/cli/nucleus_instance_segment.py | 129 ++++++++++++------ .../models/engine/multi_task_segmentor.py | 8 +- 5 files changed, 135 insertions(+), 54 deletions(-) diff --git a/tests/engines/test_nucleus_instance_segmentor.py b/tests/engines/test_nucleus_instance_segmentor.py index 8121046d7..8a7411d20 100644 --- a/tests/engines/test_nucleus_instance_segmentor.py +++ b/tests/engines/test_nucleus_instance_segmentor.py @@ -7,7 +7,10 @@ import numpy as np import torch +import zarr +from click.testing import CliRunner +from tiatoolbox import cli from tiatoolbox.models.engine.nucleus_instance_segmentor import ( NucleusInstanceSegmentor, ) @@ -78,3 +81,37 @@ def test_mtsegmentor_patches(remote_sample: Callable) -> None: fields=["box", "centroid", "contours", "prob", "type"], ) assert_predictions_and_boxes(output_dict, expected_counts_nuclei, is_zarr=False) + + +# ------------------------------------------------------------------------------------- +# Command Line Interface +# ------------------------------------------------------------------------------------- + + +def test_cli_model_single_file(remote_sample: Callable, track_tmp_path: Path) -> None: + """Test semantic segmentor CLI single file.""" + wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") + runner = CliRunner() + models_wsi_result = runner.invoke( + cli.main, + [ + "nucleus-instance-segment", + "--img-input", + str(wsi4_512_512_svs), + "--patch-mode", + "False", + "--output-path", + str(track_tmp_path / "output"), + "--return-predictions", + "True", + ], + ) + + assert models_wsi_result.exit_code == 0 + assert (track_tmp_path / "output" / f"{wsi4_512_512_svs.stem}.db").exists() + zarr_group = zarr.open( + str(track_tmp_path / "output" / f"{wsi4_512_512_svs.stem}.zarr"), mode="r" + ) + assert "probabilities" in zarr_group + assert "nuclei_segmentation" not in zarr_group + assert "predictions" in zarr_group diff --git a/tiatoolbox/cli/__init__.py b/tiatoolbox/cli/__init__.py index 74e01240f..838b14dfa 100644 --- a/tiatoolbox/cli/__init__.py +++ b/tiatoolbox/cli/__init__.py @@ -48,6 +48,7 @@ def main() -> int: main.add_command(semantic_segmentor) main.add_command(multitask_segmentor) main.add_command(nucleus_detector) +main.add_command(nucleus_instance_segment) main.add_command(deep_feature_extractor) main.add_command(slide_info) main.add_command(slide_thumbnail) diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index 5b8260d4d..e88eded62 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -606,20 +606,6 @@ def cli_merge_predictions( ) -def cli_return_labels( - usage_help: str = "Whether to return raw model output as labels.", - *, - default: bool = False, -) -> Callable: - """Enables --return-labels option for cli.""" - return click.option( - "--return-labels", - type=bool, - help=add_default_to_usage_help(usage_help, default=default), - default=default, - ) - - def cli_batch_size( usage_help: str = "Number of image patches to feed into the model each time.", default: int = 1, diff --git a/tiatoolbox/cli/nucleus_instance_segment.py b/tiatoolbox/cli/nucleus_instance_segment.py index 707e71f5b..e5c89078f 100644 --- a/tiatoolbox/cli/nucleus_instance_segment.py +++ b/tiatoolbox/cli/nucleus_instance_segment.py @@ -2,74 +2,114 @@ from __future__ import annotations -import click +from typing import TYPE_CHECKING from tiatoolbox.cli.common import ( cli_auto_get_mask, cli_batch_size, + cli_class_dict, cli_device, cli_file_type, cli_img_input, + cli_input_resolutions, cli_masks, - cli_mode, + cli_memory_threshold, + cli_model, cli_num_workers, + cli_output_file, cli_output_path, - cli_pretrained_model, - cli_pretrained_weights, + cli_output_resolutions, + cli_output_type, + cli_overwrite, + cli_patch_input_shape, + cli_patch_mode, + cli_patch_output_shape, + cli_return_predictions, + cli_return_probabilities, + cli_scale_factor, + cli_stride_shape, cli_verbose, + cli_weights, cli_yaml_config_path, prepare_ioconfig, prepare_model_cli, tiatoolbox_cli, ) +if TYPE_CHECKING: # pragma: no cover + from tiatoolbox.type_hints import IntPair + @tiatoolbox_cli.command() @cli_img_input() @cli_output_path( - usage_help="Output directory where model predictions will be saved.", - default="nucleus_instance_segmentation", + usage_help="Output directory where model segmentation will be saved.", + default="semantic_segmentation", ) +@cli_output_file(default=None) @cli_file_type( default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs", ) -@cli_mode( - usage_help="Type of input file to process.", - default="wsi", - input_type=click.Choice(["patch", "wsi", "tile"], case_sensitive=False), -) -@cli_pretrained_model(default="hovernet_fast-pannuke") -@cli_pretrained_weights(default=None) +@cli_input_resolutions(default=None) +@cli_output_resolutions(default=None) +@cli_class_dict(default=None) +@cli_model(default="hovernet_fast-pannuke") +@cli_weights() @cli_device(default="cpu") -@cli_batch_size() +@cli_batch_size(default=1) +@cli_yaml_config_path() @cli_masks(default=None) -@cli_yaml_config_path(default=None) -@cli_num_workers() +@cli_num_workers(default=0) +@cli_output_type( + default="AnnotationStore", +) +@cli_memory_threshold(default=80) +@cli_patch_input_shape(default=None) +@cli_patch_output_shape(default=None) +@cli_stride_shape(default=None) +@cli_scale_factor(default=None) +@cli_patch_mode(default=False) +@cli_return_predictions(default=None) +@cli_return_probabilities(default=True) +@cli_auto_get_mask(default=True) +@cli_overwrite(default=False) @cli_verbose(default=True) -@cli_auto_get_mask(default=False) def nucleus_instance_segment( - pretrained_model: str, - pretrained_weights: str, + model: str, + weights: str, img_input: str, file_types: str, + class_dict: list[tuple[int, str]], + input_resolutions: list[dict], + output_resolutions: list[dict], masks: str | None, - mode: str, output_path: str, + patch_input_shape: IntPair | None, + patch_output_shape: tuple[int, int] | None, + stride_shape: IntPair | None, + scale_factor: tuple[float, float] | None, batch_size: int, yaml_config_path: str, - num_loader_workers: int, + num_workers: int, device: str, + output_type: str, + memory_threshold: int, + output_file: str | None, *, - auto_generate_mask: bool, + patch_mode: bool, + return_predictions: tuple[bool, ...] | None, + return_probabilities: bool, + auto_get_mask: bool, verbose: bool, + overwrite: bool, ) -> None: - """Process an image/directory of input images with a patch classification CNN.""" + """Process a set of input images with a multitask segmentation engine.""" from tiatoolbox.models import ( # noqa: PLC0415 - IOInstanceSegmentorConfig, + IOSegmentorConfig, NucleusInstanceSegmentor, ) - from tiatoolbox.utils import save_as_json # noqa: PLC0415 + class_dict = dict(class_dict) if class_dict else None files_all, masks_all, output_path = prepare_model_cli( img_input=img_input, output_path=output_path, @@ -78,27 +118,38 @@ def nucleus_instance_segment( ) ioconfig = prepare_ioconfig( - IOInstanceSegmentorConfig, - pretrained_weights, - yaml_config_path, + IOSegmentorConfig, + pretrained_weights=weights, + yaml_config_path=yaml_config_path, ) - predictor = NucleusInstanceSegmentor( - pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, + nuc_inst_segmentor = NucleusInstanceSegmentor( + model=model, + weights=weights, batch_size=batch_size, - num_loader_workers=num_loader_workers, - auto_generate_mask=auto_generate_mask, + num_workers=num_workers, verbose=verbose, ) - output = predictor.predict( - imgs=files_all, + _ = nuc_inst_segmentor.run( + images=files_all, masks=masks_all, - mode=mode, + class_dict=class_dict, + patch_mode=patch_mode, + patch_input_shape=patch_input_shape, + patch_output_shape=patch_output_shape, + input_resolutions=input_resolutions, + output_resolutions=output_resolutions, + ioconfig=ioconfig, device=device, save_dir=output_path, - ioconfig=ioconfig, + output_type=output_type, + return_predictions=return_predictions, + return_probabilities=return_probabilities, + auto_get_mask=auto_get_mask, + memory_threshold=memory_threshold, + output_file=output_file, + scale_factor=scale_factor, + stride_shape=stride_shape, + overwrite=overwrite, ) - - save_as_json(output, str(output_path.joinpath("results.json"))) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 7caa492a6..f22260ee6 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1524,11 +1524,17 @@ def _save_predictions_as_annotationstore( for key in ("canvas", "count"): processed_predictions.pop(key, None) + return_predictions = ( + next(iter(self.return_predictions_dict.values())) + if task_name is None and len(self.return_predictions_dict) == 1 + else self.return_predictions_dict.get(task_name) + ) + keys_to_compute = list(processed_predictions.keys()) if "probabilities" in keys_to_compute: keys_to_compute.remove("probabilities") if "predictions" in keys_to_compute: - if not self.return_predictions_dict.get(task_name): + if not return_predictions: processed_predictions.pop("predictions") keys_to_compute.remove("predictions") if self.patch_mode: From 52e21e6d0c54a3434f14bf1d4d20fa41599915a3 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 3 Feb 2026 16:57:30 +0000 Subject: [PATCH 085/156] :white_check_mark: Add tests for cli --- tests/test_cli.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/test_cli.py b/tests/test_cli.py index beb6d831c..00b92dcce 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,8 @@ """Tests for cli inputs.""" import json +from pathlib import Path +from unittest.mock import patch import click import pytest @@ -11,6 +13,7 @@ cli_input_resolutions, cli_output_resolutions, parse_bool_list, + prepare_model_cli, ) @@ -176,3 +179,55 @@ def test_parse_bool_list_invalid(bad_value: str) -> None: """parse_bool_list should raise BadParameter on invalid tokens.""" with pytest.raises(click.BadParameter): parse_bool_list(_ctx=None, _param=None, value=bad_value) + + +def test_output_path_exists_raises() -> None: + """Ensure FileExistsError when the output_path already exists on disk.""" + img_input = Path("input.jpg") + output_path = Path("out") + masks = None + + # First call: output_path.exists() → True + # Second call: img_input.exists() → True (never reached) + with ( + patch.object(Path, "exists", side_effect=[True, True]), + pytest.raises(FileExistsError), + ): + prepare_model_cli(img_input, output_path, masks, "*.jpg") + + +def test_img_input_not_found_raises() -> None: + """Ensure FileNotFoundError when the img_input path does not exist.""" + img_input = Path("missing.jpg") + output_path = Path("out") + masks = None + + # output_path.exists() → False + # img_input.exists() → False + with ( + patch.object(Path, "exists", side_effect=[False, False]), + pytest.raises(FileNotFoundError), + ): + prepare_model_cli(img_input, output_path, masks, "*.jpg") + + +def test_masks_is_file() -> None: + """Verify that when masks is a file a list containing that mask file is returned.""" + img_input = Path("input.jpg") + output_path = Path("out") + masks = Path("mask.png") + + # output_path.exists() → False + # img_input.exists() → True + with ( + patch.object(Path, "exists", side_effect=[False, True]), + patch.object(Path, "is_file", return_value=True), + patch.object(Path, "is_dir", return_value=False), + ): + files, masks_all, out = prepare_model_cli( + img_input, output_path, masks, "*.jpg" + ) + + assert files == [img_input] + assert masks_all == [masks] + assert out == output_path From 151b61e3f7299794ed84d3c59fd715ef26119027 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 3 Feb 2026 17:53:27 +0000 Subject: [PATCH 086/156] :white_check_mark: Add tests for cli coverage. --- tiatoolbox/cli/common.py | 47 ----------------------------------- tiatoolbox/cli/tissue_mask.py | 2 +- 2 files changed, 1 insertion(+), 48 deletions(-) diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index e88eded62..31fff6007 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -435,39 +435,6 @@ def cli_method( ) -def cli_pretrained_model( - usage_help: str = "Name of the predefined model used to process the data. " - "The format is _. For example, " - "`resnet18-kather100K` is a resnet18 model trained on the Kather dataset. " - "Please see " - "https://tia-toolbox.readthedocs.io/en/latest/usage.html#deep-learning-models " - "for a detailed list of available pretrained models." - "By default, the corresponding pretrained weights will also be" - "downloaded. However, you can override with your own set of weights" - "via the `pretrained_weights` argument. Argument is case insensitive.", - default: str = "resnet18-kather100k", -) -> Callable: - """Enables --pretrained-model option for cli.""" - return click.option( - "--model", - help=add_default_to_usage_help(usage_help, default=default), - default=default, - ) - - -def cli_pretrained_weights( - usage_help: str = "Path to the model weight file. If not supplied, the default " - "pretrained weight will be used.", - default: str | None = None, -) -> Callable: - """Enables --pretrained-weights option for cli.""" - return click.option( - "--pretrained-weights", - help=add_default_to_usage_help(usage_help, default=default), - default=default, - ) - - def cli_model( usage_help: str = "Name of the predefined model used to process the data. " "The format is _. For example, " @@ -592,20 +559,6 @@ def cli_return_predictions( ) -def cli_merge_predictions( - usage_help: str = "Whether to merge the predictions to form a 2-dimensional map.", - *, - default: bool = True, -) -> Callable: - """Enables --merge-predictions option for cli.""" - return click.option( - "--merge-predictions", - type=bool, - default=default, - help=add_default_to_usage_help(usage_help, default=default), - ) - - def cli_batch_size( usage_help: str = "Number of image patches to feed into the model each time.", default: int = 1, diff --git a/tiatoolbox/cli/tissue_mask.py b/tiatoolbox/cli/tissue_mask.py index bb5bfe3af..9a036ae92 100644 --- a/tiatoolbox/cli/tissue_mask.py +++ b/tiatoolbox/cli/tissue_mask.py @@ -53,7 +53,7 @@ def get_masker( default="power", input_type=click.Choice(["mpp", "power"], case_sensitive=False), ) -@cli_mode(default="show") +@cli_mode(default="show", input_type=click.Choice(["show", "save"])) @cli_file_type(default="*.svs, *.ndpi, *.jp2, *.png, *.jpg, *.tif, *.tiff") # inputs specific to this function @click.option( From 6015d5faf337cb4a7e21edbaddab71ad4f18a80a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 4 Feb 2026 14:00:37 +0000 Subject: [PATCH 087/156] :white_check_mark: Add progress bar for saving annotationstore. --- tiatoolbox/models/engine/multi_task_segmentor.py | 6 +++++- tiatoolbox/models/engine/nucleus_detector.py | 2 +- tiatoolbox/utils/misc.py | 12 ++++++++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index f22260ee6..c8aee2b63 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1970,7 +1970,11 @@ def dict_to_store( contour = processed_predictions.pop("contours") ann = [] - for i, contour_ in enumerate(contour): + tqdm_ = get_tqdm() + + for i, contour_ in enumerate( + tqdm_(contour, leave=False, desc="Converting outputs to AnnotationStore.") + ): ann_ = Annotation( make_valid_poly( feature2geometry( diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index 05b81a535..fc8364c59 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -879,7 +879,7 @@ def make_points(xs_batch: np.ndarray, ys_batch: np.ndarray) -> list[Point]: ] tqdm = get_tqdm() - tqdm_loop = tqdm(range(0, n, batch_size), desc="Writing detections to store") + tqdm_loop = tqdm(range(0, n, batch_size), desc="Writing detections to store.") written = 0 for i in tqdm_loop: j = min(i + batch_size, n) diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 896223f06..e9728b256 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1220,7 +1220,11 @@ def patch_predictions_as_annotations( ) -> list: """Helper function to generate annotation per patch predictions.""" annotations = [] - for i, _ in enumerate(patch_coords): + tqdm_ = get_tqdm() + + for i, _ in enumerate( + tqdm_(patch_coords, leave=False, desc="Converting outputs to AnnotationStore.") + ): if "probabilities" in keys: props = { f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted @@ -1390,7 +1394,11 @@ def dict_to_store_semantic_segmentor( annotations_list: list[Annotation] = [] - for type_class in layer_list: + tqdm_ = get_tqdm() + + for type_class in tqdm_( + layer_list, leave=False, desc="Converting outputs to AnnotationStore." + ): class_id = int(type_class) class_label = class_dict.get(class_id, class_id) layer = da.where(preds == type_class, 1, 0).astype("uint8").compute() From a0ae9d1c8eee9c82215052e2a63c0bd1acfa8df9 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 4 Feb 2026 16:35:18 +0000 Subject: [PATCH 088/156] :white_check_mark: Update ioconfig --- tiatoolbox/data/pretrained_model.yaml | 1 + tiatoolbox/models/engine/io_config.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 9f715593e..46f8f6358 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -825,6 +825,7 @@ mapde-crchisto: - { "units": "mpp", "resolution": 0.5 } output_resolutions: - { "units": "mpp", "resolution": 0.5 } + tile_shape: [2048, 2048] patch_input_shape: [ 252, 252 ] patch_output_shape: [ 252, 252 ] stride_shape: [ 150, 150 ] diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index e8d8b2399..8c3d49473 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -230,6 +230,8 @@ class IOSegmentorConfig(ModelIOConfigABC): Shape of the largest output in (height, width). save_resolution (dict): Resolution to save all output. + tile_shape (tuple(int, int)): + Tile shape to process the WSI. Attributes: input_resolutions (list(dict)): @@ -250,6 +252,10 @@ class IOSegmentorConfig(ModelIOConfigABC): Highest resolution to process the image based on input and output resolutions. This helps to read the image at the optimal resolution and improves performance. + tile_shape (tuple(int, int)): + Tile shape to process the WSI. + margin (int): + Tile margin to accumulate the output. Examples: >>> # Defining io for a network having 1 input and 1 output at the @@ -285,6 +291,8 @@ class IOSegmentorConfig(ModelIOConfigABC): patch_output_shape: list[int] | np.ndarray | tuple[int, int] = None save_resolution: dict = None + tile_shape: tuple[int, int] | None = None + margin: int | None = None def to_baseline(self: IOSegmentorConfig) -> IOSegmentorConfig: """Returns a new config object converted to baseline form. @@ -440,9 +448,6 @@ class IOInstanceSegmentorConfig(IOSegmentorConfig): """ - margin: int = None - tile_shape: tuple[int, int] = None - def to_baseline(self: IOInstanceSegmentorConfig) -> IOInstanceSegmentorConfig: """Returns a new config object converted to baseline form. From f48c4e023cc15aaf4a50f50e0c45a76dbb5c9ddf Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 4 Feb 2026 17:25:46 +0000 Subject: [PATCH 089/156] :white_check_mark: Update batchsize for cli and tests --- tests/models/test_arch_micronet.py | 31 +++++---- tiatoolbox/cli/multitask_segmentor.py | 2 +- tiatoolbox/cli/nucleus_detector.py | 2 +- tiatoolbox/cli/nucleus_instance_segment.py | 2 +- tiatoolbox/cli/patch_predictor.py | 2 +- tiatoolbox/cli/semantic_segmentor.py | 2 +- tiatoolbox/models/architecture/micronet.py | 79 +++++++++++++++++----- 7 files changed, 86 insertions(+), 34 deletions(-) diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py index 9cc64cc96..4399057a9 100644 --- a/tests/models/test_arch_micronet.py +++ b/tests/models/test_arch_micronet.py @@ -6,6 +6,7 @@ import numpy as np import pytest import torch +import zarr from tiatoolbox.models import MicroNet, NucleusInstanceSegmentor from tiatoolbox.models.architecture import fetch_pretrained_weights @@ -59,28 +60,34 @@ def test_micronet_output(remote_sample: Callable, track_tmp_path: Path) -> None: """Test the output of MicroNet.""" svs_1_small = Path(remote_sample("svs-1-small")) micronet_output = Path(remote_sample("micronet-output")) - pretrained_model = "micronet-consep" - batch_size = 5 - num_loader_workers = 0 - num_postproc_workers = 0 + model = "micronet-consep" + batch_size = 64 + num_workers = 0 - predictor = NucleusInstanceSegmentor( - pretrained_model=pretrained_model, + ninst_seg = NucleusInstanceSegmentor( + model=model, batch_size=batch_size, - num_loader_workers=num_loader_workers, - num_postproc_workers=num_postproc_workers, + num_workers=num_workers, ) - output = predictor.predict( - imgs=[ + output = ninst_seg.run( + images=[ svs_1_small, ], save_dir=track_tmp_path / "output", + patch_mode=False, + verbose=True, + device=select_device(on_gpu=ON_GPU), + return_predictions=(True,), + return_probabilities=True, + output_type="zarr", ) - output = np.load(output[0][1] + ".raw.0.npy") + output = zarr.open(output[svs_1_small], mode="r") output_on_server = np.load(str(micronet_output)) output_on_server = np.round(output_on_server, decimals=3) - new_output = np.round(output[500:1000, 1000:1500, :], decimals=3) + new_output = np.round( + output["probabilities"][0][1000:2000:2, 2000:3000:2, :], decimals=3 + ) diff = new_output - output_on_server assert diff.mean() < 1e-5 diff --git a/tiatoolbox/cli/multitask_segmentor.py b/tiatoolbox/cli/multitask_segmentor.py index d9b2c018e..4b747ccb3 100644 --- a/tiatoolbox/cli/multitask_segmentor.py +++ b/tiatoolbox/cli/multitask_segmentor.py @@ -54,7 +54,7 @@ @cli_model(default="hovernetplus-oed") @cli_weights() @cli_device(default="cpu") -@cli_batch_size(default=1) +@cli_batch_size(default=64) @cli_yaml_config_path() @cli_masks(default=None) @cli_num_workers(default=0) diff --git a/tiatoolbox/cli/nucleus_detector.py b/tiatoolbox/cli/nucleus_detector.py index 290ebddd3..be721588a 100644 --- a/tiatoolbox/cli/nucleus_detector.py +++ b/tiatoolbox/cli/nucleus_detector.py @@ -59,7 +59,7 @@ @cli_model(default="mapde-conic") @cli_weights() @cli_device(default="cpu") -@cli_batch_size(default=1) +@cli_batch_size(default=64) @cli_yaml_config_path() @cli_masks(default=None) @cli_num_workers(default=0) diff --git a/tiatoolbox/cli/nucleus_instance_segment.py b/tiatoolbox/cli/nucleus_instance_segment.py index e5c89078f..ea33055ee 100644 --- a/tiatoolbox/cli/nucleus_instance_segment.py +++ b/tiatoolbox/cli/nucleus_instance_segment.py @@ -56,7 +56,7 @@ @cli_model(default="hovernet_fast-pannuke") @cli_weights() @cli_device(default="cpu") -@cli_batch_size(default=1) +@cli_batch_size(default=64) @cli_yaml_config_path() @cli_masks(default=None) @cli_num_workers(default=0) diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index c682ebfa8..1e80036e1 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -52,7 +52,7 @@ @cli_model(default="resnet18-kather100k") @cli_weights() @cli_device(default="cpu") -@cli_batch_size(default=1) +@cli_batch_size(default=64) @cli_yaml_config_path() @cli_masks(default=None) @cli_num_workers(default=0) diff --git a/tiatoolbox/cli/semantic_segmentor.py b/tiatoolbox/cli/semantic_segmentor.py index 86db63175..a6f2a0d88 100644 --- a/tiatoolbox/cli/semantic_segmentor.py +++ b/tiatoolbox/cli/semantic_segmentor.py @@ -55,7 +55,7 @@ @cli_model(default="fcn-tissue_mask") @cli_weights() @cli_device(default="cpu") -@cli_batch_size(default=1) +@cli_batch_size(default=64) @cli_yaml_config_path() @cli_masks(default=None) @cli_num_workers(default=0) diff --git a/tiatoolbox/models/architecture/micronet.py b/tiatoolbox/models/architecture/micronet.py index 6065fcd46..9a319b083 100644 --- a/tiatoolbox/models/architecture/micronet.py +++ b/tiatoolbox/models/architecture/micronet.py @@ -10,6 +10,7 @@ from collections import OrderedDict +import dask.array as da import numpy as np import torch from scipy import ndimage @@ -17,8 +18,12 @@ from torch import nn from torch.nn import functional -from tiatoolbox.models.architecture.hovernet import HoVerNet +from tiatoolbox.models.architecture.hovernet import ( + HoVerNet, + _inst_dict_for_dask_processing, +) from tiatoolbox.models.models_abc import ModelABC +from tiatoolbox.utils.misc import get_tqdm def group1_forward_branch( @@ -451,6 +456,7 @@ def __init__( raise ValueError(msg) self.__num_output_channels = num_output_channels self.in_ch = num_input_channels + self.tasks = ["nuclei_segmentation"] module_dict = OrderedDict() module_dict["b1"] = group1_arch_branch( @@ -568,13 +574,12 @@ def forward( # skipcq: PYL-W0221 return [out, aux1, aux2, aux3] - @staticmethod - def postproc(image: np.ndarray) -> tuple[np.ndarray, dict]: + def postproc(self: MicroNet, raw_maps: list[np.ndarray | da.Array]) -> tuple[dict]: """Post-processing script for MicroNet. Args: - image (ndarray): - Input image of type numpy array. + raw_maps (list[ndarray | da.Array]): + A list of prediction outputs of each head from inference model. Returns: :class:`numpy.ndarray`: @@ -582,16 +587,60 @@ def postproc(image: np.ndarray) -> tuple[np.ndarray, dict]: prediction. """ - pred_bin = np.argmax(image[0], axis=2) + is_dask = isinstance(raw_maps[0], da.Array) + pred_map = raw_maps[0].compute() if is_dask else raw_maps[0] + pred_bin = np.argmax(pred_map, axis=2) pred_inst = ndimage.label(pred_bin)[0] pred_inst = morphology.remove_small_objects(pred_inst, min_size=50) canvas = np.zeros(pred_inst.shape[:2], dtype=np.int32) - for inst_id in range(1, np.max(pred_inst) + 1): - inst_map = np.array(pred_inst == inst_id, dtype=np.uint8) - inst_map = ndimage.binary_fill_holes(inst_map) - canvas[inst_map > 0] = inst_id + tqdm_ = get_tqdm() + for inst_id in tqdm_( + range(1, np.max(pred_inst) + 1), + leave=False, + desc="Performing morphological operations to improve segmentation quality.", + ): + # Get coordinates of this instance + ys, xs = np.where(pred_inst == inst_id) + if len(xs) == 0: + continue # skip empty IDs + # Bounding box + y1, y2 = ys.min(), ys.max() + 1 + x1, x2 = xs.min(), xs.max() + 1 + # Crop region + crop = pred_inst[y1:y2, x1:x2] == inst_id + # Fill holes only inside this small region + filled = ndimage.binary_fill_holes(crop) + # Paste back into canvas + canvas[y1:y2, x1:x2][filled] = inst_id + nuc_inst_info_dict = HoVerNet.get_instance_info(canvas) - return canvas, nuc_inst_info_dict + + nuc_inst_info_dict_ = {} + if not nuc_inst_info_dict: + # inst_id should start at 1; use NumPy or Dask empty arrays + empty_array = da.empty(shape=0) if is_dask else np.empty(shape=0) + nuc_inst_info_dict_ = { + "box": empty_array, + "centroid": empty_array, + "contours": empty_array, + "prob": empty_array, + "type": empty_array, + } + else: + nuc_inst_info_dict_ = _inst_dict_for_dask_processing( + inst_info_dict=nuc_inst_info_dict, + inst_info_dict_=nuc_inst_info_dict_, + is_dask=is_dask, + ) + + nuclei_seg = { + "task_type": self.tasks[0], + "predictions": da.array(pred_inst) + if isinstance(raw_maps[0], da.Array) + else pred_inst, + "info_dict": nuc_inst_info_dict_, + } + return (nuclei_seg,) @staticmethod def preproc(image: np.ndarray) -> np.ndarray: @@ -629,7 +678,7 @@ def infer_batch( # skipcq: PYL-W0221 batch_data: torch.Tensor, *, device: str, - ) -> list[np.ndarray]: + ) -> tuple[np.ndarray]: """Run inference on an input batch. This contains logic for forward operation as well as batch I/O @@ -660,8 +709,4 @@ def infer_batch( # skipcq: PYL-W0221 pred, _, _, _ = model(patch_imgs_gpu) pred = pred.permute(0, 2, 3, 1).contiguous() - pred = pred.cpu().numpy() - - return [ - pred, - ] + return (pred.cpu().numpy(),) From e962fe59cf0c59b031419c325117a275a4c0a6bd Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 4 Feb 2026 17:29:57 +0000 Subject: [PATCH 090/156] :bug: Fix deepsource bug --- tiatoolbox/models/architecture/micronet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tiatoolbox/models/architecture/micronet.py b/tiatoolbox/models/architecture/micronet.py index 9a319b083..9eb3248b1 100644 --- a/tiatoolbox/models/architecture/micronet.py +++ b/tiatoolbox/models/architecture/micronet.py @@ -574,6 +574,7 @@ def forward( # skipcq: PYL-W0221 return [out, aux1, aux2, aux3] + # skipcq: PYL-W0221 # noqa: ERA001 def postproc(self: MicroNet, raw_maps: list[np.ndarray | da.Array]) -> tuple[dict]: """Post-processing script for MicroNet. From 73f7d61b545efef75268b3844f1bfbc8e7b24bc1 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 5 Feb 2026 09:52:23 +0000 Subject: [PATCH 091/156] :bug: Fix tests --- tests/models/test_arch_micronet.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py index 4399057a9..f28f75642 100644 --- a/tests/models/test_arch_micronet.py +++ b/tests/models/test_arch_micronet.py @@ -41,8 +41,11 @@ def test_functionality( pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, device=map_location) - output, _ = model.postproc(output[0]) - assert np.max(np.unique(output)) == 46 + output_ = model.postproc(list(output[0])) + assert output_[0]["task_type"] == "nuclei_segmentation" + assert np.max(np.unique(output_[0]["predictions"])) == 46 + assert len(output_[0]["info_dict"]["centroid"]) == 27 + assert len(output_[0]["info_dict"]["contours"]) == 27 Path(weights_path).unlink() From aae35295a4d972ec92ba9bae23af9744a7a54186 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 5 Feb 2026 10:29:55 +0000 Subject: [PATCH 092/156] :lipstick: Update description for tqdm --- tiatoolbox/models/engine/multi_task_segmentor.py | 6 ++++-- tiatoolbox/models/engine/nucleus_detector.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index c8aee2b63..fa8ca4ff2 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1973,7 +1973,7 @@ def dict_to_store( tqdm_ = get_tqdm() for i, contour_ in enumerate( - tqdm_(contour, leave=False, desc="Converting outputs to AnnotationStore.") + tqdm_(contour, leave=False, desc="Converting outputs to AnnotationStore") ): ann_ = Annotation( make_valid_poly( @@ -2376,7 +2376,9 @@ def merge_multitask_vertical_chunkwise( chunk_shape = tuple(chunk[0] for chunk in canvas_.chunks) tqdm_ = get_tqdm() - tqdm_loop = tqdm_(overlaps, leave=False, desc="Merging rows") + tqdm_loop = tqdm_( + overlaps, leave=False, desc=f"Merging rows for probability map {idx}" + ) curr_chunk = canvas_.blocks[0, 0].compute() curr_count = count[idx].blocks[0, 0].compute() diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index fc8364c59..05b81a535 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -879,7 +879,7 @@ def make_points(xs_batch: np.ndarray, ys_batch: np.ndarray) -> list[Point]: ] tqdm = get_tqdm() - tqdm_loop = tqdm(range(0, n, batch_size), desc="Writing detections to store.") + tqdm_loop = tqdm(range(0, n, batch_size), desc="Writing detections to store") written = 0 for i in tqdm_loop: j = min(i + batch_size, n) From 322f4f67bd9dc813861f9588e98a248091c5443c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 5 Feb 2026 10:43:37 +0000 Subject: [PATCH 093/156] :lipstick: Update description for tqdm --- tiatoolbox/models/architecture/hovernet.py | 7 +++++-- .../models/engine/multi_task_segmentor.py | 20 ++++++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 9f35ae5e1..d9fb329fe 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -22,7 +22,7 @@ centre_crop_to_shape, ) from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils.misc import get_bounding_box +from tiatoolbox.utils.misc import get_bounding_box, get_tqdm class TFSamepaddingLayer(nn.Module): @@ -650,7 +650,10 @@ def get_instance_info(pred_inst: np.ndarray, pred_type: np.ndarray = None) -> di """ inst_id_list = np.unique(pred_inst)[1:] # exclude background inst_info_dict = {} - for inst_id in inst_id_list: + tqdm_ = get_tqdm() + for inst_id in tqdm_( + inst_id_list, leave=False, desc="Generating 'info_dict' for instances" + ): inst_map = pred_inst == inst_id inst_box = get_bounding_box(inst_map) inst_box_tl = inst_box[:2] diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index fa8ca4ff2..e31458a73 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -131,7 +131,7 @@ from shapely.strtree import STRtree from typing_extensions import Unpack -from tiatoolbox import logger +from tiatoolbox import DuplicateFilter, logger from tiatoolbox.annotation import SQLiteStore from tiatoolbox.annotation.storage import Annotation from tiatoolbox.tools.patchextraction import PatchExtractor @@ -1073,7 +1073,16 @@ def _process_tile_mode( merged = [] wsi_info_dict = None - for set_idx, (set_bounds, set_flags) in enumerate(tile_info_sets): + duplicate_filter = DuplicateFilter() + logger.addFilter(duplicate_filter) + tqdm_ = get_tqdm() + for set_idx, (set_bounds, set_flags) in enumerate( + tqdm_( + tile_info_sets, + leave=False, + desc="Post-Processing WSI to generate predictions and contours", + ) + ): for tile_idx, tile_bounds in enumerate(set_bounds): tile_flag = set_flags[tile_idx] tile_tl = tile_bounds[:2] @@ -1131,8 +1140,13 @@ def _process_tile_mode( wsi_info_dict[inst_id]["info_dict"].update(new_inst_dict) for inst_uuid in remove_uuid_lists[inst_id]: wsi_info_dict[inst_id]["info_dict"].pop(inst_uuid, None) + logger.removeFilter(duplicate_filter) - for idx, wsi_info_dict_ in enumerate(wsi_info_dict): + for idx, wsi_info_dict_ in enumerate( + tqdm_( + wsi_info_dict, leave=False, desc="Converting 'info_dict' to dask arrays" + ) + ): info_df = pd.DataFrame(wsi_info_dict_["info_dict"]).transpose() dict_info_wsi = {} for key, col in info_df.items(): From 7c62b3dd1a1a9ba76c903638fdb5d4c1efc7ad93 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 5 Feb 2026 11:49:57 +0000 Subject: [PATCH 094/156] :white_check_mark: Add tests for improved coverage --- tests/models/test_arch_micronet.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py index f28f75642..e514d90b5 100644 --- a/tests/models/test_arch_micronet.py +++ b/tests/models/test_arch_micronet.py @@ -46,6 +46,14 @@ def test_functionality( assert np.max(np.unique(output_[0]["predictions"])) == 46 assert len(output_[0]["info_dict"]["centroid"]) == 27 assert len(output_[0]["info_dict"]["contours"]) == 27 + + # For test coverage pass probability map with + # no cell segmentation instance + output_ = model.postproc(np.zeros((1, 252, 252, 2))) + assert output_[0]["task_type"] == "nuclei_segmentation" + assert np.max(np.unique(output_[0]["predictions"])) == 0 + assert len(output_[0]["info_dict"]["centroid"]) == 0 + Path(weights_path).unlink() From 199aeb9fbcf5451326b47a1ee75fe42e50d39ab0 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 5 Feb 2026 17:45:44 +0000 Subject: [PATCH 095/156] :zap: Use dask delayed for parallel post-processing. --- .../models/engine/multi_task_segmentor.py | 164 ++++++++++++++---- 1 file changed, 131 insertions(+), 33 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index e31458a73..d81e9b68b 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -114,6 +114,7 @@ from __future__ import annotations import gc +import multiprocessing import shutil import uuid from collections import deque @@ -126,12 +127,13 @@ import psutil import torch import zarr +from dask import compute, delayed from shapely.geometry import box as shapely_box from shapely.geometry import shape as feature2geometry from shapely.strtree import STRtree from typing_extensions import Unpack -from tiatoolbox import DuplicateFilter, logger +from tiatoolbox import logger from tiatoolbox.annotation import SQLiteStore from tiatoolbox.annotation.storage import Annotation from tiatoolbox.tools.patchextraction import PatchExtractor @@ -897,10 +899,16 @@ def post_process_wsi( # skipcq: PYL-R0201 return_predictions=return_predictions, ) else: + num_workers = ( + kwargs.get("num_workers", multiprocessing.cpu_count()) + if self.num_workers == 0 + else self.num_workers + ) post_process_predictions = self._process_tile_mode( probabilities=probabilities, save_path=save_path.with_suffix(".zarr"), memory_threshold=kwargs.get("memory_threshold", 80), + num_workers=num_workers, return_predictions=kwargs.get("return_predictions"), ) @@ -991,6 +999,7 @@ def _process_tile_mode( probabilities: list[da.Array | np.ndarray], save_path: Path, memory_threshold: float = 80, + num_workers: int = multiprocessing.cpu_count(), *, return_predictions: tuple[bool, ...] | None = None, ) -> list[dict] | None: @@ -1021,6 +1030,9 @@ def _process_tile_mode( memory_threshold (float): Maximum allowed RAM usage (percentage) for in-memory arrays before switching to or continuing with Zarr-backed allocation. Defaults to 80. + num_workers (int): + Number of workers for data loading. + Default is multiprocessing.cpu_count(). return_predictions (tuple[bool, ...] | None): Per-task flags indicating whether to retain a full-resolution ``"predictions"`` array for each task. If ``None``, no task-level @@ -1071,32 +1083,32 @@ def _process_tile_mode( tile_info_sets = self._get_tile_info(wsi_proc_shape, self.ioconfig) ioconfig = self.ioconfig.to_baseline() - merged = [] - wsi_info_dict = None - duplicate_filter = DuplicateFilter() - logger.addFilter(duplicate_filter) tqdm_ = get_tqdm() - for set_idx, (set_bounds, set_flags) in enumerate( - tqdm_( - tile_info_sets, - leave=False, - desc="Post-Processing WSI to generate predictions and contours", - ) + + delayed_results, tile_metadata = _build_delayed_tile_tasks( + probabilities=probabilities, + tile_info_sets=tile_info_sets, + model=self.model, + ) + + wsi_info_dict = None + merge_idx = 0 + for i in tqdm_( + range(0, len(delayed_results), num_workers), + leave=False, + desc="Post-Processing WSI to generate predictions and contours", ): - for tile_idx, tile_bounds in enumerate(set_bounds): - tile_flag = set_flags[tile_idx] - tile_tl = tile_bounds[:2] - tile_br = tile_bounds[2:] - tile_shape = tile_br - tile_tl # in width height - head_raws = [ - probabilities_[ - tile_bounds[1] : tile_bounds[3], - tile_bounds[0] : tile_bounds[2], - :, - ].compute() - for probabilities_ in probabilities - ] - post_process_output = self.model.postproc_func(head_raws) + batch = delayed_results[i : i + num_workers] + + # Compute only this batch in parallel to avoid memory overload. + batch_outputs = compute( + *batch, scheduler="threads", num_workers=num_workers + ) + + # Merge each tile result immediately + for post_process_output in batch_outputs: + tile_bounds, tile_flag, tile_mode = tile_metadata[merge_idx] + merge_idx += 1 # create a list of info dict for each task wsi_info_dict = _create_wsi_info_dict( @@ -1117,8 +1129,10 @@ def _process_tile_mode( inst_dicts = _get_inst_info_dicts( post_process_output=post_process_output ) + tile_tl = tile_bounds[:2] + tile_br = tile_bounds[2:] + tile_shape = tile_br - tile_tl - tile_mode = set_idx new_inst_dicts, remove_insts_in_origs = [], [] for inst_id, inst_dict in enumerate(inst_dicts): new_inst_dict, remove_insts_in_orig = _process_instance_predictions( @@ -1133,14 +1147,10 @@ def _process_tile_mode( new_inst_dicts.append(new_inst_dict) remove_insts_in_origs.append(remove_insts_in_orig) - merged.append((new_inst_dicts, remove_insts_in_origs)) - - for new_inst_dicts, remove_uuid_lists in merged: for inst_id, new_inst_dict in enumerate(new_inst_dicts): wsi_info_dict[inst_id]["info_dict"].update(new_inst_dict) - for inst_uuid in remove_uuid_lists[inst_id]: + for inst_uuid in remove_insts_in_origs[inst_id]: wsi_info_dict[inst_id]["info_dict"].pop(inst_uuid, None) - logger.removeFilter(duplicate_filter) for idx, wsi_info_dict_ in enumerate( tqdm_( @@ -2892,7 +2902,7 @@ def _create_wsi_info_dict( save_path: Path, return_predictions: tuple[bool, ...] | None, memory_threshold: float = 80, -) -> tuple[dict[str, dict[Any, Any] | list[Any] | Any], ...]: +) -> tuple[dict, ...]: """Create or reuse WSI info dictionaries for post-processed outputs. This function constructs a tuple of WSI information dictionaries, one for each @@ -2961,7 +2971,7 @@ def _update_tile_based_predictions_array( post_process_output: tuple[dict], wsi_info_dict: tuple[dict], bounds: tuple[int, int, int, int], -) -> tuple[dict]: +) -> tuple[dict, ...]: """Helper function to update tile based predictions array.""" x_start, y_start, x_end, y_end = bounds @@ -2977,3 +2987,91 @@ def _update_tile_based_predictions_array( ) return wsi_info_dict + + +@delayed +def _compute_tile( + probabilities: list[da.Array | np.ndarray], + tile_bounds: tuple[int, int, int, int], + model: ModelABC, +) -> tuple: + """Compute post-processing outputs for a single WSI tile. + + This function performs lazy slicing of the probability maps for the given + tile bounds and applies the model's `postproc_func` to produce per-task + outputs (semantic, instance, etc.). + + Args: + probabilities: + List of WSI-scale probability maps, one per model head. Each element + is either a Dask array or NumPy array with shape (H, W, C). + tile_bounds: + A 4-tuple (x_start, y_start, x_end, y_end) defining the tile region + in WSI coordinates. + model: + The multitask model containing a `postproc_func` method that accepts + a list of tile-level arrays and returns a tuple of task dictionaries. + + Returns: + A tuple of dictionaries, one per task, as produced by `postproc_func`. + Each dictionary typically contains: + - "task_type": str + - "predictions": np.ndarray + - "info_dict": dict + """ + head_raws = [ + p[tile_bounds[1] : tile_bounds[3], tile_bounds[0] : tile_bounds[2], :] + for p in probabilities + ] + return model.postproc_func(head_raws) + + +def _build_delayed_tile_tasks( + probabilities: list[da.Array | np.ndarray], + tile_info_sets: list, + model: ModelABC, +) -> tuple[ + list[Any], # delayed results + list, # metadata +]: + """Build delayed tile-processing tasks and associated metadata. + + This function iterates over all tile sets and constructs: + - a list of delayed tasks (each calling `_compute_tile`) + - a parallel list of metadata entries describing each tile + + Metadata entries contain: + (tile_bounds, tile_flag, tile_mode) + + Args: + probabilities: + List of WSI-scale probability maps, one per model head. + Each element is a Dask array or NumPy array with shape (H, W, C). + tile_info_sets: + A list where each element is a tuple: + (set_bounds, set_flags) + - set_bounds: list of tile bounds (x0, y0, x1, y1) + - set_flags: list of per-tile flags used for instance merging + model: + The multitask model containing a `postproc_func` method. + + Returns: + A tuple: + - delayed_results: list of delayed `_compute_tile` tasks + - tile_metadata: list of metadata tuples + (tile_bounds, tile_flag, tile_mode) + """ + delayed_results: list = [] + tile_metadata: list = [] + + for set_idx, (set_bounds, set_flags) in enumerate(tile_info_sets): + for tile_idx, tile_bounds in enumerate(set_bounds): + tile_flag = set_flags[tile_idx] + + # Create delayed tile compute task + delayed_results.append(_compute_tile(probabilities, tile_bounds, model)) + + # Store metadata for merging + tile_metadata.append((tile_bounds, tile_flag, set_idx)) + + return delayed_results, tile_metadata From 20a2a846eef4436b6fbfc2ada7e2abd5f6475b8f Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 6 Feb 2026 09:33:13 +0000 Subject: [PATCH 096/156] :lipstick: Replace `tqdm.write` with `tqdm.desc` --- .../models/engine/deep_feature_extractor.py | 3 +-- tiatoolbox/models/engine/multi_task_segmentor.py | 8 ++++---- tiatoolbox/models/engine/semantic_segmentor.py | 15 ++++++++------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index 587e47b8b..b09c72ca9 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -325,7 +325,6 @@ def infer_wsi( used_percent > memory_threshold or probabilities_used_percent > memory_threshold ): - tqdm_loop.desc = "Spill intermediate data to disk" used_percent = ( probabilities_used_percent if (probabilities_used_percent > memory_threshold) @@ -337,7 +336,7 @@ def infer_wsi( f"Saving intermediate results to disk." ) - tqdm.write(msg) + tqdm_loop.desc = msg # Flush data in Memory and clear dask graph probabilities_zarr, coordinates_zarr = save_to_cache( probabilities, diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index d81e9b68b..a1e6aca99 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -2477,12 +2477,13 @@ def _save_multitask_vertical_to_cache( total_bytes = sum(0 if arr is None else arr.nbytes for arr in probabilities_da) used_percent = (total_bytes / max(vm.available, 1)) * 100 if probabilities_zarr[idx] is None and used_percent > memory_threshold: + desc = tqdm_.desc msg = ( f"Current Memory usage: {used_percent} % " f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) - tqdm_.write(msg) + tqdm_.desc = msg zarr_group = zarr.open(str(save_path), mode="a") probabilities_zarr[idx] = zarr_group.create_dataset( name=f"probabilities/{idx}", @@ -2492,7 +2493,7 @@ def _save_multitask_vertical_to_cache( overwrite=True, ) probabilities_zarr[idx][:] = probabilities_da[idx].compute() - + tqdm_.desc = desc probabilities_da[idx] = None return probabilities_zarr, probabilities_da @@ -2577,7 +2578,6 @@ def _check_and_update_for_memory_overload( if not (used_percent > memory_threshold or canvas_used_percent > memory_threshold): return canvas, count, canvas_zarr, count_zarr, tqdm_loop - tqdm_loop.desc = "Spill intermediate data to disk" used_percent = ( canvas_used_percent if (canvas_used_percent > memory_threshold) @@ -2588,7 +2588,7 @@ def _check_and_update_for_memory_overload( f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) - tqdm_.write(msg) + tqdm_.desc = msg # Flush data in Memory and clear dask graph canvas_zarr, count_zarr = save_multitask_to_cache( canvas, diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index c7d052f31..c33ebf6ac 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -460,9 +460,9 @@ def infer_wsi( ) # Inference loop - tqdm = get_tqdm() + tqdm_ = get_tqdm() tqdm_loop = ( - tqdm(dataloader, leave=False, desc="Inferring patches") + tqdm_(dataloader, leave=False, desc="Inferring patches") if self.verbose else dataloader ) @@ -526,7 +526,6 @@ def infer_wsi( used_percent > memory_threshold or canvas_used_percent > memory_threshold ): - tqdm_loop.desc = "Spill intermediate data to disk" used_percent = ( canvas_used_percent if (canvas_used_percent > memory_threshold) @@ -537,7 +536,7 @@ def infer_wsi( f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) - tqdm.write(msg) + tqdm_loop.desc = msg # Flush data in Memory and clear dask graph canvas_zarr, count_zarr = save_to_cache( canvas, @@ -1248,8 +1247,8 @@ def merge_vertical_chunkwise( probabilities_zarr, probabilities_da = None, None chunk_shape = tuple(chunk[0] for chunk in canvas.chunks) - tqdm = get_tqdm() - tqdm_loop = tqdm(overlaps, leave=False, desc="Merging rows") + tqdm_ = get_tqdm() + tqdm_loop = tqdm_(overlaps, leave=False, desc="Merging rows") used_percent = 0 @@ -1279,12 +1278,13 @@ def merge_vertical_chunkwise( vm = psutil.virtual_memory() used_percent = (probabilities_da.nbytes / vm.free) * 100 if probabilities_zarr is None and used_percent > memory_threshold: + desc = tqdm_.desc msg = ( f"Current Memory usage: {used_percent} % " f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) - tqdm.write(msg) + tqdm_.desc = msg zarr_group = zarr.open(str(save_path), mode="a") probabilities_zarr = zarr_group.create_dataset( name="probabilities", @@ -1296,6 +1296,7 @@ def merge_vertical_chunkwise( probabilities_zarr[:] = probabilities_da.compute() probabilities_da = None + tqdm_.desc = desc if next_chunk is not None: curr_chunk, curr_count = next_chunk[overlap:], next_count[overlap:] From 63d91578123f119305cb92fbbdcfd141b94948c4 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 6 Feb 2026 10:12:30 +0000 Subject: [PATCH 097/156] :bug: Fix failing test for DummyTQDM --- tests/engines/test_multi_task_segmentor.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 0365e0083..612615228 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -597,6 +597,7 @@ class DummyTqdm: """Dummy tqdm with a write() method.""" messages: ClassVar[list[str]] = [] + desc: str = "Test Method" @classmethod def write(cls: DummyTqdm, msg: str) -> None: @@ -615,11 +616,6 @@ def write(cls: DummyTqdm, msg: str) -> None: memory_threshold=0, # ensure branch triggers ) - # --- Assertions --- - # tqdm.write was called - assert len(DummyTqdm.messages) == 1 - assert "Saving intermediate results to disk" in DummyTqdm.messages[0] - # probabilities_da must be set to None assert new_da[idx] is None From 9fda5363804739d67285b12bb3d39ce89ddb8822 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 6 Feb 2026 12:47:56 +0000 Subject: [PATCH 098/156] :zap: Use dask delayed to save annotations. --- .../models/engine/multi_task_segmentor.py | 157 ++++++++++++++---- 1 file changed, 128 insertions(+), 29 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index a1e6aca99..206e18e15 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1561,6 +1561,11 @@ def _save_predictions_as_annotationstore( if not return_predictions: processed_predictions.pop("predictions") keys_to_compute.remove("predictions") + num_workers = ( + kwargs.get("num_workers", multiprocessing.cpu_count()) + if self.num_workers == 0 + else self.num_workers + ) if self.patch_mode: for idx, curr_image in enumerate(self.images): values = [processed_predictions[key][idx] for key in keys_to_compute] @@ -1573,6 +1578,7 @@ def _save_predictions_as_annotationstore( save_path=save_path, class_dict=class_dict, scale_factor=scale_factor, + num_workers=num_workers, ) save_paths.append(output_path) @@ -1588,6 +1594,7 @@ def _save_predictions_as_annotationstore( save_path=save_path, class_dict=class_dict, scale_factor=scale_factor, + num_workers=num_workers, ) save_paths.append(output_path) @@ -1936,6 +1943,7 @@ def dict_to_store( class_dict: dict | None = None, origin: tuple[float, float] = (0, 0), scale_factor: tuple[float, float] = (1, 1), + num_workers: int = multiprocessing.cpu_count(), ) -> AnnotationStore: """Write polygonal multitask predictions into an SQLite-backed AnnotationStore. @@ -1974,6 +1982,9 @@ def dict_to_store( `(sx, sy)` factors applied to coordinates before translation, used to convert from model space to baseline slide resolution (e.g., `model_mpp / slide_mpp`). + num_workers (int): + Number of parallel worker threads to use. If set to 0 or None, + defaults to the number of CPU cores. Returns: AnnotationStore: @@ -1991,43 +2002,84 @@ def dict_to_store( - All annotations are appended in a single batch via `store.append_many(...)`. """ - contour = processed_predictions.pop("contours") + contours = processed_predictions.pop("contours") + n = len(contours) + + # Build delayed tasks + delayed_tasks = [ + _build_single_annotation( + i, + contours[i], + processed_predictions, + class_dict, + origin, + scale_factor, + ) + for i in range(n) + ] - ann = [] - tqdm_ = get_tqdm() + ann = compute_dask_delayed_with_progress( + delayed_tasks, num_workers=num_workers, desc="Saving annotations " + ) - for i, contour_ in enumerate( - tqdm_(contour, leave=False, desc="Converting outputs to AnnotationStore") - ): - ann_ = Annotation( - make_valid_poly( - feature2geometry( - { - "type": processed_predictions.get("geom_type", "Polygon"), - "coordinates": scale_factor * np.array([contour_]), - }, - ), - tuple(origin), - ), - { - prop: ( - class_dict[processed_predictions[prop][i]] - if prop == "type" and class_dict is not None - # Intention is convert arrays to list - # There might be int or float values which need to be - # converted to arrays first and then apply tolist(). - else np.array(processed_predictions[prop][i]).tolist() - ) - for prop in processed_predictions - }, - ) - ann.append(ann_) logger.info("Added %d annotations.", len(ann)) store.append_many(ann) return store +def compute_dask_delayed_with_progress( + delayed_tasks: list, + num_workers: int = multiprocessing.cpu_count(), + desc: str = "Computing", + batch_size: int | None = None, +) -> list: + """Compute a list of Dask delayed tasks in parallel while displaying a progress bar. + + This function batches tasks according to `num_workers`, ensuring that only + `num_workers` tasks are computed concurrently. This avoids excessive memory + usage when each delayed task returns a large object (e.g., NumPy arrays, + geometries, or annotations). A tqdm progress bar is updated after each batch. + + Args: + delayed_tasks (list): + A list of Dask delayed objects to compute. + num_workers (int): + Number of parallel worker threads to use. If set to 0 or None, + defaults to the number of CPU cores. + desc (str): + Description string shown in the tqdm progress bar. + batch_size (int | None): + batch_size to process dask delayed. + batch_size is set to num_workers if batch_size is not provided. + + Returns: + A list containing the computed results from all delayed tasks, in order. + + """ + total = len(delayed_tasks) + batch_size = num_workers if batch_size is None else batch_size + results: list[Any] = [] + + tqdm_ = get_tqdm() + + with tqdm_(total=total, desc=desc, leave=False) as pbar: + for i in range(0, total, batch_size): + batch = delayed_tasks[i : i + batch_size] + + # Compute this batch in parallel + batch_results = compute( + *batch, + scheduler="threads", + num_workers=num_workers, + ) + + results.extend(batch_results) + pbar.update(len(batch)) + + return results + + def prepare_multitask_full_batch( batch_output: tuple[np.ndarray], batch_locs: np.ndarray, @@ -2614,6 +2666,7 @@ def _save_annotation_store( save_path: Path, class_dict: dict, scale_factor: tuple[float, float], + num_workers: int, ) -> Path: """Helper function to save to annotation store.""" if isinstance(curr_image, Path): @@ -2640,6 +2693,7 @@ def _save_annotation_store( class_dict=class_dict, scale_factor=scale_factor, origin=origin, + num_workers=num_workers, ) store.commit() @@ -3075,3 +3129,48 @@ def _build_delayed_tile_tasks( tile_metadata.append((tile_bounds, tile_flag, set_idx)) return delayed_results, tile_metadata + + +@delayed +def _build_single_annotation( + i: int, + contour: np.ndarray, + processed_predictions: dict[str, Any], + class_dict: dict[int, str] | None, + origin: tuple[float, float], + scale_factor: tuple[float, float], +) -> Annotation: + """Creates a delayed annotation to run with dask. + + Build a single Annotation object for index `i`. + + This function performs: + - geometry creation + - coordinate scaling + translation + - per-object property extraction + - class_dict mapping (if provided) + + Returns: + A single Annotation instance. + + """ + geom = make_valid_poly( + feature2geometry( + { + "type": processed_predictions.get("geom_type", "Polygon"), + "coordinates": scale_factor * np.array([contour]), + } + ), + tuple(origin), + ) + + properties = { + prop: ( + class_dict[processed_predictions[prop][i]] + if prop == "type" and class_dict is not None + else np.array(processed_predictions[prop][i]).tolist() + ) + for prop in processed_predictions + } + + return Annotation(geom, properties) From bbb1bbfe4a3890a0380ca252efa91917fd23adf9 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 6 Feb 2026 12:58:40 +0000 Subject: [PATCH 099/156] :lipstick: Add tqdm to creating list of tasks --- tiatoolbox/models/engine/multi_task_segmentor.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 206e18e15..65098e0a9 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -2005,6 +2005,8 @@ def dict_to_store( contours = processed_predictions.pop("contours") n = len(contours) + tqdm_ = get_tqdm() + # Build delayed tasks delayed_tasks = [ _build_single_annotation( @@ -2015,7 +2017,11 @@ def dict_to_store( origin, scale_factor, ) - for i in range(n) + for i in tqdm_( + range(n), + leave=False, + desc="Creating list of delayed tasks for writing annotations.", + ) ] ann = compute_dask_delayed_with_progress( From 2684d926cfe1c39e36b12b80743de60a47e35205 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 11:56:04 +0000 Subject: [PATCH 100/156] :lipstick: Update `warnings` and `tqdm`. --- tiatoolbox/models/architecture/hovernet.py | 8 +++++++- tiatoolbox/models/engine/engine_abc.py | 5 +++++ .../models/engine/multi_task_segmentor.py | 17 +++++++++++++++-- .../models/engine/semantic_segmentor.py | 19 +++++++------------ 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index d9fb329fe..cf34d1dbb 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -3,6 +3,7 @@ from __future__ import annotations import math +import warnings from collections import OrderedDict import cv2 @@ -603,7 +604,12 @@ def _proc_np_hv( kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel) marker = ndimage.label(marker)[0] - marker = remove_small_objects(marker, min_size=obj_size) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Only one label was provided to `remove_small_objects`", + ) + marker = remove_small_objects(marker, min_size=obj_size) return watershed(dist, markers=marker, mask=blb) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 08b1d5579..308cfd4dd 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -35,6 +35,7 @@ from __future__ import annotations import copy +import shutil from abc import ABC from pathlib import Path from typing import TYPE_CHECKING, TypedDict @@ -1785,6 +1786,8 @@ def prepare_engines_save_dir( if patch_mode: if save_dir is not None: save_dir = Path(save_dir) + if save_dir.exists() and overwrite: + shutil.rmtree(save_dir) save_dir.mkdir(parents=True, exist_ok=overwrite) return save_dir return None @@ -1804,6 +1807,8 @@ def prepare_engines_save_dir( ) save_dir = Path(save_dir) + if save_dir.exists() and overwrite: + shutil.rmtree(save_dir) save_dir.mkdir(parents=True, exist_ok=overwrite) return save_dir diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 65098e0a9..95bed3a09 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -3123,9 +3123,22 @@ def _build_delayed_tile_tasks( """ delayed_results: list = [] tile_metadata: list = [] + tqdm_ = get_tqdm() - for set_idx, (set_bounds, set_flags) in enumerate(tile_info_sets): - for tile_idx, tile_bounds in enumerate(set_bounds): + for set_idx, (set_bounds, set_flags) in enumerate( + tqdm_( + tile_info_sets, + leave=False, + desc="Building delayed tile-processing tasks", + ) + ): + for tile_idx, tile_bounds in enumerate( + tqdm_( + set_bounds, + leave=False, + desc=f"Building delayed tile-processing tasks for tile set {set_idx}", + ) + ): tile_flag = set_flags[tile_idx] # Create delayed tile compute task diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index c33ebf6ac..4e9703b03 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1148,6 +1148,8 @@ def save_to_cache( """ chunk0 = canvas.chunks[0][0] + tqdm_ = get_tqdm() + if canvas_zarr is None: zarr_group = zarr.open(str(save_path), mode="a") @@ -1186,7 +1188,11 @@ def save_to_cache( # Append remaining blocks one-at-a-time to limit peak memory. num_blocks = canvas.numblocks[0] - for block_idx in range(start_idx, num_blocks): + for block_idx in tqdm_( + range(start_idx, num_blocks), + leave=False, + desc="Memory Overload, Spilling to disk", + ): canvas_block = canvas.blocks[block_idx, 0, 0].compute() count_block = count.blocks[block_idx, 0, 0].compute() @@ -1462,17 +1468,6 @@ def prepare_full_batch( dtype=batch_output.dtype, ) else: - # Array too large, use zarr backed by disk to avoid RAM spikes - # Use a unique temp subdirectory per call to avoid chunk-shape clashes - msg = ( - f"Estimated peak memory usage for full batch output: " - f"{peak_bytes / (1024**3):.2f} GB exceeds threshold of " - f"{memory_available / (1024**3):.2f} GB." - f"Allocating full batch output of size " - f"{final_size}x{sample_shape} using Zarr on disk." - ) - logger.info(msg) - save_path_dir = Path(save_path) save_path_dir.mkdir(parents=True, exist_ok=True) temp_dir = Path( From a1ee82dfb84d72bfa038eaf8d7d20fc5cc93652b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 16:34:19 +0000 Subject: [PATCH 101/156] :bug: Fix test with same input and output directory. --- tests/engines/test_nucleus_detection_engine.py | 10 ++++++---- tiatoolbox/models/engine/engine_abc.py | 4 ++-- tiatoolbox/models/engine/multi_task_segmentor.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/engines/test_nucleus_detection_engine.py b/tests/engines/test_nucleus_detection_engine.py index 4141dbaa7..c91993b6a 100644 --- a/tests/engines/test_nucleus_detection_engine.py +++ b/tests/engines/test_nucleus_detection_engine.py @@ -105,7 +105,7 @@ def test_nucleus_detector_patch_annotation_store_output( pretrained_model = "sccnn-conic" - save_dir = track_tmp_path + save_dir = track_tmp_path / "output" nucleus_detector = NucleusDetector(model=pretrained_model) _ = nucleus_detector.run( @@ -127,14 +127,16 @@ def test_nucleus_detector_patch_annotation_store_output( assert len(store_2.values()) == 0 store_2.close() - imwrite(save_dir / "patch_0.png", patch_1) - imwrite(save_dir / "patch_1.png", patch_2) + image_dir = track_tmp_path / "inputs" + image_dir.mkdir() + imwrite(image_dir / "patch_0.png", patch_1) + imwrite(image_dir / "patch_1.png", patch_2) _ = nucleus_detector.run( patch_mode=True, device=device, output_type="annotationstore", memory_threshold=50, - images=[save_dir / "patch_0.png", save_dir / "patch_1.png"], + images=[image_dir / "patch_0.png", image_dir / "patch_1.png"], save_dir=save_dir, overwrite=True, ) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 308cfd4dd..eece332de 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1788,7 +1788,7 @@ def prepare_engines_save_dir( save_dir = Path(save_dir) if save_dir.exists() and overwrite: shutil.rmtree(save_dir) - save_dir.mkdir(parents=True, exist_ok=overwrite) + save_dir.mkdir(parents=True) return save_dir return None @@ -1809,6 +1809,6 @@ def prepare_engines_save_dir( save_dir = Path(save_dir) if save_dir.exists() and overwrite: shutil.rmtree(save_dir) - save_dir.mkdir(parents=True, exist_ok=overwrite) + save_dir.mkdir(parents=True) return save_dir diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 95bed3a09..f133ab155 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -2020,7 +2020,7 @@ def dict_to_store( for i in tqdm_( range(n), leave=False, - desc="Creating list of delayed tasks for writing annotations.", + desc="Creating list of delayed tasks for writing annotations", ) ] From 8318a14b10290fb1744a12b3329ba772bda6fb8a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 17:02:56 +0000 Subject: [PATCH 102/156] :recycle: Restructure calculation of delayed objects --- .../models/engine/multi_task_segmentor.py | 59 +++++++++---------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index f133ab155..27e27fac8 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1085,24 +1085,36 @@ def _process_tile_mode( tqdm_ = get_tqdm() - delayed_results, tile_metadata = _build_delayed_tile_tasks( - probabilities=probabilities, + tile_metadata = _build_tile_tasks( tile_info_sets=tile_info_sets, - model=self.model, ) wsi_info_dict = None merge_idx = 0 for i in tqdm_( - range(0, len(delayed_results), num_workers), + range(0, len(tile_metadata), num_workers), leave=False, desc="Post-Processing WSI to generate predictions and contours", ): - batch = delayed_results[i : i + num_workers] + tile_metadata_ = tile_metadata[i : i + num_workers] + + # Build delayed tasks + delayed_tasks = [ + _compute_tile( + probabilities, + _tile_meta[0], + self.model, + ) + for _tile_meta in tqdm_( + tile_metadata_, + leave=False, + desc="Creating list of delayed tasks for writing annotations", + ) + ] # Compute only this batch in parallel to avoid memory overload. batch_outputs = compute( - *batch, scheduler="threads", num_workers=num_workers + *delayed_tasks, scheduler="threads", num_workers=num_workers ) # Merge each tile result immediately @@ -3086,42 +3098,30 @@ def _compute_tile( return model.postproc_func(head_raws) -def _build_delayed_tile_tasks( - probabilities: list[da.Array | np.ndarray], +def _build_tile_tasks( tile_info_sets: list, - model: ModelABC, -) -> tuple[ - list[Any], # delayed results - list, # metadata +) -> list[ + tuple, # metadata ]: - """Build delayed tile-processing tasks and associated metadata. + """Build tasks for delayed tile-processing using associated metadata. - This function iterates over all tile sets and constructs: - - a list of delayed tasks (each calling `_compute_tile`) - - a parallel list of metadata entries describing each tile - - Metadata entries contain: + This function iterates over all tile sets and constructs + and Metadata entries containing: (tile_bounds, tile_flag, tile_mode) Args: - probabilities: - List of WSI-scale probability maps, one per model head. - Each element is a Dask array or NumPy array with shape (H, W, C). tile_info_sets: A list where each element is a tuple: (set_bounds, set_flags) - set_bounds: list of tile bounds (x0, y0, x1, y1) - set_flags: list of per-tile flags used for instance merging - model: - The multitask model containing a `postproc_func` method. Returns: - A tuple: - - delayed_results: list of delayed `_compute_tile` tasks - - tile_metadata: list of metadata tuples + list: + tile_metadata: list of metadata tuples (tile_bounds, tile_flag, tile_mode) + """ - delayed_results: list = [] tile_metadata: list = [] tqdm_ = get_tqdm() @@ -3141,13 +3141,10 @@ def _build_delayed_tile_tasks( ): tile_flag = set_flags[tile_idx] - # Create delayed tile compute task - delayed_results.append(_compute_tile(probabilities, tile_bounds, model)) - # Store metadata for merging tile_metadata.append((tile_bounds, tile_flag, set_idx)) - return delayed_results, tile_metadata + return tile_metadata @delayed From bcbfb1a6ad4d5ebf620da216f2158028f004c5c3 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 19:19:16 +0000 Subject: [PATCH 103/156] :zap: Pass only small arrays to delayed function --- .../models/engine/multi_task_segmentor.py | 67 +++++++------------ 1 file changed, 23 insertions(+), 44 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 27e27fac8..7196e1295 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1098,15 +1098,31 @@ def _process_tile_mode( ): tile_metadata_ = tile_metadata[i : i + num_workers] + head_raws = [] + for _tile_meta in tqdm_( + tile_metadata_, + leave=False, + desc="Computing tiles for post processing", + ): + tile_bounds = _tile_meta[0] + head_raws.append( + tuple( + [ + p[ + tile_bounds[1] : tile_bounds[3], + tile_bounds[0] : tile_bounds[2], + :, + ].compute() + for p in probabilities + ] + ) + ) + # Build delayed tasks delayed_tasks = [ - _compute_tile( - probabilities, - _tile_meta[0], - self.model, - ) - for _tile_meta in tqdm_( - tile_metadata_, + delayed(self.model.postproc_func)(head_raw) + for head_raw in tqdm_( + head_raws, leave=False, desc="Creating list of delayed tasks for writing annotations", ) @@ -3061,43 +3077,6 @@ def _update_tile_based_predictions_array( return wsi_info_dict -@delayed -def _compute_tile( - probabilities: list[da.Array | np.ndarray], - tile_bounds: tuple[int, int, int, int], - model: ModelABC, -) -> tuple: - """Compute post-processing outputs for a single WSI tile. - - This function performs lazy slicing of the probability maps for the given - tile bounds and applies the model's `postproc_func` to produce per-task - outputs (semantic, instance, etc.). - - Args: - probabilities: - List of WSI-scale probability maps, one per model head. Each element - is either a Dask array or NumPy array with shape (H, W, C). - tile_bounds: - A 4-tuple (x_start, y_start, x_end, y_end) defining the tile region - in WSI coordinates. - model: - The multitask model containing a `postproc_func` method that accepts - a list of tile-level arrays and returns a tuple of task dictionaries. - - Returns: - A tuple of dictionaries, one per task, as produced by `postproc_func`. - Each dictionary typically contains: - - "task_type": str - - "predictions": np.ndarray - - "info_dict": dict - """ - head_raws = [ - p[tile_bounds[1] : tile_bounds[3], tile_bounds[0] : tile_bounds[2], :] - for p in probabilities - ] - return model.postproc_func(head_raws) - - def _build_tile_tasks( tile_info_sets: list, ) -> list[ From cbc27b9f1455557886fcaf287ecf425d05fb7173 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 19:34:22 +0000 Subject: [PATCH 104/156] :bug: Fix deep source error. --- .../models/engine/multi_task_segmentor.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 7196e1295..c22edb8e1 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1106,16 +1106,14 @@ def _process_tile_mode( ): tile_bounds = _tile_meta[0] head_raws.append( - tuple( - [ - p[ - tile_bounds[1] : tile_bounds[3], - tile_bounds[0] : tile_bounds[2], - :, - ].compute() - for p in probabilities - ] - ) + [ + p[ + tile_bounds[1] : tile_bounds[3], + tile_bounds[0] : tile_bounds[2], + :, + ].compute() + for p in probabilities + ] ) # Build delayed tasks From ae09e90b6ad505591b813a7c1d657f5704665060 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 21:08:19 +0000 Subject: [PATCH 105/156] :lipstick: Update tqdm --- tiatoolbox/models/architecture/hovernet.py | 22 ++++++++++---- tiatoolbox/models/architecture/micronet.py | 18 +++++++---- .../models/engine/deep_feature_extractor.py | 12 ++++---- tiatoolbox/utils/misc.py | 30 ++++++++++++++++++- 4 files changed, 64 insertions(+), 18 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index cf34d1dbb..7ff8ffa96 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -23,7 +23,7 @@ centre_crop_to_shape, ) from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils.misc import get_bounding_box, get_tqdm +from tiatoolbox.utils.misc import get_bounding_box, get_tqdm_full class TFSamepaddingLayer(nn.Module): @@ -614,7 +614,12 @@ def _proc_np_hv( return watershed(dist, markers=marker, mask=blb) @staticmethod - def get_instance_info(pred_inst: np.ndarray, pred_type: np.ndarray = None) -> dict: + def get_instance_info( + pred_inst: np.ndarray, + pred_type: np.ndarray = None, + *, + verbose: bool = True, + ) -> dict: """To collect instance information and store it within a dictionary. Args: @@ -624,6 +629,8 @@ def get_instance_info(pred_inst: np.ndarray, pred_type: np.ndarray = None) -> di pred_type (:class:`numpy.ndarray`): An image of shape (height, width, 1) which contains the probabilities of a pixel being a certain type of nuclei. + verbose (bool): + Whether to display progress bar. Returns: dict: @@ -656,10 +663,13 @@ def get_instance_info(pred_inst: np.ndarray, pred_type: np.ndarray = None) -> di """ inst_id_list = np.unique(pred_inst)[1:] # exclude background inst_info_dict = {} - tqdm_ = get_tqdm() - for inst_id in tqdm_( - inst_id_list, leave=False, desc="Generating 'info_dict' for instances" - ): + tqdm_loop = get_tqdm_full( + inst_id_list, + leave=False, + desc="Generating 'info_dict' for instances", + verbose=verbose, + ) + for inst_id in tqdm_loop: inst_map = pred_inst == inst_id inst_box = get_bounding_box(inst_map) inst_box_tl = inst_box[:2] diff --git a/tiatoolbox/models/architecture/micronet.py b/tiatoolbox/models/architecture/micronet.py index 9eb3248b1..01caced79 100644 --- a/tiatoolbox/models/architecture/micronet.py +++ b/tiatoolbox/models/architecture/micronet.py @@ -23,7 +23,7 @@ _inst_dict_for_dask_processing, ) from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils.misc import get_tqdm +from tiatoolbox.utils.misc import get_tqdm_full def group1_forward_branch( @@ -575,12 +575,19 @@ def forward( # skipcq: PYL-W0221 return [out, aux1, aux2, aux3] # skipcq: PYL-W0221 # noqa: ERA001 - def postproc(self: MicroNet, raw_maps: list[np.ndarray | da.Array]) -> tuple[dict]: + def postproc( + self: MicroNet, + raw_maps: list[np.ndarray | da.Array], + *, + verbose: bool = True, + ) -> tuple[dict]: """Post-processing script for MicroNet. Args: raw_maps (list[ndarray | da.Array]): A list of prediction outputs of each head from inference model. + verbose (bool): + Whether to display progress bar. Returns: :class:`numpy.ndarray`: @@ -594,12 +601,13 @@ def postproc(self: MicroNet, raw_maps: list[np.ndarray | da.Array]) -> tuple[dic pred_inst = ndimage.label(pred_bin)[0] pred_inst = morphology.remove_small_objects(pred_inst, min_size=50) canvas = np.zeros(pred_inst.shape[:2], dtype=np.int32) - tqdm_ = get_tqdm() - for inst_id in tqdm_( + tqdm_loop = get_tqdm_full( range(1, np.max(pred_inst) + 1), leave=False, desc="Performing morphological operations to improve segmentation quality.", - ): + verbose=verbose, + ) + for inst_id in tqdm_loop: # Get coordinates of this instance ys, xs = np.where(pred_inst == inst_id) if len(xs) == 0: diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index b09c72ca9..e596e706e 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -48,7 +48,7 @@ from dask import compute from typing_extensions import Unpack -from tiatoolbox.utils.misc import get_tqdm +from tiatoolbox.utils.misc import get_tqdm_full from .patch_predictor import PatchPredictor, PredictorRunParams @@ -292,11 +292,11 @@ def infer_wsi( ) # Inference loop - tqdm = get_tqdm() - tqdm_loop = ( - tqdm(dataloader, leave=False, desc="Inferring patches") - if self.verbose - else dataloader + tqdm_loop = get_tqdm_full( + dataloader, + leave=False, + desc="Inferring Patches", + verbose=self.verbose, ) probabilities_zarr, coordinates_zarr = None, None diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index e9728b256..82a08c1a3 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -34,7 +34,7 @@ from tiatoolbox.utils.exceptions import FileNotSupportedError if TYPE_CHECKING: # pragma: no cover - from collections.abc import Iterator + from collections.abc import Iterable, Iterator from os import PathLike from shapely import geometry @@ -1658,6 +1658,34 @@ def get_tqdm() -> type[tqdm_notebook | tqdm]: return tqdm +def get_tqdm_full( + iterable_input: Iterable, + desc: str = "Processing input", + *, + leave: bool = False, + verbose: bool = True, +) -> type[tqdm_notebook | tqdm] | Iterable: + """Helper function to get appropriate tqdm progress bar. + + Args: + iterable_input (Iterable): + Any iterable input. + desc (str): + tqdm progress bar description. + leave (bool): + Whether to leave progress bar after completion. + verbose (bool): + Whether to return progress bar or the input iterator. + + Returns: + Iterable: + Iterable of tqdm progress bar if self.verbose is True else input Iterable. + + """ + tqdm_ = tqdm_notebook.tqdm if is_notebook() else tqdm + return tqdm_(iterable_input, leave=leave, desc=desc) if verbose else iterable_input + + def cast_to_min_dtype(array: np.ndarray | da.Array) -> np.ndarray | da.Array: """Cast the input array to the minimal data type required to represent its values. From 6b320cdbf87f50c3393212bba8e07b3926cef861 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 21:38:52 +0000 Subject: [PATCH 106/156] :zap: Test computation of arrays using dask delayed. --- .../models/engine/multi_task_segmentor.py | 58 ++++++++++++------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index c22edb8e1..b1e8c281d 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1091,6 +1091,7 @@ def _process_tile_mode( wsi_info_dict = None merge_idx = 0 + self._probabilities = probabilities for i in tqdm_( range(0, len(tile_metadata), num_workers), leave=False, @@ -1098,29 +1099,13 @@ def _process_tile_mode( ): tile_metadata_ = tile_metadata[i : i + num_workers] - head_raws = [] - for _tile_meta in tqdm_( - tile_metadata_, - leave=False, - desc="Computing tiles for post processing", - ): - tile_bounds = _tile_meta[0] - head_raws.append( - [ - p[ - tile_bounds[1] : tile_bounds[3], - tile_bounds[0] : tile_bounds[2], - :, - ].compute() - for p in probabilities - ] - ) - # Build delayed tasks delayed_tasks = [ - delayed(self.model.postproc_func)(head_raw) - for head_raw in tqdm_( - head_raws, + self._compute_tile( + _tile_meta[0], + ) + for _tile_meta in tqdm_( + tile_metadata_, leave=False, desc="Creating list of delayed tasks for writing annotations", ) @@ -1195,6 +1180,37 @@ def _process_tile_mode( return wsi_info_dict + @delayed + def _compute_tile( + self: MultiTaskSegmentor, + tile_bounds: tuple[int, int, int, int], + ) -> tuple: + """Compute post-processing outputs for a single WSI tile. + + This function performs lazy slicing of the probability maps for the given + tile bounds and applies the model's `postproc_func` to produce per-task + outputs (semantic, instance, etc.). + + Args: + tile_bounds: + A 4-tuple (x_start, y_start, x_end, y_end) defining the tile region + in WSI coordinates. + + Returns: + A tuple of dictionaries, one per task, as produced by `postproc_func`. + Each dictionary typically contains: + - "task_type": str + - "predictions": np.ndarray + - "info_dict": dict + """ + head_raws = [ + p[ + tile_bounds[1] : tile_bounds[3], tile_bounds[0] : tile_bounds[2], : + ].compute() + for p in self._probabilities + ] + return self.model.postproc_func(head_raws) + @staticmethod def _get_tile_info( image_shape: list[int] | np.ndarray, From eaa13aff0c6a54851614774ceaa997f864453a3d Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 21:50:30 +0000 Subject: [PATCH 107/156] :zap: Test computation of arrays using dask delayed. --- .../models/engine/multi_task_segmentor.py | 32 +++++++++++++++++-- .../models/engine/semantic_segmentor.py | 20 +++++++++--- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index b1e8c281d..e4f192352 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -137,7 +137,12 @@ from tiatoolbox.annotation import SQLiteStore from tiatoolbox.annotation.storage import Annotation from tiatoolbox.tools.patchextraction import PatchExtractor -from tiatoolbox.utils.misc import create_smart_array, get_tqdm, make_valid_poly +from tiatoolbox.utils.misc import ( + create_smart_array, + get_tqdm, + get_tqdm_full, + make_valid_poly, +) from tiatoolbox.wsicore.wsireader import is_zarr from .semantic_segmentor import ( @@ -652,6 +657,7 @@ def infer_wsi( tqdm_=tqdm_, save_path=save_path, num_expected_output=num_expected_output, + verbose=self.verbose, ) ) @@ -678,6 +684,7 @@ def infer_wsi( output_locs_y_=output_locs_y_, save_path=save_path, memory_threshold=memory_threshold, + verbose=self.verbose, ) raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) @@ -2348,6 +2355,8 @@ def save_multitask_to_cache( canvas_zarr: list[zarr.Array | None], count_zarr: list[zarr.Array | None], save_path: str | Path = "temp.zarr", + *, + verbose: bool = True, ) -> tuple[list[zarr.Array], list[zarr.Array]]: """Write accumulated horizontal row blocks to a Zarr cache on disk. @@ -2386,6 +2395,8 @@ def save_multitask_to_cache( save_path (str | Path): Path to the Zarr group used for caching. A new group is created if needed on the first spill. + verbose (bool): + Whether to display progress bar. Returns: tuple[list[zarr.Array], list[zarr.Array]]: @@ -2403,7 +2414,12 @@ def save_multitask_to_cache( and ``count`` to free RAM and continue populating new entries. """ - for idx, canvas_ in enumerate(canvas): + tqdm_loop = get_tqdm_full( + canvas, + desc="Memory Overload, Spilling to disk", + verbose=verbose, + ) + for idx, canvas_ in enumerate(tqdm_loop): canvas_zarr[idx], count_zarr[idx] = save_to_cache( canvas=canvas_, count=count[idx], @@ -2411,6 +2427,7 @@ def save_multitask_to_cache( count_zarr=count_zarr[idx], save_path=save_path, zarr_dataset_name=(f"canvas/{idx}", f"count/{idx}"), + verbose=verbose, ) return canvas_zarr, count_zarr @@ -2627,12 +2644,18 @@ def _calculate_probabilities( output_locs_y_: np.ndarray, save_path: Path, memory_threshold: int, + *, + verbose: bool, ) -> list[da.Array]: """Helper function to calculate probabilities for MultiTaskSegmentor.""" zarr_group = None if canvas_zarr[0] is not None: canvas_zarr, count_zarr = save_multitask_to_cache( - canvas, count, canvas_zarr, count_zarr + canvas, + count, + canvas_zarr, + count_zarr, + verbose=verbose, ) # Wrap zarr in dask array for idx, canvas_zarr_ in enumerate(canvas_zarr): @@ -2662,6 +2685,8 @@ def _check_and_update_for_memory_overload( tqdm_: type[tqdm_notebook | tqdm], save_path: Path, num_expected_output: int, + *, + verbose: bool = True, ) -> tuple[ list[da.Array | None], list[da.Array | None], @@ -2696,6 +2721,7 @@ def _check_and_update_for_memory_overload( canvas_zarr, count_zarr, save_path=save_path, + verbose=verbose, ) canvas = [None for _ in range(num_expected_output)] count = [None for _ in range(num_expected_output)] diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 4e9703b03..f6cfd111a 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -68,6 +68,7 @@ from tiatoolbox.utils.misc import ( dict_to_store_semantic_segmentor, get_tqdm, + get_tqdm_full, ) from tiatoolbox.wsicore.wsireader import is_zarr @@ -544,6 +545,7 @@ def infer_wsi( canvas_zarr, count_zarr, save_path=save_path, + verbose=self.verbose, ) canvas, count = None, None gc.collect() @@ -567,7 +569,11 @@ def infer_wsi( zarr_group = None if canvas_zarr is not None: canvas_zarr, count_zarr = save_to_cache( - canvas, count, canvas_zarr, count_zarr + canvas, + count, + canvas_zarr, + count_zarr, + verbose=self.verbose, ) # Wrap zarr in dask array canvas = da.from_zarr(canvas_zarr, chunks=canvas_zarr.chunks) @@ -1119,6 +1125,8 @@ def save_to_cache( count_zarr: zarr.Array, save_path: str | Path = "temp.zarr", zarr_dataset_name: tuple[str, str] = ("canvas", "count"), + *, + verbose: bool = True, ) -> tuple[zarr.Array, zarr.Array]: """Incrementally save computed canvas and count arrays to Zarr cache. @@ -1141,6 +1149,8 @@ def save_to_cache( zarr_dataset_name (tuple[str, str]): Tuple of name for zarr dataset to save canvas and count. Defaults to ("canvas", "count"). + verbose (bool): + Whether to display progress bar. Returns: tuple[zarr.Array, zarr.Array]: @@ -1148,8 +1158,6 @@ def save_to_cache( """ chunk0 = canvas.chunks[0][0] - tqdm_ = get_tqdm() - if canvas_zarr is None: zarr_group = zarr.open(str(save_path), mode="a") @@ -1188,11 +1196,13 @@ def save_to_cache( # Append remaining blocks one-at-a-time to limit peak memory. num_blocks = canvas.numblocks[0] - for block_idx in tqdm_( + tqdm_loop = get_tqdm_full( range(start_idx, num_blocks), leave=False, desc="Memory Overload, Spilling to disk", - ): + verbose=verbose, + ) + for block_idx in tqdm_loop: canvas_block = canvas.blocks[block_idx, 0, 0].compute() count_block = count.blocks[block_idx, 0, 0].compute() From 199674c46e23a97ff01ec9deb0eda89e0ea55a8c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 22:12:50 +0000 Subject: [PATCH 108/156] :bug: Fix tqdm_loop --- .../models/engine/semantic_segmentor.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index f6cfd111a..05b685f57 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -588,6 +588,7 @@ def infer_wsi( zarr_group, save_path, memory_threshold, + verbose=self.verbose, ) raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) @@ -1226,6 +1227,8 @@ def merge_vertical_chunkwise( zarr_group: zarr.Group, save_path: Path, memory_threshold: int = 80, + *, + verbose: bool = True, ) -> da.Array: """Merge vertically chunked canvas and count arrays into a single probability map. @@ -1249,6 +1252,8 @@ def merge_vertical_chunkwise( is saved in a Zarr file. memory_threshold (int): Memory usage threshold (in percentage) to trigger caching behavior. + verbose (bool): + Whether to display progress bar. Returns: da.Array: @@ -1263,8 +1268,12 @@ def merge_vertical_chunkwise( probabilities_zarr, probabilities_da = None, None chunk_shape = tuple(chunk[0] for chunk in canvas.chunks) - tqdm_ = get_tqdm() - tqdm_loop = tqdm_(overlaps, leave=False, desc="Merging rows") + tqdm_loop = get_tqdm_full( + overlaps, + leave=False, + desc="Merging rows", + verbose=verbose, + ) used_percent = 0 @@ -1294,13 +1303,13 @@ def merge_vertical_chunkwise( vm = psutil.virtual_memory() used_percent = (probabilities_da.nbytes / vm.free) * 100 if probabilities_zarr is None and used_percent > memory_threshold: - desc = tqdm_.desc + desc = tqdm_loop.desc msg = ( f"Current Memory usage: {used_percent} % " f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) - tqdm_.desc = msg + tqdm_loop.desc = msg zarr_group = zarr.open(str(save_path), mode="a") probabilities_zarr = zarr_group.create_dataset( name="probabilities", @@ -1312,7 +1321,7 @@ def merge_vertical_chunkwise( probabilities_zarr[:] = probabilities_da.compute() probabilities_da = None - tqdm_.desc = desc + tqdm_loop.desc = desc if next_chunk is not None: curr_chunk, curr_count = next_chunk[overlap:], next_count[overlap:] From 11513d6976dcec89372db53c96019389468a980c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 22:22:13 +0000 Subject: [PATCH 109/156] :lipstick: Update tqdm loop engine_abc.py --- tiatoolbox/models/engine/engine_abc.py | 35 +++++++++++--------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index eece332de..6037728d2 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -58,7 +58,7 @@ from tiatoolbox.models.models_abc import load_torch_model from tiatoolbox.utils.misc import ( dict_to_store_patch_predictions, - get_tqdm, + get_tqdm_full, ) from tiatoolbox.wsicore.wsireader import WSIReader, is_zarr @@ -532,11 +532,11 @@ def infer_patches( raw_predictions = {key: [] for key in keys} # Inference loop - tqdm = get_tqdm() - tqdm_loop = ( - tqdm(dataloader, leave=False, desc="Inferring patches") - if self.verbose - else self.dataloader + tqdm_loop = get_tqdm_full( + dataloader, + leave=False, + desc="Inferring patches", + verbose=self.verbose, ) infer_batch = self._get_model_attr("infer_batch") @@ -1542,14 +1542,6 @@ def _run_wsi_mode( Output may be a zarr file, SQLite database, or in-memory dictionary. """ - progress_bar = None - tqdm = get_tqdm() - - if self.verbose: - progress_bar = tqdm( - total=len(self.images), - desc="Processing WSIs", - ) suffix = ".zarr" if output_type == "AnnotationStore": suffix = ".db" @@ -1568,7 +1560,14 @@ def get_path(image: Path | WSIReader) -> Path: for image in self.images } - for image_num, image in enumerate(self.images): + tqdm_loop = get_tqdm_full( + self.images, + leave=False, + desc="Processing WSIs", + verbose=self.verbose, + ) + + for image_num, image in enumerate(tqdm_loop): duplicate_filter = DuplicateFilter() logger.addFilter(duplicate_filter) mask = self.masks[image_num] if self.masks is not None else None @@ -1606,12 +1605,6 @@ def get_path(image: Path | WSIReader) -> Path: msg = f"Output file saved at {out[get_path(image)]}." logger.info(msg=msg) - if progress_bar: - progress_bar.update() - - if progress_bar: - progress_bar.close() - return out def run( From ca55c74ad69f682db6a7a75685916cbbcc5821a6 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 22:25:48 +0000 Subject: [PATCH 110/156] :lipstick: Update tqdm loop semantic_segmentor.py --- tiatoolbox/models/engine/semantic_segmentor.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 05b685f57..a5634f8a3 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -67,7 +67,6 @@ from tiatoolbox.models.dataset.dataset_abc import WSIPatchDataset from tiatoolbox.utils.misc import ( dict_to_store_semantic_segmentor, - get_tqdm, get_tqdm_full, ) from tiatoolbox.wsicore.wsireader import is_zarr @@ -461,11 +460,11 @@ def infer_wsi( ) # Inference loop - tqdm_ = get_tqdm() - tqdm_loop = ( - tqdm_(dataloader, leave=False, desc="Inferring patches") - if self.verbose - else dataloader + tqdm_loop = get_tqdm_full( + dataloader, + leave=False, + desc="Inferring patches", + verbose=self.verbose, ) canvas_np, output_locs_y_ = None, None From 0ba76278907b31b245ccd03a0a60cf7c22b9cce1 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 22:35:05 +0000 Subject: [PATCH 111/156] :lipstick: Update tqdm loop semantic_segmentor.py --- tiatoolbox/models/engine/engine_abc.py | 1 + .../models/engine/semantic_segmentor.py | 2 ++ tiatoolbox/utils/misc.py | 33 ++++++++++++++----- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 6037728d2..045afe225 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -737,6 +737,7 @@ def save_predictions( scale_factor, class_dict, save_path, + verbose=self.verbose, ) msg = f"Unsupported output type: {output_type}" diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index a5634f8a3..2dd0999b5 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -711,6 +711,7 @@ def save_predictions( scale_factor=scale_factor, class_dict=class_dict, save_path=output_path, + verbose=self.verbose, ) save_paths.append(out_file) @@ -720,6 +721,7 @@ def save_predictions( scale_factor=scale_factor, class_dict=class_dict, save_path=save_path.with_suffix(".db"), + verbose=self.verbose, ) save_paths = out_file diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 82a08c1a3..61b3b90e1 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1217,14 +1217,19 @@ def patch_predictions_as_annotations( patch_coords: list, classes_predicted: list, labels: list, + *, + verbose: bool = True, ) -> list: """Helper function to generate annotation per patch predictions.""" annotations = [] - tqdm_ = get_tqdm() + tqdm_loop = get_tqdm_full( + patch_coords, + leave=False, + desc="Converting outputs to AnnotationStore.", + verbose=verbose, + ) - for i, _ in enumerate( - tqdm_(patch_coords, leave=False, desc="Converting outputs to AnnotationStore.") - ): + for i, _ in enumerate(tqdm_loop): if "probabilities" in keys: props = { f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted @@ -1358,6 +1363,8 @@ def dict_to_store_semantic_segmentor( scale_factor: tuple[float, float], class_dict: dict | None = None, save_path: Path | None = None, + *, + verbose: bool = True, ) -> AnnotationStore | Path: """Converts output of TIAToolbox SemanticSegmentor engine to AnnotationStore. @@ -1374,6 +1381,8 @@ def dict_to_store_semantic_segmentor( save_path (str or Path): Optional Output directory to save the Annotation Store results. + verbose (bool): + Whether to display logs and progress bar. Returns: (SQLiteStore or Path): @@ -1394,11 +1403,14 @@ def dict_to_store_semantic_segmentor( annotations_list: list[Annotation] = [] - tqdm_ = get_tqdm() + tqdm_loop = get_tqdm_full( + layer_list, + leave=False, + desc="Converting outputs to AnnotationStore.", + verbose=verbose, + ) - for type_class in tqdm_( - layer_list, leave=False, desc="Converting outputs to AnnotationStore." - ): + for type_class in tqdm_loop: class_id = int(type_class) class_label = class_dict.get(class_id, class_id) layer = da.where(preds == type_class, 1, 0).astype("uint8").compute() @@ -1437,6 +1449,8 @@ def dict_to_store_patch_predictions( scale_factor: tuple[float, float], class_dict: dict | None = None, save_path: Path | None = None, + *, + verbose: bool = True, ) -> AnnotationStore | Path: """Converts output of TIAToolbox PatchPredictor engine to AnnotationStore. @@ -1454,6 +1468,8 @@ def dict_to_store_patch_predictions( save_path (str or Path): Optional Output directory to save the Annotation Store results. + verbose (bool): + Whether to display logs and progress bar. Returns: (SQLiteStore or Path): @@ -1502,6 +1518,7 @@ def dict_to_store_patch_predictions( patch_coords.astype(float), classes_predicted, labels, + verbose=verbose, ) store = SQLiteStore() From 9a035506e726310617e999469e050a5caa24e24e Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 22:53:58 +0000 Subject: [PATCH 112/156] :lipstick: Use `get_tqdm_full` for nucleus_detector.py --- .../models/engine/multi_task_segmentor.py | 5 ++-- tiatoolbox/models/engine/nucleus_detector.py | 29 +++++++++++++++---- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index e4f192352..d2be6a6de 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1098,7 +1098,8 @@ def _process_tile_mode( wsi_info_dict = None merge_idx = 0 - self._probabilities = probabilities + # Only used for delayed processing. + self._probabilities = probabilities # skipcq: PYL-W0201 for i in tqdm_( range(0, len(tile_metadata), num_workers), leave=False, @@ -1184,7 +1185,7 @@ def _process_tile_mode( chunks=(len(col),), ) wsi_info_dict[idx]["info_dict"] = dict_info_wsi - + delattr(self, "_probabilities") return wsi_info_dict @delayed diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index 05b81a535..36a625d50 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -63,7 +63,7 @@ SemanticSegmentor, SemanticSegmentorRunParams, ) -from tiatoolbox.utils.misc import get_tqdm +from tiatoolbox.utils.misc import get_tqdm_full if TYPE_CHECKING: # pragma: no cover import os @@ -471,7 +471,9 @@ def post_process_wsi( zarr_group = zarr.open(zarr_file, mode="r+") centroid_maps = da.from_zarr(zarr_group["centroid_maps"]) - return self._centroid_maps_to_detection_arrays(centroid_maps) + return self._centroid_maps_to_detection_arrays( + centroid_maps, verbose=self.verbose + ) def save_predictions( self: NucleusDetector, @@ -710,6 +712,8 @@ def _save_predictions_annotation_store( @staticmethod def _centroid_maps_to_detection_arrays( detection_maps: da.Array, + *, + verbose: bool = True, ) -> dict[str, da.Array]: """Convert centroid maps into 1-D detection arrays. @@ -727,6 +731,8 @@ def _centroid_maps_to_detection_arrays( detections. Each non-zero entry encodes both the class channel and its associated probability. This array is expected to be already computed. + verbose (bool): + Whether to display logs and progress bar. Returns: dict[str, da.Array]: @@ -763,8 +769,12 @@ class IDs for each detection (``np.uint32``). classes_list = [] probs_list = [] - tqdm = get_tqdm() - for i in tqdm(range(num_blocks_h), desc="Processing detection blocks"): + tqdm_loop = get_tqdm_full( + range(num_blocks_h), + desc="Processing detection blocks", + verbose=verbose, + ) + for i in tqdm_loop: for j in range(num_blocks_w): # Get block offsets y_offset = sum(detection_maps.chunks[0][:i]) if i > 0 else 0 @@ -814,6 +824,8 @@ def _write_detection_arrays_to_store( scale_factor: tuple[float, float], class_dict: dict[int, str | int] | None, batch_size: int = 5000, + *, + verbose: bool = True, ) -> int: """Write detection arrays to an AnnotationStore in batches. @@ -839,6 +851,8 @@ def _write_detection_arrays_to_store( If `None`, an identity mapping is used for the set of present classes. batch_size (int): Number of records to write per batch. Default is `5000`. + verbose (bool): + Whether to display logs and progress bar. Returns: int: @@ -878,8 +892,11 @@ def make_points(xs_batch: np.ndarray, ys_batch: np.ndarray) -> list[Point]: for xx, yy in zip(xs_batch, ys_batch, strict=True) ] - tqdm = get_tqdm() - tqdm_loop = tqdm(range(0, n, batch_size), desc="Writing detections to store") + tqdm_loop = get_tqdm_full( + range(0, n, batch_size), + desc="Writing detections to store", + verbose=verbose, + ) written = 0 for i in tqdm_loop: j = min(i + batch_size, n) From 1480c2c494449152e576771f27c181601a0548bf Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 23:26:33 +0000 Subject: [PATCH 113/156] :lipstick: Use `get_tqdm_full` for nucleus_detector.py --- tests/engines/test_multi_task_segmentor.py | 1 + .../models/engine/multi_task_segmentor.py | 127 +++++++++++------- tiatoolbox/utils/misc.py | 7 - 3 files changed, 78 insertions(+), 57 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 612615228..9e3e96a88 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -338,6 +338,7 @@ def test_wsi_mtsegmentor_zarr( # Prediction masks can be tracked and saved as for layer segmentation in # HoVerNet Plus. return_predictions=(False, True), + verbose=False, ) output_tile_ = zarr.open(output_tile[wsi4_1k_1k_svs], mode="r") diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index d2be6a6de..7e334edd5 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -139,7 +139,6 @@ from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils.misc import ( create_smart_array, - get_tqdm, get_tqdm_full, make_valid_poly, ) @@ -426,11 +425,11 @@ def infer_patches( raw_predictions["probabilities"] = [[] for _ in range(num_expected_output)] # Inference loop - tqdm_ = get_tqdm() - tqdm_loop = ( - tqdm_(dataloader, leave=False, desc="Inferring patches") - if self.verbose - else dataloader + tqdm_loop = get_tqdm_full( + dataloader, + leave=False, + desc="Inferring patches", + verbose=self.verbose, ) for batch_data in tqdm_loop: @@ -572,11 +571,11 @@ def infer_wsi( ) # Inference loop - tqdm_ = get_tqdm() - tqdm_loop = ( - tqdm_(dataloader, leave=False, desc="Inferring patches") - if self.verbose - else dataloader + tqdm_loop = get_tqdm_full( + dataloader, + leave=False, + desc="Inferring patches", + verbose=self.verbose, ) # Expected number of outputs from the model @@ -654,7 +653,6 @@ def infer_wsi( count_zarr=count_zarr, memory_threshold=memory_threshold, tqdm_loop=tqdm_loop, - tqdm_=tqdm_, save_path=save_path, num_expected_output=num_expected_output, verbose=self.verbose, @@ -1090,20 +1088,20 @@ def _process_tile_mode( tile_info_sets = self._get_tile_info(wsi_proc_shape, self.ioconfig) ioconfig = self.ioconfig.to_baseline() - tqdm_ = get_tqdm() - tile_metadata = _build_tile_tasks( tile_info_sets=tile_info_sets, + verbose=self.verbose, ) wsi_info_dict = None merge_idx = 0 # Only used for delayed processing. self._probabilities = probabilities # skipcq: PYL-W0201 - for i in tqdm_( + for i in get_tqdm_full( range(0, len(tile_metadata), num_workers), leave=False, desc="Post-Processing WSI to generate predictions and contours", + verbose=self.verbose, ): tile_metadata_ = tile_metadata[i : i + num_workers] @@ -1112,10 +1110,11 @@ def _process_tile_mode( self._compute_tile( _tile_meta[0], ) - for _tile_meta in tqdm_( + for _tile_meta in get_tqdm_full( tile_metadata_, leave=False, desc="Creating list of delayed tasks for writing annotations", + verbose=self.verbose, ) ] @@ -1172,8 +1171,11 @@ def _process_tile_mode( wsi_info_dict[inst_id]["info_dict"].pop(inst_uuid, None) for idx, wsi_info_dict_ in enumerate( - tqdm_( - wsi_info_dict, leave=False, desc="Converting 'info_dict' to dask arrays" + get_tqdm_full( + wsi_info_dict, + leave=False, + desc="Converting 'info_dict' to dask arrays", + verbose=self.verbose, ) ): info_df = pd.DataFrame(wsi_info_dict_["info_dict"]).transpose() @@ -1629,6 +1631,7 @@ def _save_predictions_as_annotationstore( class_dict=class_dict, scale_factor=scale_factor, num_workers=num_workers, + verbose=self.verbose, ) save_paths.append(output_path) @@ -1645,6 +1648,7 @@ def _save_predictions_as_annotationstore( class_dict=class_dict, scale_factor=scale_factor, num_workers=num_workers, + verbose=self.verbose, ) save_paths.append(output_path) @@ -1994,6 +1998,8 @@ def dict_to_store( origin: tuple[float, float] = (0, 0), scale_factor: tuple[float, float] = (1, 1), num_workers: int = multiprocessing.cpu_count(), + *, + verbose: bool = True, ) -> AnnotationStore: """Write polygonal multitask predictions into an SQLite-backed AnnotationStore. @@ -2035,6 +2041,8 @@ def dict_to_store( num_workers (int): Number of parallel worker threads to use. If set to 0 or None, defaults to the number of CPU cores. + verbose (bool): + Whether to display logs and progress bar. Returns: AnnotationStore: @@ -2055,8 +2063,6 @@ def dict_to_store( contours = processed_predictions.pop("contours") n = len(contours) - tqdm_ = get_tqdm() - # Build delayed tasks delayed_tasks = [ _build_single_annotation( @@ -2067,15 +2073,19 @@ def dict_to_store( origin, scale_factor, ) - for i in tqdm_( + for i in get_tqdm_full( range(n), leave=False, desc="Creating list of delayed tasks for writing annotations", + verbose=verbose, ) ] ann = compute_dask_delayed_with_progress( - delayed_tasks, num_workers=num_workers, desc="Saving annotations " + delayed_tasks, + num_workers=num_workers, + desc="Saving annotations ", + verbose=verbose, ) logger.info("Added %d annotations.", len(ann)) @@ -2089,6 +2099,8 @@ def compute_dask_delayed_with_progress( num_workers: int = multiprocessing.cpu_count(), desc: str = "Computing", batch_size: int | None = None, + *, + verbose: bool = True, ) -> list: """Compute a list of Dask delayed tasks in parallel while displaying a progress bar. @@ -2108,6 +2120,8 @@ def compute_dask_delayed_with_progress( batch_size (int | None): batch_size to process dask delayed. batch_size is set to num_workers if batch_size is not provided. + verbose (bool): + Whether to display logs and progress bar. Returns: A list containing the computed results from all delayed tasks, in order. @@ -2117,21 +2131,22 @@ def compute_dask_delayed_with_progress( batch_size = num_workers if batch_size is None else batch_size results: list[Any] = [] - tqdm_ = get_tqdm() - - with tqdm_(total=total, desc=desc, leave=False) as pbar: - for i in range(0, total, batch_size): - batch = delayed_tasks[i : i + batch_size] + for i in get_tqdm_full( + range(0, total, batch_size), + desc=desc, + leave=False, + verbose=verbose, + ): + batch = delayed_tasks[i : i + batch_size] - # Compute this batch in parallel - batch_results = compute( - *batch, - scheduler="threads", - num_workers=num_workers, - ) + # Compute this batch in parallel + batch_results = compute( + *batch, + scheduler="threads", + num_workers=num_workers, + ) - results.extend(batch_results) - pbar.update(len(batch)) + results.extend(batch_results) return results @@ -2441,6 +2456,8 @@ def merge_multitask_vertical_chunkwise( zarr_group: zarr.Group, save_path: Path, memory_threshold: int = 80, + *, + verbose: bool = True, ) -> list[da.Array]: """Merge horizontally stitched row blocks into final WSI probability maps. @@ -2484,6 +2501,8 @@ def merge_multitask_vertical_chunkwise( memory_threshold (int): Maximum allowed RAM usage (percentage) before converting in-memory probability accumulators to Zarr-backed arrays. Default is 80. + verbose (bool): + Whether to display logs and progress bar. Returns: list[da.Array]: @@ -2517,16 +2536,17 @@ def merge_multitask_vertical_chunkwise( num_chunks = canvas_.numblocks[0] chunk_shape = tuple(chunk[0] for chunk in canvas_.chunks) - tqdm_ = get_tqdm() - tqdm_loop = tqdm_( - overlaps, leave=False, desc=f"Merging rows for probability map {idx}" - ) - curr_chunk = canvas_.blocks[0, 0].compute() curr_count = count[idx].blocks[0, 0].compute() next_chunk = canvas_.blocks[1, 0].compute() if num_chunks > 1 else None next_count = count[idx].blocks[1, 0].compute() if num_chunks > 1 else None + tqdm_loop = get_tqdm_full( + overlaps, + leave=False, + desc=f"Merging rows for probability map {idx}", + verbose=verbose, + ) for i, overlap in enumerate(tqdm_loop): if next_chunk is not None and overlap > 0: curr_chunk[-overlap:] += next_chunk[:overlap] @@ -2550,7 +2570,7 @@ def merge_multitask_vertical_chunkwise( probabilities_da=probabilities_da, probabilities=probabilities, idx=idx, - tqdm_=tqdm_, + tqdm_loop=tqdm_loop, save_path=save_path, chunk_shape=chunk_shape, memory_threshold=memory_threshold, @@ -2582,7 +2602,7 @@ def _save_multitask_vertical_to_cache( probabilities_da: list[da.Array] | list[None], probabilities: np.ndarray, idx: int, - tqdm_: type[tqdm_notebook | tqdm], + tqdm_loop: type[tqdm_notebook | tqdm], save_path: Path, chunk_shape: tuple, memory_threshold: int = 80, @@ -2595,13 +2615,13 @@ def _save_multitask_vertical_to_cache( total_bytes = sum(0 if arr is None else arr.nbytes for arr in probabilities_da) used_percent = (total_bytes / max(vm.available, 1)) * 100 if probabilities_zarr[idx] is None and used_percent > memory_threshold: - desc = tqdm_.desc + desc = tqdm_loop.desc msg = ( f"Current Memory usage: {used_percent} % " f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) - tqdm_.desc = msg + tqdm_loop.desc = msg zarr_group = zarr.open(str(save_path), mode="a") probabilities_zarr[idx] = zarr_group.create_dataset( name=f"probabilities/{idx}", @@ -2611,7 +2631,7 @@ def _save_multitask_vertical_to_cache( overwrite=True, ) probabilities_zarr[idx][:] = probabilities_da[idx].compute() - tqdm_.desc = desc + tqdm_loop.desc = desc probabilities_da[idx] = None return probabilities_zarr, probabilities_da @@ -2683,7 +2703,6 @@ def _check_and_update_for_memory_overload( count_zarr: list[zarr.Array | None], memory_threshold: int, tqdm_loop: DataLoader | tqdm, - tqdm_: type[tqdm_notebook | tqdm], save_path: Path, num_expected_output: int, *, @@ -2714,7 +2733,7 @@ def _check_and_update_for_memory_overload( f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) - tqdm_.desc = msg + tqdm_loop.desc = msg # Flush data in Memory and clear dask graph canvas_zarr, count_zarr = save_multitask_to_cache( canvas, @@ -2742,6 +2761,8 @@ def _save_annotation_store( class_dict: dict, scale_factor: tuple[float, float], num_workers: int, + *, + verbose: bool = True, ) -> Path: """Helper function to save to annotation store.""" if isinstance(curr_image, Path): @@ -2769,6 +2790,7 @@ def _save_annotation_store( scale_factor=scale_factor, origin=origin, num_workers=num_workers, + verbose=verbose, ) store.commit() @@ -3120,6 +3142,8 @@ def _update_tile_based_predictions_array( def _build_tile_tasks( tile_info_sets: list, + *, + verbose: bool = True, ) -> list[ tuple, # metadata ]: @@ -3135,6 +3159,8 @@ def _build_tile_tasks( (set_bounds, set_flags) - set_bounds: list of tile bounds (x0, y0, x1, y1) - set_flags: list of per-tile flags used for instance merging + verbose (bool): + Whether to display logs and progress bar. Returns: list: @@ -3143,20 +3169,21 @@ def _build_tile_tasks( """ tile_metadata: list = [] - tqdm_ = get_tqdm() for set_idx, (set_bounds, set_flags) in enumerate( - tqdm_( + get_tqdm_full( tile_info_sets, leave=False, desc="Building delayed tile-processing tasks", + verbose=verbose, ) ): for tile_idx, tile_bounds in enumerate( - tqdm_( + get_tqdm_full( set_bounds, leave=False, desc=f"Building delayed tile-processing tasks for tile set {set_idx}", + verbose=verbose, ) ): tile_flag = set_flags[tile_idx] diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 61b3b90e1..e2976bbfb 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1668,13 +1668,6 @@ def write_probability_heatmap_as_ome_tiff( logger.info(msg) -def get_tqdm() -> type[tqdm_notebook | tqdm]: - """Returns appropriate tqdm tqdm object.""" - if is_notebook(): # pragma: no cover - return tqdm_notebook.tqdm - return tqdm - - def get_tqdm_full( iterable_input: Iterable, desc: str = "Processing input", From 6fb8293f0cc8dc396abf9f5cde107b241842f346 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 7 Feb 2026 23:33:48 +0000 Subject: [PATCH 114/156] :bug: Fix `mypy` errors. --- tests/engines/test_multi_task_segmentor.py | 18 ++++++------------ tiatoolbox/utils/misc.py | 2 +- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 9e3e96a88..2470031c0 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -4,7 +4,7 @@ import shutil from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, Final +from typing import TYPE_CHECKING, Any, Final import dask.array as da import numpy as np @@ -25,6 +25,7 @@ ) from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils import imwrite +from tiatoolbox.utils.misc import get_tqdm_full from tiatoolbox.wsicore import WSIReader if TYPE_CHECKING: @@ -594,16 +595,9 @@ class FakeVM: # --- Real numpy array for shape/dtype --- probabilities = np.zeros((1, 3)) - class DummyTqdm: - """Dummy tqdm with a write() method.""" - - messages: ClassVar[list[str]] = [] - desc: str = "Test Method" - - @classmethod - def write(cls: DummyTqdm, msg: str) -> None: - """Append a message to the messages list.""" - cls.messages.append(msg) + tqdm_loop = get_tqdm_full( + range(1), + ) # --- Call function --- new_zarr, new_da = _save_multitask_vertical_to_cache( @@ -611,7 +605,7 @@ def write(cls: DummyTqdm, msg: str) -> None: probabilities_da=probabilities_da, probabilities=probabilities, idx=idx, - tqdm_=DummyTqdm, + tqdm_loop=tqdm_loop, save_path=tmp_path / "cache.zarr", chunk_shape=(1,), memory_threshold=0, # ensure branch triggers diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index e2976bbfb..a5e94e7c3 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1674,7 +1674,7 @@ def get_tqdm_full( *, leave: bool = False, verbose: bool = True, -) -> type[tqdm_notebook | tqdm] | Iterable: +) -> Iterable: """Helper function to get appropriate tqdm progress bar. Args: From 1fc69d7e07f447f1d625a14a029d87d9ce82ab25 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sun, 8 Feb 2026 00:04:51 +0000 Subject: [PATCH 115/156] :zap: Try computing full array --- .../models/engine/multi_task_segmentor.py | 118 ++++++++---------- 1 file changed, 53 insertions(+), 65 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 7e334edd5..eff49391d 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1094,81 +1094,69 @@ def _process_tile_mode( ) wsi_info_dict = None - merge_idx = 0 # Only used for delayed processing. self._probabilities = probabilities # skipcq: PYL-W0201 - for i in get_tqdm_full( - range(0, len(tile_metadata), num_workers), - leave=False, - desc="Post-Processing WSI to generate predictions and contours", - verbose=self.verbose, - ): - tile_metadata_ = tile_metadata[i : i + num_workers] - - # Build delayed tasks - delayed_tasks = [ - self._compute_tile( - _tile_meta[0], - ) - for _tile_meta in get_tqdm_full( - tile_metadata_, - leave=False, - desc="Creating list of delayed tasks for writing annotations", - verbose=self.verbose, - ) - ] - # Compute only this batch in parallel to avoid memory overload. - batch_outputs = compute( - *delayed_tasks, scheduler="threads", num_workers=num_workers + # Build delayed tasks + delayed_tasks = [ + self._compute_tile( + _tile_meta[0], + ) + for _tile_meta in get_tqdm_full( + tile_metadata, + leave=False, + desc="Creating list of delayed tasks for writing annotations", + verbose=self.verbose, ) + ] - # Merge each tile result immediately - for post_process_output in batch_outputs: - tile_bounds, tile_flag, tile_mode = tile_metadata[merge_idx] - merge_idx += 1 + # Compute only this batch in parallel to avoid memory overload. + batch_outputs = compute( + *delayed_tasks, scheduler="threads", num_workers=num_workers + ) - # create a list of info dict for each task - wsi_info_dict = _create_wsi_info_dict( - post_process_output=post_process_output, - wsi_info_dict=wsi_info_dict, - wsi_proc_shape=wsi_proc_shape, - save_path=save_path, - memory_threshold=memory_threshold, - return_predictions=return_predictions, - ) + # Merge each tile result immediately + for merge_idx, post_process_output in enumerate(batch_outputs): + tile_bounds, tile_flag, tile_mode = tile_metadata[merge_idx] + # create a list of info dict for each task + wsi_info_dict = _create_wsi_info_dict( + post_process_output=post_process_output, + wsi_info_dict=wsi_info_dict, + wsi_proc_shape=wsi_proc_shape, + save_path=save_path, + memory_threshold=memory_threshold, + return_predictions=return_predictions, + ) - wsi_info_dict = _update_tile_based_predictions_array( - post_process_output=post_process_output, - wsi_info_dict=wsi_info_dict, - bounds=tile_bounds, - ) + wsi_info_dict = _update_tile_based_predictions_array( + post_process_output=post_process_output, + wsi_info_dict=wsi_info_dict, + bounds=tile_bounds, + ) - inst_dicts = _get_inst_info_dicts( - post_process_output=post_process_output + inst_dicts = _get_inst_info_dicts(post_process_output=post_process_output) + tile_tl = tile_bounds[:2] + tile_br = tile_bounds[2:] + tile_shape = tile_br - tile_tl + + new_inst_dicts, remove_insts_in_origs = [], [] + for inst_id, inst_dict in enumerate(inst_dicts): + new_inst_dict, remove_insts_in_orig = _process_instance_predictions( + inst_dict, + ioconfig, + tile_shape, + tile_flag, + tile_mode, + tile_tl, + wsi_info_dict[inst_id]["info_dict"], ) - tile_tl = tile_bounds[:2] - tile_br = tile_bounds[2:] - tile_shape = tile_br - tile_tl - - new_inst_dicts, remove_insts_in_origs = [], [] - for inst_id, inst_dict in enumerate(inst_dicts): - new_inst_dict, remove_insts_in_orig = _process_instance_predictions( - inst_dict, - ioconfig, - tile_shape, - tile_flag, - tile_mode, - tile_tl, - wsi_info_dict[inst_id]["info_dict"], - ) - new_inst_dicts.append(new_inst_dict) - remove_insts_in_origs.append(remove_insts_in_orig) + new_inst_dicts.append(new_inst_dict) + remove_insts_in_origs.append(remove_insts_in_orig) - for inst_id, new_inst_dict in enumerate(new_inst_dicts): - wsi_info_dict[inst_id]["info_dict"].update(new_inst_dict) - for inst_uuid in remove_insts_in_origs[inst_id]: - wsi_info_dict[inst_id]["info_dict"].pop(inst_uuid, None) + for inst_id, new_inst_dict in enumerate(new_inst_dicts): + wsi_info_dict[inst_id]["info_dict"].update(new_inst_dict) + for inst_uuid in remove_insts_in_origs[inst_id]: + wsi_info_dict[inst_id]["info_dict"].pop(inst_uuid, None) for idx, wsi_info_dict_ in enumerate( get_tqdm_full( From 628cdd0e06e30c43b4e90c87a4d40a7eb60bdaaa Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sun, 8 Feb 2026 01:33:23 +0000 Subject: [PATCH 116/156] :lipstick: Use tqdm progress bar for dask --- tiatoolbox/models/engine/engine_abc.py | 100 +++++++++++++++++- .../models/engine/multi_task_segmentor.py | 24 ++++- tiatoolbox/models/engine/nucleus_detector.py | 11 +- tiatoolbox/utils/misc.py | 8 +- 4 files changed, 130 insertions(+), 13 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 045afe225..c3eb94c95 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -49,6 +49,7 @@ from dask.diagnostics import ProgressBar from numcodecs import Pickle from torch import nn +from tqdm.auto import tqdm from typing_extensions import Unpack from tiatoolbox import DuplicateFilter, logger, rcParam @@ -833,8 +834,13 @@ def save_predictions_as_zarr( ) msg = f"Saving output to {save_path}." - logger.info(msg=msg) - with ProgressBar(): + progressbar = TqdmProgressBar( + total=len(write_tasks), + desc=msg, + leave=False, + verbose=self.verbose, + ) + with progressbar: compute(*write_tasks) zarr_group = zarr.open(save_path, mode="r+") @@ -1752,6 +1758,96 @@ def run( ) +class TqdmProgressBar(ProgressBar): + """A Dask progress bar that forwards progress updates to a ``tqdm`` bar. + + This class integrates Dask's diagnostic progress reporting with a + ``tqdm`` progress bar, providing a familiar and visually rich progress + indicator during ``dask.compute()`` calls. + + The bar updates based on Dask's internal progress fraction and + automatically closes when computation finishes. + + Args: + total (int): + Total number of tasks expected to be computed. Typically, the + number of delayed objects passed to ``dask.compute()``. + desc (str, optional): + Description text displayed to the left of the tqdm bar. + Defaults to ``"Computing"``. + leave (bool, optional): + Whether to leave the tqdm bar visible after completion. + Defaults to ``False``. + verbose (bool, optional): + If ``True``, prints Dask's diagnostic messages in addition to + updating the tqdm bar. Defaults to ``False``. + + Attributes: + pbar (tqdm): + The underlying tqdm progress bar instance. + + """ + + def __init__( + self, + total: int, + desc: str = "Computing", + *, + leave: bool = False, + verbose: bool = True, + ) -> None: + """TqdmProgressBar constructor. + + Args: + total (int): + Total number of tasks expected to be computed. Typically, the + number of delayed objects passed to ``dask.compute()``. + desc (str, optional): + Description text displayed to the left of the tqdm bar. + Defaults to ``"Computing"``. + leave (bool, optional): + Whether to leave the tqdm bar visible after completion. + Defaults to ``False``. + verbose (bool, optional): + If ``True``, prints Dask's diagnostic messages in addition to + updating the tqdm bar. Defaults to ``False``. + + """ + super().__init__(dt=0.1, out=_SilentFile()) + self.verbose = verbose + self.pbar = tqdm(total=total, desc=desc, leave=leave) + + def _draw_bar(self, *args: object, **kwargs: object) -> None: + """Update the tqdm bar based on Dask's reported completion fraction.""" + fraction = getattr(self, "_last_fraction", None) + + if fraction is not None: + new_value = int(fraction * self.pbar.total) + self.pbar.n = new_value + self.pbar.refresh() + + if self.verbose: + super()._draw_bar(*args, **kwargs) + + def _finish(self, *args: object, **kwargs: object) -> None: + """Finalize and close the tqdm bar when computation completes.""" + self.pbar.n = self.pbar.total + self.pbar.close() + + if self.verbose: + super()._finish(*args, **kwargs) + + +class _SilentFile: + """A file-like object that discards all writes.""" + + def write(self, s: str) -> None: + """Dummy write function.""" + + def flush(self) -> None: + """Dummy flush function.""" + + def prepare_engines_save_dir( save_dir: str | Path | None, *, diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index eff49391d..a434dd5b8 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -144,6 +144,7 @@ ) from tiatoolbox.wsicore.wsireader import is_zarr +from .engine_abc import TqdmProgressBar from .semantic_segmentor import ( SemanticSegmentor, SemanticSegmentorRunParams, @@ -1110,13 +1111,28 @@ def _process_tile_mode( ) ] - # Compute only this batch in parallel to avoid memory overload. - batch_outputs = compute( - *delayed_tasks, scheduler="threads", num_workers=num_workers + progressbar = TqdmProgressBar( + total=len(delayed_tasks), + desc="Post processing inference output", + leave=False, + verbose=self.verbose, + ) + + with progressbar: + # Compute only this batch in parallel to avoid memory overload. + batch_outputs = compute( + *delayed_tasks, scheduler="threads", num_workers=num_workers + ) + + tqdm_loop = get_tqdm_full( + batch_outputs, + leave=False, + verbose=self.verbose, + desc="Merging Output Predictions", ) # Merge each tile result immediately - for merge_idx, post_process_output in enumerate(batch_outputs): + for merge_idx, post_process_output in enumerate(tqdm_loop): tile_bounds, tile_flag, tile_mode = tile_metadata[merge_idx] # create a list of info dict for each task wsi_info_dict = _create_wsi_info_dict( diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index 36a625d50..fe0f66510 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -54,7 +54,6 @@ import numpy as np import zarr from dask import compute -from dask.diagnostics.progress import ProgressBar from shapely.geometry import Point from tiatoolbox import logger @@ -65,6 +64,8 @@ ) from tiatoolbox.utils.misc import get_tqdm_full +from .engine_abc import TqdmProgressBar + if TYPE_CHECKING: # pragma: no cover import os from typing import Unpack @@ -464,7 +465,13 @@ def post_process_wsi( task = centroid_maps.to_zarr( url=zarr_file, component="centroid_maps", compute=False, object_codec=None ) - with ProgressBar(): + progressbar = TqdmProgressBar( + total=len(task), + desc="Computing Centroids", + leave=False, + verbose=self.verbose, + ) + with progressbar: compute(task) self.drop_keys.append("centroid_maps") diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index a5e94e7c3..6f2f342ce 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -25,12 +25,11 @@ from shapely.geometry import Polygon from shapely.geometry import shape as feature2geometry from skimage import exposure -from tqdm import notebook as tqdm_notebook -from tqdm import tqdm, trange +from tqdm import trange +from tqdm.auto import tqdm from tiatoolbox import logger from tiatoolbox.annotation.storage import Annotation, AnnotationStore, SQLiteStore -from tiatoolbox.utils.env_detection import is_notebook from tiatoolbox.utils.exceptions import FileNotSupportedError if TYPE_CHECKING: # pragma: no cover @@ -1692,8 +1691,7 @@ def get_tqdm_full( Iterable of tqdm progress bar if self.verbose is True else input Iterable. """ - tqdm_ = tqdm_notebook.tqdm if is_notebook() else tqdm - return tqdm_(iterable_input, leave=leave, desc=desc) if verbose else iterable_input + return tqdm(iterable_input, leave=leave, desc=desc) if verbose else iterable_input def cast_to_min_dtype(array: np.ndarray | da.Array) -> np.ndarray | da.Array: From b7df2e05e2783831fced885b941b6a3579cdf530 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sun, 8 Feb 2026 01:59:45 +0000 Subject: [PATCH 117/156] :lipstick: Use tqdm.dask progress bar for dask --- tiatoolbox/models/engine/engine_abc.py | 104 +----------------- .../models/engine/multi_task_segmentor.py | 13 +-- tiatoolbox/models/engine/nucleus_detector.py | 14 +-- 3 files changed, 11 insertions(+), 120 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index c3eb94c95..d7b5df94a 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -46,10 +46,9 @@ import torch import zarr from dask import compute -from dask.diagnostics import ProgressBar from numcodecs import Pickle from torch import nn -from tqdm.auto import tqdm +from tqdm.dask import TqdmCallback from typing_extensions import Unpack from tiatoolbox import DuplicateFilter, logger, rcParam @@ -834,14 +833,9 @@ def save_predictions_as_zarr( ) msg = f"Saving output to {save_path}." - progressbar = TqdmProgressBar( - total=len(write_tasks), - desc=msg, - leave=False, - verbose=self.verbose, - ) - with progressbar: - compute(*write_tasks) + + with TqdmCallback(desc=msg, leave=False): + compute(*write_tasks, scheduler="processes", num_workers=self.num_workers) zarr_group = zarr.open(save_path, mode="r+") for key in self.drop_keys: @@ -1758,96 +1752,6 @@ def run( ) -class TqdmProgressBar(ProgressBar): - """A Dask progress bar that forwards progress updates to a ``tqdm`` bar. - - This class integrates Dask's diagnostic progress reporting with a - ``tqdm`` progress bar, providing a familiar and visually rich progress - indicator during ``dask.compute()`` calls. - - The bar updates based on Dask's internal progress fraction and - automatically closes when computation finishes. - - Args: - total (int): - Total number of tasks expected to be computed. Typically, the - number of delayed objects passed to ``dask.compute()``. - desc (str, optional): - Description text displayed to the left of the tqdm bar. - Defaults to ``"Computing"``. - leave (bool, optional): - Whether to leave the tqdm bar visible after completion. - Defaults to ``False``. - verbose (bool, optional): - If ``True``, prints Dask's diagnostic messages in addition to - updating the tqdm bar. Defaults to ``False``. - - Attributes: - pbar (tqdm): - The underlying tqdm progress bar instance. - - """ - - def __init__( - self, - total: int, - desc: str = "Computing", - *, - leave: bool = False, - verbose: bool = True, - ) -> None: - """TqdmProgressBar constructor. - - Args: - total (int): - Total number of tasks expected to be computed. Typically, the - number of delayed objects passed to ``dask.compute()``. - desc (str, optional): - Description text displayed to the left of the tqdm bar. - Defaults to ``"Computing"``. - leave (bool, optional): - Whether to leave the tqdm bar visible after completion. - Defaults to ``False``. - verbose (bool, optional): - If ``True``, prints Dask's diagnostic messages in addition to - updating the tqdm bar. Defaults to ``False``. - - """ - super().__init__(dt=0.1, out=_SilentFile()) - self.verbose = verbose - self.pbar = tqdm(total=total, desc=desc, leave=leave) - - def _draw_bar(self, *args: object, **kwargs: object) -> None: - """Update the tqdm bar based on Dask's reported completion fraction.""" - fraction = getattr(self, "_last_fraction", None) - - if fraction is not None: - new_value = int(fraction * self.pbar.total) - self.pbar.n = new_value - self.pbar.refresh() - - if self.verbose: - super()._draw_bar(*args, **kwargs) - - def _finish(self, *args: object, **kwargs: object) -> None: - """Finalize and close the tqdm bar when computation completes.""" - self.pbar.n = self.pbar.total - self.pbar.close() - - if self.verbose: - super()._finish(*args, **kwargs) - - -class _SilentFile: - """A file-like object that discards all writes.""" - - def write(self, s: str) -> None: - """Dummy write function.""" - - def flush(self) -> None: - """Dummy flush function.""" - - def prepare_engines_save_dir( save_dir: str | Path | None, *, diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index a434dd5b8..dcc7c040f 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -131,6 +131,7 @@ from shapely.geometry import box as shapely_box from shapely.geometry import shape as feature2geometry from shapely.strtree import STRtree +from tqdm.dask import TqdmCallback from typing_extensions import Unpack from tiatoolbox import logger @@ -144,7 +145,6 @@ ) from tiatoolbox.wsicore.wsireader import is_zarr -from .engine_abc import TqdmProgressBar from .semantic_segmentor import ( SemanticSegmentor, SemanticSegmentorRunParams, @@ -1111,17 +1111,10 @@ def _process_tile_mode( ) ] - progressbar = TqdmProgressBar( - total=len(delayed_tasks), - desc="Post processing inference output", - leave=False, - verbose=self.verbose, - ) - - with progressbar: + with TqdmCallback(desc="Post processing inference output", leave=False): # Compute only this batch in parallel to avoid memory overload. batch_outputs = compute( - *delayed_tasks, scheduler="threads", num_workers=num_workers + *delayed_tasks, scheduler="processes", num_workers=num_workers ) tqdm_loop = get_tqdm_full( diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index fe0f66510..b7a6b55de 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -55,6 +55,7 @@ import zarr from dask import compute from shapely.geometry import Point +from tqdm.dask import TqdmCallback from tiatoolbox import logger from tiatoolbox.annotation.storage import Annotation, SQLiteStore @@ -64,8 +65,6 @@ ) from tiatoolbox.utils.misc import get_tqdm_full -from .engine_abc import TqdmProgressBar - if TYPE_CHECKING: # pragma: no cover import os from typing import Unpack @@ -465,14 +464,9 @@ def post_process_wsi( task = centroid_maps.to_zarr( url=zarr_file, component="centroid_maps", compute=False, object_codec=None ) - progressbar = TqdmProgressBar( - total=len(task), - desc="Computing Centroids", - leave=False, - verbose=self.verbose, - ) - with progressbar: - compute(task) + + with TqdmCallback(desc="Computing Centroids", leave=False): + compute(task, scheduler="processes", num_workers=self.num_workers) self.drop_keys.append("centroid_maps") zarr_group = zarr.open(zarr_file, mode="r+") From 69fa82de9b17621585a1c7f6bac327d945efb8b9 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sun, 8 Feb 2026 02:03:47 +0000 Subject: [PATCH 118/156] :bug: Fix delayed_tasks compute --- tiatoolbox/models/engine/multi_task_segmentor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index dcc7c040f..66a210910 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1114,7 +1114,7 @@ def _process_tile_mode( with TqdmCallback(desc="Post processing inference output", leave=False): # Compute only this batch in parallel to avoid memory overload. batch_outputs = compute( - *delayed_tasks, scheduler="processes", num_workers=num_workers + *delayed_tasks, scheduler="threads", num_workers=num_workers ) tqdm_loop = get_tqdm_full( From 0b91f28eb6a6b06400fcd1516226e56893cf41f2 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sun, 8 Feb 2026 02:38:54 +0000 Subject: [PATCH 119/156] :recycle: Add `tqdm_dask_progress_bar` --- tiatoolbox/models/engine/engine_abc.py | 48 ++++++++++- .../models/engine/multi_task_segmentor.py | 86 ++++--------------- tiatoolbox/models/engine/nucleus_detector.py | 15 ++-- 3 files changed, 71 insertions(+), 78 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index d7b5df94a..c58223f29 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -833,9 +833,14 @@ def save_predictions_as_zarr( ) msg = f"Saving output to {save_path}." - - with TqdmCallback(desc=msg, leave=False): - compute(*write_tasks, scheduler="processes", num_workers=self.num_workers) + _ = tqdm_dask_progress_bar( + msg=msg, + write_tasks=write_tasks, + num_workers=self.num_workers, + scheduler="threads", # tasks are I/O-bound and shared memory use threads + leave=False, + verbose=self.verbose, + ) zarr_group = zarr.open(save_path, mode="r+") for key in self.drop_keys: @@ -1806,3 +1811,40 @@ def prepare_engines_save_dir( save_dir.mkdir(parents=True) return save_dir + + +def tqdm_dask_progress_bar( + msg: str, + write_tasks: list, + num_workers: int, + scheduler: str = "threads", + *, + leave: bool = False, + verbose: bool = True, +) -> list: + """Helper function for tqdm_dask_progress_bar. + + Args: + msg (str): + Message to display for the progress bar. + write_tasks (list): + List of dask tasks to compute. + num_workers (int): + Number of workers to use. + scheduler (str): + dask compute scheduler to use e.g., "threads" or "processes". + leave (bool): + Whether to leave progress bar after completion. + verbose (bool): + Whether to display progress bar. + + Returns: + List: + list of outputs from dask compute. + + """ + if verbose: + with TqdmCallback(desc=msg, leave=leave): + return compute(*write_tasks, scheduler=scheduler, num_workers=num_workers) + + return compute(*write_tasks, scheduler=scheduler, num_workers=num_workers) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 66a210910..3aab59dc4 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -127,11 +127,10 @@ import psutil import torch import zarr -from dask import compute, delayed +from dask import delayed from shapely.geometry import box as shapely_box from shapely.geometry import shape as feature2geometry from shapely.strtree import STRtree -from tqdm.dask import TqdmCallback from typing_extensions import Unpack from tiatoolbox import logger @@ -145,6 +144,7 @@ ) from tiatoolbox.wsicore.wsireader import is_zarr +from .engine_abc import tqdm_dask_progress_bar from .semantic_segmentor import ( SemanticSegmentor, SemanticSegmentorRunParams, @@ -905,7 +905,7 @@ def post_process_wsi( # skipcq: PYL-R0201 return_predictions=return_predictions, ) else: - num_workers = ( + self.num_workers = ( kwargs.get("num_workers", multiprocessing.cpu_count()) if self.num_workers == 0 else self.num_workers @@ -914,7 +914,6 @@ def post_process_wsi( # skipcq: PYL-R0201 probabilities=probabilities, save_path=save_path.with_suffix(".zarr"), memory_threshold=kwargs.get("memory_threshold", 80), - num_workers=num_workers, return_predictions=kwargs.get("return_predictions"), ) @@ -1005,7 +1004,6 @@ def _process_tile_mode( probabilities: list[da.Array | np.ndarray], save_path: Path, memory_threshold: float = 80, - num_workers: int = multiprocessing.cpu_count(), *, return_predictions: tuple[bool, ...] | None = None, ) -> list[dict] | None: @@ -1111,11 +1109,14 @@ def _process_tile_mode( ) ] - with TqdmCallback(desc="Post processing inference output", leave=False): - # Compute only this batch in parallel to avoid memory overload. - batch_outputs = compute( - *delayed_tasks, scheduler="threads", num_workers=num_workers - ) + batch_outputs = tqdm_dask_progress_bar( + msg="Post processing inference output", + write_tasks=delayed_tasks, + num_workers=self.num_workers, + scheduler="threads", + leave=False, + verbose=self.verbose, + ) tqdm_loop = get_tqdm_full( batch_outputs, @@ -2078,10 +2079,12 @@ def dict_to_store( ) ] - ann = compute_dask_delayed_with_progress( - delayed_tasks, + ann = tqdm_dask_progress_bar( + msg="Saving annotations", + write_tasks=delayed_tasks, num_workers=num_workers, - desc="Saving annotations ", + scheduler="threads", + leave=False, verbose=verbose, ) @@ -2091,63 +2094,6 @@ def dict_to_store( return store -def compute_dask_delayed_with_progress( - delayed_tasks: list, - num_workers: int = multiprocessing.cpu_count(), - desc: str = "Computing", - batch_size: int | None = None, - *, - verbose: bool = True, -) -> list: - """Compute a list of Dask delayed tasks in parallel while displaying a progress bar. - - This function batches tasks according to `num_workers`, ensuring that only - `num_workers` tasks are computed concurrently. This avoids excessive memory - usage when each delayed task returns a large object (e.g., NumPy arrays, - geometries, or annotations). A tqdm progress bar is updated after each batch. - - Args: - delayed_tasks (list): - A list of Dask delayed objects to compute. - num_workers (int): - Number of parallel worker threads to use. If set to 0 or None, - defaults to the number of CPU cores. - desc (str): - Description string shown in the tqdm progress bar. - batch_size (int | None): - batch_size to process dask delayed. - batch_size is set to num_workers if batch_size is not provided. - verbose (bool): - Whether to display logs and progress bar. - - Returns: - A list containing the computed results from all delayed tasks, in order. - - """ - total = len(delayed_tasks) - batch_size = num_workers if batch_size is None else batch_size - results: list[Any] = [] - - for i in get_tqdm_full( - range(0, total, batch_size), - desc=desc, - leave=False, - verbose=verbose, - ): - batch = delayed_tasks[i : i + batch_size] - - # Compute this batch in parallel - batch_results = compute( - *batch, - scheduler="threads", - num_workers=num_workers, - ) - - results.extend(batch_results) - - return results - - def prepare_multitask_full_batch( batch_output: tuple[np.ndarray], batch_locs: np.ndarray, diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index b7a6b55de..7785253c4 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -53,9 +53,7 @@ import dask.array as da import numpy as np import zarr -from dask import compute from shapely.geometry import Point -from tqdm.dask import TqdmCallback from tiatoolbox import logger from tiatoolbox.annotation.storage import Annotation, SQLiteStore @@ -65,6 +63,8 @@ ) from tiatoolbox.utils.misc import get_tqdm_full +from .engine_abc import tqdm_dask_progress_bar + if TYPE_CHECKING: # pragma: no cover import os from typing import Unpack @@ -464,9 +464,14 @@ def post_process_wsi( task = centroid_maps.to_zarr( url=zarr_file, component="centroid_maps", compute=False, object_codec=None ) - - with TqdmCallback(desc="Computing Centroids", leave=False): - compute(task, scheduler="processes", num_workers=self.num_workers) + _ = tqdm_dask_progress_bar( + msg="Computing Centroids", + write_tasks=[task], + num_workers=self.num_workers, + scheduler="processes", + leave=False, + verbose=self.verbose, + ) self.drop_keys.append("centroid_maps") zarr_group = zarr.open(zarr_file, mode="r+") From dba148f14d4d69d07b9c14c6bc0548f40e9d0cb3 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 9 Feb 2026 14:16:25 +0000 Subject: [PATCH 120/156] :zap: Optimise performance and move tqdm progress bar to misc --- tiatoolbox/models/architecture/hovernet.py | 108 ++++++---- tiatoolbox/models/engine/engine_abc.py | 41 +--- .../models/engine/multi_task_segmentor.py | 202 +++++++++++------- tiatoolbox/models/engine/nucleus_detector.py | 8 +- tiatoolbox/utils/misc.py | 39 ++++ 5 files changed, 227 insertions(+), 171 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 7ff8ffa96..56ef8daa7 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -3,6 +3,7 @@ from __future__ import annotations import math +import multiprocessing import warnings from collections import OrderedDict @@ -12,6 +13,7 @@ import pandas as pd import torch import torch.nn.functional as F # noqa: N812 +from dask import delayed from scipy import ndimage from skimage.morphology import remove_small_objects from skimage.segmentation import watershed @@ -23,7 +25,10 @@ centre_crop_to_shape, ) from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils.misc import get_bounding_box, get_tqdm_full +from tiatoolbox.utils.misc import ( + get_bounding_box, + tqdm_dask_progress_bar, +) class TFSamepaddingLayer(nn.Module): @@ -662,51 +667,22 @@ def get_instance_info( """ inst_id_list = np.unique(pred_inst)[1:] # exclude background - inst_info_dict = {} - tqdm_loop = get_tqdm_full( - inst_id_list, - leave=False, - desc="Generating 'info_dict' for instances", - verbose=verbose, - ) - for inst_id in tqdm_loop: - inst_map = pred_inst == inst_id - inst_box = get_bounding_box(inst_map) - inst_box_tl = inst_box[:2] - inst_map = inst_map[inst_box[1] : inst_box[3], inst_box[0] : inst_box[2]] - inst_map = inst_map.astype(np.uint8) - inst_moment = cv2.moments(inst_map) - inst_contour = cv2.findContours( - inst_map, - cv2.RETR_TREE, - cv2.CHAIN_APPROX_SIMPLE, + + tasks = [compute_inst_info(inst_id, pred_inst) for inst_id in inst_id_list] + + inst_info_dict = dict( + result + for result in tqdm_dask_progress_bar( + desc="Generating 'info_dict' for instances", + write_tasks=tasks, + num_workers=multiprocessing.cpu_count(), + scheduler="threads", + leave=False, + verbose=verbose, ) - # * opencv protocol format may break - inst_contour = inst_contour[0][0].astype(np.int32) - inst_contour = np.squeeze(inst_contour) - - # < 3 points does not make a contour, so skip, likely artifact too - # as the contours obtained via approximation => too small - if inst_contour.shape[0] < 3: # pragma: no cover # noqa: PLR2004 - continue - # ! check for trickery shape - if len(inst_contour.shape) != 2: # pragma: no cover # noqa: PLR2004 - continue - - inst_centroid = [ - (inst_moment["m10"] / inst_moment["m00"]), - (inst_moment["m01"] / inst_moment["m00"]), - ] - inst_centroid = np.array(inst_centroid) - inst_contour += inst_box_tl[None] - inst_centroid += inst_box_tl # X - inst_info_dict[inst_id] = { # inst_id should start at 1 - "box": inst_box, - "centroid": inst_centroid, - "contours": inst_contour, - "prob": None, - "type": None, - } + ) + + inst_info_dict = {k: v for k, v in inst_info_dict.items() if v is not None} if pred_type is not None: # * Get class of each instance id, stored at index id-1 @@ -910,3 +886,45 @@ def _inst_dict_for_dask_processing( else col_np ) return inst_info_dict_ + + +@delayed +def compute_inst_info(inst_id: int, pred_inst: np.ndarray) -> tuple[int, dict]: + """Helper function to compute instance info with dask delayed.""" + inst_map = pred_inst == inst_id + inst_box = get_bounding_box(inst_map) + inst_box_tl = inst_box[:2] + inst_map = inst_map[inst_box[1] : inst_box[3], inst_box[0] : inst_box[2]] + inst_map = inst_map.astype(np.uint8) + inst_moment = cv2.moments(inst_map) + inst_contour = cv2.findContours( + inst_map, + cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE, + ) + # * opencv protocol format may break + inst_contour = inst_contour[0][0].astype(np.int32) + inst_contour = np.squeeze(inst_contour) + + # < 3 points does not make a contour, so skip, likely artifact too + # as the contours obtained via approximation => too small + # ! check for trickery shape + if ( + inst_contour.shape[0] < 3 or inst_contour.ndim != 2 # noqa: PLR2004 + ): # pragma: no cover + return inst_id, None + + inst_centroid = [ + (inst_moment["m10"] / inst_moment["m00"]), + (inst_moment["m01"] / inst_moment["m00"]), + ] + inst_centroid = np.array(inst_centroid) + inst_contour += inst_box_tl[None] + inst_centroid += inst_box_tl # X + return inst_id, { + "box": inst_box, + "centroid": inst_centroid, + "contours": inst_contour, + "prob": None, + "type": None, + } diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index c58223f29..8377d7afd 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -48,7 +48,6 @@ from dask import compute from numcodecs import Pickle from torch import nn -from tqdm.dask import TqdmCallback from typing_extensions import Unpack from tiatoolbox import DuplicateFilter, logger, rcParam @@ -59,6 +58,7 @@ from tiatoolbox.utils.misc import ( dict_to_store_patch_predictions, get_tqdm_full, + tqdm_dask_progress_bar, ) from tiatoolbox.wsicore.wsireader import WSIReader, is_zarr @@ -834,7 +834,7 @@ def save_predictions_as_zarr( msg = f"Saving output to {save_path}." _ = tqdm_dask_progress_bar( - msg=msg, + desc=msg, write_tasks=write_tasks, num_workers=self.num_workers, scheduler="threads", # tasks are I/O-bound and shared memory use threads @@ -1811,40 +1811,3 @@ def prepare_engines_save_dir( save_dir.mkdir(parents=True) return save_dir - - -def tqdm_dask_progress_bar( - msg: str, - write_tasks: list, - num_workers: int, - scheduler: str = "threads", - *, - leave: bool = False, - verbose: bool = True, -) -> list: - """Helper function for tqdm_dask_progress_bar. - - Args: - msg (str): - Message to display for the progress bar. - write_tasks (list): - List of dask tasks to compute. - num_workers (int): - Number of workers to use. - scheduler (str): - dask compute scheduler to use e.g., "threads" or "processes". - leave (bool): - Whether to leave progress bar after completion. - verbose (bool): - Whether to display progress bar. - - Returns: - List: - list of outputs from dask compute. - - """ - if verbose: - with TqdmCallback(desc=msg, leave=leave): - return compute(*write_tasks, scheduler=scheduler, num_workers=num_workers) - - return compute(*write_tasks, scheduler=scheduler, num_workers=num_workers) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 3aab59dc4..1e9533e13 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -114,6 +114,7 @@ from __future__ import annotations import gc +import math import multiprocessing import shutil import uuid @@ -141,10 +142,10 @@ create_smart_array, get_tqdm_full, make_valid_poly, + tqdm_dask_progress_bar, ) from tiatoolbox.wsicore.wsireader import is_zarr -from .engine_abc import tqdm_dask_progress_bar from .semantic_segmentor import ( SemanticSegmentor, SemanticSegmentorRunParams, @@ -1093,80 +1094,121 @@ def _process_tile_mode( ) wsi_info_dict = None + merge_idx = 0 + # Only used for delayed processing. - self._probabilities = probabilities # skipcq: PYL-W0201 + self._probabilities = probabilities # skipcq: PYL-W0201 # skipcq: PYL-W0201 - # Build delayed tasks - delayed_tasks = [ - self._compute_tile( - _tile_meta[0], - ) - for _tile_meta in get_tqdm_full( - tile_metadata, - leave=False, - desc="Creating list of delayed tasks for writing annotations", - verbose=self.verbose, - ) - ] + # Calculate batch size for dask compute + vm = psutil.virtual_memory() + bytes_per_element = np.dtype(probabilities[0].dtype).itemsize + tile_elements = np.prod(self.ioconfig.tile_shape) + prod_dim2 = math.prod(p.shape[2] for p in probabilities if len(p.shape) > 2) # noqa: PLR2004 + tile_memory = len(probabilities) * tile_elements * prod_dim2 * bytes_per_element + # available memory + available_memory = vm.available * (memory_threshold / 100) + # batch size for dask compute should be greater than 0 + batch_size = max(int(available_memory // tile_memory), 1) - batch_outputs = tqdm_dask_progress_bar( - msg="Post processing inference output", - write_tasks=delayed_tasks, - num_workers=self.num_workers, - scheduler="threads", + for i in get_tqdm_full( + range(0, len(tile_metadata), batch_size), leave=False, + desc="Post-Processing WSI to generate predictions and contours", verbose=self.verbose, - ) + ): + tile_metadata_ = tile_metadata[i : i + batch_size] - tqdm_loop = get_tqdm_full( - batch_outputs, - leave=False, - verbose=self.verbose, - desc="Merging Output Predictions", - ) + # Build delayed tasks + delayed_tasks = [ + self._compute_tile( + _tile_meta[0], + ) + for _tile_meta in get_tqdm_full( + tile_metadata_, + leave=False, + desc="Creating list of delayed tasks for post-processing", + verbose=self.verbose, + ) + ] - # Merge each tile result immediately - for merge_idx, post_process_output in enumerate(tqdm_loop): - tile_bounds, tile_flag, tile_mode = tile_metadata[merge_idx] - # create a list of info dict for each task - wsi_info_dict = _create_wsi_info_dict( - post_process_output=post_process_output, - wsi_info_dict=wsi_info_dict, - wsi_proc_shape=wsi_proc_shape, - save_path=save_path, - memory_threshold=memory_threshold, - return_predictions=return_predictions, + # Compute only this batch in parallel to avoid memory overload. + batch_outputs = tqdm_dask_progress_bar( + desc="Running tile-based post-processing", + write_tasks=delayed_tasks, + num_workers=self.num_workers, + scheduler="threads", + leave=False, + verbose=self.verbose, ) - wsi_info_dict = _update_tile_based_predictions_array( - post_process_output=post_process_output, - wsi_info_dict=wsi_info_dict, - bounds=tile_bounds, + tqdm_loop = get_tqdm_full( + batch_outputs, + leave=False, + verbose=self.verbose, + desc="Merging Output Predictions", ) - inst_dicts = _get_inst_info_dicts(post_process_output=post_process_output) - tile_tl = tile_bounds[:2] - tile_br = tile_bounds[2:] - tile_shape = tile_br - tile_tl - - new_inst_dicts, remove_insts_in_origs = [], [] - for inst_id, inst_dict in enumerate(inst_dicts): - new_inst_dict, remove_insts_in_orig = _process_instance_predictions( - inst_dict, - ioconfig, - tile_shape, - tile_flag, - tile_mode, - tile_tl, - wsi_info_dict[inst_id]["info_dict"], + # Merge each tile result + for post_process_output in tqdm_loop: + tile_bounds, tile_flag, tile_mode = tile_metadata[merge_idx] + merge_idx += 1 + + # create a list of info dict for each task + wsi_info_dict = _create_wsi_info_dict( + post_process_output=post_process_output, + wsi_info_dict=wsi_info_dict, + wsi_proc_shape=wsi_proc_shape, + save_path=save_path, + memory_threshold=memory_threshold, + return_predictions=return_predictions, ) - new_inst_dicts.append(new_inst_dict) - remove_insts_in_origs.append(remove_insts_in_orig) - for inst_id, new_inst_dict in enumerate(new_inst_dicts): - wsi_info_dict[inst_id]["info_dict"].update(new_inst_dict) - for inst_uuid in remove_insts_in_origs[inst_id]: - wsi_info_dict[inst_id]["info_dict"].pop(inst_uuid, None) + wsi_info_dict = _update_tile_based_predictions_array( + post_process_output=post_process_output, + wsi_info_dict=wsi_info_dict, + bounds=tile_bounds, + ) + + inst_dicts = _get_inst_info_dicts( + post_process_output=post_process_output + ) + tile_tl = tile_bounds[:2] + tile_br = tile_bounds[2:] + tile_shape = tile_br - tile_tl + + ref_inst_rtree = STRtree([]) + processed_inst_predicts = [] + for inst_id, inst_dict in enumerate(inst_dicts): + if tile_mode == 3: # noqa: PLR2004 + inst_boxes = [ + v["box"] + for v in wsi_info_dict[inst_id]["info_dict"].values() + ] + inst_boxes = np.array(inst_boxes) + + geometries = [shapely_box(*bounds) for bounds in inst_boxes] + ref_inst_rtree = STRtree(geometries) + + processed_inst_predicts.append( + _process_instance_predictions( + inst_dict=inst_dict, + ioconfig=ioconfig, + tile_shape=tile_shape, + tile_flag=tile_flag, + tile_mode=tile_mode, + tile_tl=tile_tl, + ref_inst_dict=wsi_info_dict[inst_id]["info_dict"], + ref_inst_rtree=ref_inst_rtree, + ) + ) + + for inst_id, processed_inst_predict in enumerate( + processed_inst_predicts + ): + new_inst_dict, remove_insts_in_origs = processed_inst_predict + wsi_info_dict[inst_id]["info_dict"].update(new_inst_dict) + for inst_uuid in remove_insts_in_origs: + wsi_info_dict[inst_id]["info_dict"].pop(inst_uuid, None) for idx, wsi_info_dict_ in enumerate( get_tqdm_full( @@ -2080,11 +2122,9 @@ def dict_to_store( ] ann = tqdm_dask_progress_bar( - msg="Saving annotations", write_tasks=delayed_tasks, + desc="Saving annotations", num_workers=num_workers, - scheduler="threads", - leave=False, verbose=verbose, ) @@ -2750,6 +2790,7 @@ def _process_instance_predictions( tile_mode: int, tile_tl: tuple[int, int], ref_inst_dict: dict, + ref_inst_rtree: STRtree, ) -> list | tuple: """Function to merge new tile prediction with existing prediction. @@ -2789,6 +2830,8 @@ def _process_instance_predictions( Dictionary contains accumulated output. The expected format is {instance_id: {type: int, contour: List[List[int]], centroid:List[float], box:List[int]}. + ref_inst_rtree (STRtree): + A query-only R-tree spatial index for a list of geometries. Returns: new_inst_dict (dict): @@ -2814,29 +2857,16 @@ def _process_instance_predictions( tile_flag=tile_flag, ) - def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: - """Helper to retrieved selected instance uids.""" - if len(sel_indices) > 0: - # not sure how costly this is in large dict - inst_uids = list(inst_dict.keys()) - return [inst_uids[idx] for idx in sel_indices] - remove_insts_in_tile = retrieve_sel_uids(sel_indices, inst_dict) # external removal only for tile at cross-sections # this one should contain UUID with the reference database remove_insts_in_orig = [] if tile_mode == 3: # noqa: PLR2004 - inst_boxes = [v["box"] for v in ref_inst_dict.values()] - inst_boxes = np.array(inst_boxes) - - geometries = [shapely_box(*bounds) for bounds in inst_boxes] - ref_inst_rtree = STRtree(geometries) - sel_indices = [ + sel_indices_remove = [ geo for bounds in margin_lines for geo in ref_inst_rtree.query(bounds) ] - - remove_insts_in_orig = retrieve_sel_uids(sel_indices, ref_inst_dict) + remove_insts_in_orig = retrieve_sel_uids(sel_indices_remove, ref_inst_dict) new_inst_dict = _move_tile_space_to_wsi_space( inst_dict=inst_dict, @@ -2847,6 +2877,14 @@ def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: return new_inst_dict, remove_insts_in_orig +def retrieve_sel_uids(sel_indices_: list, inst_dict_: dict) -> list: + """Helper to retrieved selected instance uids.""" + if len(sel_indices_) > 0: + # not sure how costly this is in large dict + inst_uids = list(inst_dict_.keys()) + return [inst_uids[idx] for idx in sel_indices_] + + def _get_sel_indices_margin_lines( ioconfig: IOSegmentorConfig, tile_shape: tuple[int, int], @@ -2991,7 +3029,7 @@ def _get_inst_info_dicts(post_process_output: tuple[dict]) -> list: def _create_wsi_info_dict( post_process_output: tuple[dict], - wsi_info_dict: tuple[dict] | None, + wsi_info_dict: tuple[dict, ...] | None, wsi_proc_shape: tuple[int, ...], save_path: Path, return_predictions: tuple[bool, ...] | None, @@ -3063,7 +3101,7 @@ def _create_wsi_info_dict( def _update_tile_based_predictions_array( post_process_output: tuple[dict], - wsi_info_dict: tuple[dict], + wsi_info_dict: tuple[dict, ...], bounds: tuple[int, int, int, int], ) -> tuple[dict, ...]: """Helper function to update tile based predictions array.""" diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index 7785253c4..ae95ce25b 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -61,9 +61,7 @@ SemanticSegmentor, SemanticSegmentorRunParams, ) -from tiatoolbox.utils.misc import get_tqdm_full - -from .engine_abc import tqdm_dask_progress_bar +from tiatoolbox.utils.misc import get_tqdm_full, tqdm_dask_progress_bar if TYPE_CHECKING: # pragma: no cover import os @@ -465,10 +463,10 @@ def post_process_wsi( url=zarr_file, component="centroid_maps", compute=False, object_codec=None ) _ = tqdm_dask_progress_bar( - msg="Computing Centroids", + desc="Computing Centroids", write_tasks=[task], num_workers=self.num_workers, - scheduler="processes", + scheduler="threads", leave=False, verbose=self.verbose, ) diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 6f2f342ce..06f0c9702 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -20,6 +20,7 @@ import tifffile import yaml import zarr +from dask import compute from filelock import FileLock from shapely.affinity import translate from shapely.geometry import Polygon @@ -27,6 +28,7 @@ from skimage import exposure from tqdm import trange from tqdm.auto import tqdm +from tqdm.dask import TqdmCallback from tiatoolbox import logger from tiatoolbox.annotation.storage import Annotation, AnnotationStore, SQLiteStore @@ -1791,3 +1793,40 @@ def create_smart_array( chunks=chunks, dtype=dtype, ) + + +def tqdm_dask_progress_bar( + desc: str, + write_tasks: list, + num_workers: int, + scheduler: str = "threads", + *, + leave: bool = False, + verbose: bool = True, +) -> list: + """Helper function for tqdm_dask_progress_bar. + + Args: + desc (str): + Message to display for the progress bar. + write_tasks (list): + List of dask tasks to compute. + num_workers (int): + Number of workers to use. + scheduler (str): + dask compute scheduler to use e.g., "threads" or "processes". + leave (bool): + Whether to leave progress bar after completion. + verbose (bool): + Whether to display progress bar. + + Returns: + List: + list of outputs from dask compute. + + """ + if verbose: + with TqdmCallback(desc=desc, leave=leave): + return compute(*write_tasks, scheduler=scheduler, num_workers=num_workers) + + return compute(*write_tasks, scheduler=scheduler, num_workers=num_workers) From af5d598be097832dfa822137c186aabce9ba2493 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 9 Feb 2026 15:23:55 +0000 Subject: [PATCH 121/156] :zap: Tile based processing is faster --- tiatoolbox/models/architecture/hovernet.py | 11 +++++++- .../models/engine/multi_task_segmentor.py | 25 ++++++++++--------- tiatoolbox/utils/misc.py | 9 ++++--- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 56ef8daa7..a11c89add 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -27,6 +27,7 @@ from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.utils.misc import ( get_bounding_box, + get_tqdm_full, tqdm_dask_progress_bar, ) @@ -668,7 +669,15 @@ def get_instance_info( """ inst_id_list = np.unique(pred_inst)[1:] # exclude background - tasks = [compute_inst_info(inst_id, pred_inst) for inst_id in inst_id_list] + tasks = [ + compute_inst_info(inst_id, pred_inst) + for inst_id in get_tqdm_full( + inst_id_list, + desc="Creating list of tasks for computing instance info", + leave=False, + verbose=verbose, + ) + ] inst_info_dict = dict( result diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 1e9533e13..91ead8075 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -891,21 +891,17 @@ def post_process_wsi( # skipcq: PYL-R0201 """ probabilities = raw_predictions["probabilities"] - probabilities_is_zarr = False - for probabilities_ in probabilities: - if any("from-zarr" in str(key) for key in probabilities_.dask.layers): - probabilities_is_zarr = True - break + tile_h, tile_w = self.ioconfig.tile_shape + + trigger_tile_proc = any( + p.shape[0] > tile_h or p.shape[1] > tile_w for p in probabilities + ) return_predictions = kwargs.get("return_predictions") # If dask array can fit in memory process without tiling. # This ignores post-processing tile size even if it is smaller. - if not probabilities_is_zarr: - post_process_predictions = self._process_full_wsi( - probabilities=probabilities, - return_predictions=return_predictions, - ) - else: + if trigger_tile_proc: + logger.info("Processing tiles") self.num_workers = ( kwargs.get("num_workers", multiprocessing.cpu_count()) if self.num_workers == 0 @@ -917,6 +913,11 @@ def post_process_wsi( # skipcq: PYL-R0201 memory_threshold=kwargs.get("memory_threshold", 80), return_predictions=kwargs.get("return_predictions"), ) + else: + post_process_predictions = self._process_full_wsi( + probabilities=probabilities, + return_predictions=return_predictions, + ) tasks = set() for idx, seg in enumerate(post_process_predictions): @@ -1106,7 +1107,7 @@ def _process_tile_mode( prod_dim2 = math.prod(p.shape[2] for p in probabilities if len(p.shape) > 2) # noqa: PLR2004 tile_memory = len(probabilities) * tile_elements * prod_dim2 * bytes_per_element # available memory - available_memory = vm.available * (memory_threshold / 100) + available_memory = vm.available * (80 / 100) # batch size for dask compute should be greater than 0 batch_size = max(int(available_memory // tile_memory), 1) diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 06f0c9702..f86934d42 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -4,6 +4,7 @@ import copy import json +import multiprocessing import shutil import tempfile import zipfile @@ -1796,10 +1797,10 @@ def create_smart_array( def tqdm_dask_progress_bar( - desc: str, write_tasks: list, - num_workers: int, + num_workers: int = multiprocessing.cpu_count(), scheduler: str = "threads", + desc: str = "Processing data", *, leave: bool = False, verbose: bool = True, @@ -1807,14 +1808,14 @@ def tqdm_dask_progress_bar( """Helper function for tqdm_dask_progress_bar. Args: - desc (str): - Message to display for the progress bar. write_tasks (list): List of dask tasks to compute. num_workers (int): Number of workers to use. scheduler (str): dask compute scheduler to use e.g., "threads" or "processes". + desc (str): + Message to display for the progress bar. leave (bool): Whether to leave progress bar after completion. verbose (bool): From 8741c937b1223d79dfd1552a38e6c65d405fd0c6 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:44:28 +0000 Subject: [PATCH 122/156] :bug: Fix memory threshold variable. --- .../models/engine/multi_task_segmentor.py | 66 ++++++++++--------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 91ead8075..f1a115afd 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1107,7 +1107,7 @@ def _process_tile_mode( prod_dim2 = math.prod(p.shape[2] for p in probabilities if len(p.shape) > 2) # noqa: PLR2004 tile_memory = len(probabilities) * tile_elements * prod_dim2 * bytes_per_element # available memory - available_memory = vm.available * (80 / 100) + available_memory = vm.available * (memory_threshold / 100) # batch size for dask compute should be greater than 0 batch_size = max(int(available_memory // tile_memory), 1) @@ -1177,31 +1177,18 @@ def _process_tile_mode( tile_br = tile_bounds[2:] tile_shape = tile_br - tile_tl - ref_inst_rtree = STRtree([]) - processed_inst_predicts = [] - for inst_id, inst_dict in enumerate(inst_dicts): - if tile_mode == 3: # noqa: PLR2004 - inst_boxes = [ - v["box"] - for v in wsi_info_dict[inst_id]["info_dict"].values() - ] - inst_boxes = np.array(inst_boxes) - - geometries = [shapely_box(*bounds) for bounds in inst_boxes] - ref_inst_rtree = STRtree(geometries) - - processed_inst_predicts.append( - _process_instance_predictions( - inst_dict=inst_dict, - ioconfig=ioconfig, - tile_shape=tile_shape, - tile_flag=tile_flag, - tile_mode=tile_mode, - tile_tl=tile_tl, - ref_inst_dict=wsi_info_dict[inst_id]["info_dict"], - ref_inst_rtree=ref_inst_rtree, - ) + processed_inst_predicts = [ + _process_instance_predictions( + inst_dict=inst_dict, + ioconfig=ioconfig, + tile_shape=tile_shape, + tile_flag=tile_flag, + tile_mode=tile_mode, + tile_tl=tile_tl, + ref_inst_dict=wsi_info_dict[inst_id]["info_dict"], ) + for inst_id, inst_dict in enumerate(inst_dicts) + ] for inst_id, processed_inst_predict in enumerate( processed_inst_predicts @@ -2791,7 +2778,6 @@ def _process_instance_predictions( tile_mode: int, tile_tl: tuple[int, int], ref_inst_dict: dict, - ref_inst_rtree: STRtree, ) -> list | tuple: """Function to merge new tile prediction with existing prediction. @@ -2860,23 +2846,39 @@ def _process_instance_predictions( remove_insts_in_tile = retrieve_sel_uids(sel_indices, inst_dict) + if tile_mode != 3: # noqa: PLR2004 + return ( + _move_tile_space_to_wsi_space( + inst_dict=inst_dict, + tile_tl=tile_tl, + remove_insts_in_tile=remove_insts_in_tile, + ), + [], + ) # external removal only for tile at cross-sections # this one should contain UUID with the reference database remove_insts_in_orig = [] if tile_mode == 3: # noqa: PLR2004 + inst_boxes = [v["box"] for v in ref_inst_dict.values()] + inst_boxes = np.array(inst_boxes) + + geometries = [shapely_box(*bounds) for bounds in inst_boxes] + ref_inst_rtree = STRtree(geometries) + sel_indices_remove = [ geo for bounds in margin_lines for geo in ref_inst_rtree.query(bounds) ] remove_insts_in_orig = retrieve_sel_uids(sel_indices_remove, ref_inst_dict) - new_inst_dict = _move_tile_space_to_wsi_space( - inst_dict=inst_dict, - tile_tl=tile_tl, - remove_insts_in_tile=remove_insts_in_tile, + return ( + _move_tile_space_to_wsi_space( + inst_dict=inst_dict, + tile_tl=tile_tl, + remove_insts_in_tile=remove_insts_in_tile, + ), + remove_insts_in_orig, ) - return new_inst_dict, remove_insts_in_orig - def retrieve_sel_uids(sel_indices_: list, inst_dict_: dict) -> list: """Helper to retrieved selected instance uids.""" From 3fd396bfb604092bde8ddbea740e43093ac3cd39 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 9 Feb 2026 19:41:40 +0000 Subject: [PATCH 123/156] :bug: Update appropriate tile_shape --- tests/engines/test_multi_task_segmentor.py | 3 +- .../models/engine/multi_task_segmentor.py | 42 +++++++++++-------- .../models/engine/semantic_segmentor.py | 1 - 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 2470031c0..7e7e07af1 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -318,7 +318,8 @@ def test_wsi_mtsegmentor_zarr( assert "count" not in output_full_["layer_segmentation"] # Redefine tile size to force tile-based processing. - ioconfig.tile_shape = (512, 512) + # 350 x 350 forces tile mode 3 (overlap) + ioconfig.tile_shape = (350, 350) mtsegmentor.drop_keys = [] # Return Probabilities is False diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index f1a115afd..0b51ee30c 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1177,18 +1177,31 @@ def _process_tile_mode( tile_br = tile_bounds[2:] tile_shape = tile_br - tile_tl - processed_inst_predicts = [ - _process_instance_predictions( - inst_dict=inst_dict, - ioconfig=ioconfig, - tile_shape=tile_shape, - tile_flag=tile_flag, - tile_mode=tile_mode, - tile_tl=tile_tl, - ref_inst_dict=wsi_info_dict[inst_id]["info_dict"], + processed_inst_predicts = [] + for inst_id, inst_dict in enumerate(inst_dicts): + ref_inst_rtree = STRtree([]) + if tile_mode == 3: # noqa: PLR2004 + inst_boxes = [ + v["box"] + for v in wsi_info_dict[inst_id]["info_dict"].values() + ] + inst_boxes = np.array(inst_boxes) + + geometries = [shapely_box(*bounds) for bounds in inst_boxes] + ref_inst_rtree = STRtree(geometries) + + processed_inst_predicts.append( + _process_instance_predictions( + inst_dict=inst_dict, + ioconfig=ioconfig, + tile_shape=tile_shape, + tile_flag=tile_flag, + tile_mode=tile_mode, + tile_tl=tile_tl, + ref_inst_dict=wsi_info_dict[inst_id]["info_dict"], + ref_inst_rtree=ref_inst_rtree, + ) ) - for inst_id, inst_dict in enumerate(inst_dicts) - ] for inst_id, processed_inst_predict in enumerate( processed_inst_predicts @@ -2778,6 +2791,7 @@ def _process_instance_predictions( tile_mode: int, tile_tl: tuple[int, int], ref_inst_dict: dict, + ref_inst_rtree: STRtree, ) -> list | tuple: """Function to merge new tile prediction with existing prediction. @@ -2859,12 +2873,6 @@ def _process_instance_predictions( # this one should contain UUID with the reference database remove_insts_in_orig = [] if tile_mode == 3: # noqa: PLR2004 - inst_boxes = [v["box"] for v in ref_inst_dict.values()] - inst_boxes = np.array(inst_boxes) - - geometries = [shapely_box(*bounds) for bounds in inst_boxes] - ref_inst_rtree = STRtree(geometries) - sel_indices_remove = [ geo for bounds in margin_lines for geo in ref_inst_rtree.query(bounds) ] diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 2dd0999b5..e87285c89 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1452,7 +1452,6 @@ def prepare_full_batch( """ # Map batch locations back to indices in the full output grid. # Use a dict to avoid allocating a huge dense array when locations are sparse. - # Use np.intersect1d once numpy version is upgraded to 2.0 full_output_dict = {tuple(row): i for i, row in enumerate(full_output_locs)} matches = np.array([full_output_dict[tuple(row)] for row in batch_locs]) From 239b434bd14a5fef83aac5e1446962f42ddb26f0 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 9 Feb 2026 21:10:13 +0000 Subject: [PATCH 124/156] :zap: Optimise by removing unnecessary dask compute --- tiatoolbox/models/architecture/hovernet.py | 117 +++++++----------- .../models/engine/multi_task_segmentor.py | 75 ++++++----- 2 files changed, 90 insertions(+), 102 deletions(-) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index a11c89add..7ff8ffa96 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -3,7 +3,6 @@ from __future__ import annotations import math -import multiprocessing import warnings from collections import OrderedDict @@ -13,7 +12,6 @@ import pandas as pd import torch import torch.nn.functional as F # noqa: N812 -from dask import delayed from scipy import ndimage from skimage.morphology import remove_small_objects from skimage.segmentation import watershed @@ -25,11 +23,7 @@ centre_crop_to_shape, ) from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils.misc import ( - get_bounding_box, - get_tqdm_full, - tqdm_dask_progress_bar, -) +from tiatoolbox.utils.misc import get_bounding_box, get_tqdm_full class TFSamepaddingLayer(nn.Module): @@ -668,30 +662,51 @@ def get_instance_info( """ inst_id_list = np.unique(pred_inst)[1:] # exclude background - - tasks = [ - compute_inst_info(inst_id, pred_inst) - for inst_id in get_tqdm_full( - inst_id_list, - desc="Creating list of tasks for computing instance info", - leave=False, - verbose=verbose, - ) - ] - - inst_info_dict = dict( - result - for result in tqdm_dask_progress_bar( - desc="Generating 'info_dict' for instances", - write_tasks=tasks, - num_workers=multiprocessing.cpu_count(), - scheduler="threads", - leave=False, - verbose=verbose, - ) + inst_info_dict = {} + tqdm_loop = get_tqdm_full( + inst_id_list, + leave=False, + desc="Generating 'info_dict' for instances", + verbose=verbose, ) - - inst_info_dict = {k: v for k, v in inst_info_dict.items() if v is not None} + for inst_id in tqdm_loop: + inst_map = pred_inst == inst_id + inst_box = get_bounding_box(inst_map) + inst_box_tl = inst_box[:2] + inst_map = inst_map[inst_box[1] : inst_box[3], inst_box[0] : inst_box[2]] + inst_map = inst_map.astype(np.uint8) + inst_moment = cv2.moments(inst_map) + inst_contour = cv2.findContours( + inst_map, + cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE, + ) + # * opencv protocol format may break + inst_contour = inst_contour[0][0].astype(np.int32) + inst_contour = np.squeeze(inst_contour) + + # < 3 points does not make a contour, so skip, likely artifact too + # as the contours obtained via approximation => too small + if inst_contour.shape[0] < 3: # pragma: no cover # noqa: PLR2004 + continue + # ! check for trickery shape + if len(inst_contour.shape) != 2: # pragma: no cover # noqa: PLR2004 + continue + + inst_centroid = [ + (inst_moment["m10"] / inst_moment["m00"]), + (inst_moment["m01"] / inst_moment["m00"]), + ] + inst_centroid = np.array(inst_centroid) + inst_contour += inst_box_tl[None] + inst_centroid += inst_box_tl # X + inst_info_dict[inst_id] = { # inst_id should start at 1 + "box": inst_box, + "centroid": inst_centroid, + "contours": inst_contour, + "prob": None, + "type": None, + } if pred_type is not None: # * Get class of each instance id, stored at index id-1 @@ -895,45 +910,3 @@ def _inst_dict_for_dask_processing( else col_np ) return inst_info_dict_ - - -@delayed -def compute_inst_info(inst_id: int, pred_inst: np.ndarray) -> tuple[int, dict]: - """Helper function to compute instance info with dask delayed.""" - inst_map = pred_inst == inst_id - inst_box = get_bounding_box(inst_map) - inst_box_tl = inst_box[:2] - inst_map = inst_map[inst_box[1] : inst_box[3], inst_box[0] : inst_box[2]] - inst_map = inst_map.astype(np.uint8) - inst_moment = cv2.moments(inst_map) - inst_contour = cv2.findContours( - inst_map, - cv2.RETR_TREE, - cv2.CHAIN_APPROX_SIMPLE, - ) - # * opencv protocol format may break - inst_contour = inst_contour[0][0].astype(np.int32) - inst_contour = np.squeeze(inst_contour) - - # < 3 points does not make a contour, so skip, likely artifact too - # as the contours obtained via approximation => too small - # ! check for trickery shape - if ( - inst_contour.shape[0] < 3 or inst_contour.ndim != 2 # noqa: PLR2004 - ): # pragma: no cover - return inst_id, None - - inst_centroid = [ - (inst_moment["m10"] / inst_moment["m00"]), - (inst_moment["m01"] / inst_moment["m00"]), - ] - inst_centroid = np.array(inst_centroid) - inst_contour += inst_box_tl[None] - inst_centroid += inst_box_tl # X - return inst_id, { - "box": inst_box, - "centroid": inst_centroid, - "contours": inst_contour, - "prob": None, - "type": None, - } diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 0b51ee30c..786aa79b4 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1177,31 +1177,18 @@ def _process_tile_mode( tile_br = tile_bounds[2:] tile_shape = tile_br - tile_tl - processed_inst_predicts = [] - for inst_id, inst_dict in enumerate(inst_dicts): - ref_inst_rtree = STRtree([]) - if tile_mode == 3: # noqa: PLR2004 - inst_boxes = [ - v["box"] - for v in wsi_info_dict[inst_id]["info_dict"].values() - ] - inst_boxes = np.array(inst_boxes) - - geometries = [shapely_box(*bounds) for bounds in inst_boxes] - ref_inst_rtree = STRtree(geometries) - - processed_inst_predicts.append( - _process_instance_predictions( - inst_dict=inst_dict, - ioconfig=ioconfig, - tile_shape=tile_shape, - tile_flag=tile_flag, - tile_mode=tile_mode, - tile_tl=tile_tl, - ref_inst_dict=wsi_info_dict[inst_id]["info_dict"], - ref_inst_rtree=ref_inst_rtree, - ) + processed_inst_predicts = [ + _compute_info_dict_for_merge( + inst_dict=inst_dict, + tile_mode=tile_mode, + ref_inst_info_dict=wsi_info_dict[inst_id]["info_dict"], + ioconfig=ioconfig, + tile_shape=tile_shape, + tile_tl=tile_tl, + tile_flag=tile_flag, ) + for inst_id, inst_dict in enumerate(inst_dicts) + ] for inst_id, processed_inst_predict in enumerate( processed_inst_predicts @@ -2871,12 +2858,10 @@ def _process_instance_predictions( ) # external removal only for tile at cross-sections # this one should contain UUID with the reference database - remove_insts_in_orig = [] - if tile_mode == 3: # noqa: PLR2004 - sel_indices_remove = [ - geo for bounds in margin_lines for geo in ref_inst_rtree.query(bounds) - ] - remove_insts_in_orig = retrieve_sel_uids(sel_indices_remove, ref_inst_dict) + sel_indices_remove = [ + geo for bounds in margin_lines for geo in ref_inst_rtree.query(bounds) + ] + remove_insts_in_orig = retrieve_sel_uids(sel_indices_remove, ref_inst_dict) return ( _move_tile_space_to_wsi_space( @@ -3229,3 +3214,33 @@ def _build_single_annotation( } return Annotation(geom, properties) + + +def _compute_info_dict_for_merge( + inst_dict: dict, + tile_mode: int, + ref_inst_info_dict: dict, + ioconfig: IOSegmentorConfig, + tile_shape: tuple[int, int], + tile_tl: tuple[int, int], + tile_flag: tuple[int, int, int, int], +) -> list | tuple: + """Helper function to compute info dict with remove inst ids.""" + ref_inst_rtree = STRtree([]) + if tile_mode == 3: # noqa: PLR2004 + inst_boxes = [v["box"] for v in ref_inst_info_dict.values()] + inst_boxes = np.array(inst_boxes) + + geometries = [shapely_box(*bounds) for bounds in inst_boxes] + ref_inst_rtree = STRtree(geometries) + + return _process_instance_predictions( + inst_dict=inst_dict, + ioconfig=ioconfig, + tile_shape=tile_shape, + tile_flag=tile_flag, + tile_mode=tile_mode, + tile_tl=tile_tl, + ref_inst_dict=ref_inst_info_dict, + ref_inst_rtree=ref_inst_rtree, + ) From dcd73843e283e6b2bbf1f4c514a89291ea4815d2 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 9 Feb 2026 21:47:41 +0000 Subject: [PATCH 125/156] :bug: Fix outputs for annotationstore with multiple inputs. --- tiatoolbox/models/engine/multi_task_segmentor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 786aa79b4..f3773af98 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1664,21 +1664,21 @@ def _save_predictions_as_annotationstore( save_paths.append(output_path) else: - for idx, curr_image in enumerate(self.images): - values = [processed_predictions[key] for key in keys_to_compute] - output_path = _save_annotation_store( - curr_image=curr_image, + values = [processed_predictions[key] for key in keys_to_compute] + save_paths = [ + _save_annotation_store( + curr_image=save_path, keys_to_compute=keys_to_compute, values=values, task_name=task_name, - idx=idx, + idx=0, save_path=save_path, class_dict=class_dict, scale_factor=scale_factor, num_workers=num_workers, verbose=self.verbose, ) - save_paths.append(output_path) + ] for key in keys_to_compute: del processed_predictions[key] From bc86d74aa9075d212689d81d93e973dd4ea30af1 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 9 Feb 2026 22:39:12 +0000 Subject: [PATCH 126/156] :fire: Remove unnecessary variables --- tiatoolbox/models/engine/multi_task_segmentor.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index f3773af98..893ff8d4f 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1094,9 +1094,6 @@ def _process_tile_mode( verbose=self.verbose, ) - wsi_info_dict = None - merge_idx = 0 - # Only used for delayed processing. self._probabilities = probabilities # skipcq: PYL-W0201 # skipcq: PYL-W0201 @@ -1111,6 +1108,7 @@ def _process_tile_mode( # batch size for dask compute should be greater than 0 batch_size = max(int(available_memory // tile_memory), 1) + wsi_info_dict = None for i in get_tqdm_full( range(0, len(tile_metadata), batch_size), leave=False, @@ -1150,9 +1148,8 @@ def _process_tile_mode( ) # Merge each tile result - for post_process_output in tqdm_loop: - tile_bounds, tile_flag, tile_mode = tile_metadata[merge_idx] - merge_idx += 1 + for _tile_id, post_process_output in enumerate(tqdm_loop): + tile_bounds, tile_flag, tile_mode = tile_metadata_[_tile_id] # create a list of info dict for each task wsi_info_dict = _create_wsi_info_dict( From 5d26867fdcffb8dfeda7b5a83006d456836d02f4 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 11 Feb 2026 14:34:44 +0000 Subject: [PATCH 127/156] :bug: Fix writing annotations for large images --- tiatoolbox/models/engine/multi_task_segmentor.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 893ff8d4f..afee54fc3 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -2744,12 +2744,8 @@ def _save_annotation_store( predictions_ = dict(zip(keys_to_compute, values, strict=False)) output_path = save_path.parent / store_file_name # Patch mode indexes the "coordinates" while calculating "values" variable. - origin = ( - predictions_.pop("coordinates")[0][:2] - if len(predictions_["coordinates"].shape) > 1 - else predictions_.pop("coordinates")[:2] - ) - origin = tuple(max(0.0, float(x)) for x in origin) + origin = (0.0, 0.0) + _ = predictions_.pop("coordinates") store = SQLiteStore() store = dict_to_store( store=store, From 90e3fd5eeefe7a89b52cdce98cb687f05709b830 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 11 Feb 2026 15:20:17 +0000 Subject: [PATCH 128/156] :zap: Improve annotationstore writing --- .../models/engine/multi_task_segmentor.py | 404 +++++++++++------- 1 file changed, 255 insertions(+), 149 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index afee54fc3..d1c3dbca1 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -120,7 +120,7 @@ import uuid from collections import deque from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import dask.array as da import numpy as np @@ -2016,109 +2016,6 @@ def run( ) -def dict_to_store( - store: SQLiteStore, - processed_predictions: dict, - class_dict: dict | None = None, - origin: tuple[float, float] = (0, 0), - scale_factor: tuple[float, float] = (1, 1), - num_workers: int = multiprocessing.cpu_count(), - *, - verbose: bool = True, -) -> AnnotationStore: - """Write polygonal multitask predictions into an SQLite-backed AnnotationStore. - - Converts a task dictionary (with per-object fields) into `Annotation` records, - applying coordinate scaling and translation to move predictions into the slide's - baseline coordinate space. Each geometry is created from the per-object - `"contours"` entry, validated, and shifted by `origin`. All remaining keys in - `processed_predictions` are attached as annotation properties; the `"type"` key - can be mapped via `class_dict`. - - Expected `processed_predictions` structure: - - "contours": list-like of polygon coordinates per object, where each item - is shaped like `[[x0, y0], [x1, y1], ..., [xN, yN]]`. These are interpreted - according to `"geom_type"` (default `"Polygon"`). - - Optional "geom_type": str (e.g., "Polygon", "MultiPolygon"). - Defaults to "Polygon". - - Additional per-object fields (e.g., "type", "probability", scores, attributes) - with list-like values aligned to `contours` length. - - Args: - store (SQLiteStore): - Target annotation store that will receive the converted annotations. - processed_predictions (dict): - Dictionary containing per-object fields. Must include `"contours"`; - may include `"geom_type"` and any number of additional fields to be - written as properties. - class_dict (dict | None): - Optional mapping for the `"type"` field. When provided and when - `"type"` is present in `processed_predictions`, each `"type"` value is - replaced by `class_dict[type_id]` in the saved annotation properties. - origin (tuple[float, float]): - `(x0, y0)` offset to add to the final geometry coordinates (in pixels) - after scaling. Typically corresponds to the tile/patch origin in WSI - space. - scale_factor (tuple[float, float]): - `(sx, sy)` factors applied to coordinates before translation, used to - convert from model space to baseline slide resolution (e.g., - `model_mpp / slide_mpp`). - num_workers (int): - Number of parallel worker threads to use. If set to 0 or None, - defaults to the number of CPU cores. - verbose (bool): - Whether to display logs and progress bar. - - Returns: - AnnotationStore: - The input `store` after appending all converted annotations. - - Notes: - - Geometries are constructed from `processed_predictions["contours"]` using - `geom_type` (default `"Polygon"`), scaled by `scale_factor`, and translated - by `origin`. Invalid geometries are auto-corrected using `make_valid_poly`. - - Per-object properties are created by taking the i-th element from each - remaining key in `processed_predictions`. Scalars are coerced to arrays - first, then converted with `.tolist()` to ensure JSON-serializable values. - - If `class_dict` is provided and a `"type"` key exists, `"type"` values are - mapped prior to saving. - - All annotations are appended in a single batch via `store.append_many(...)`. - - """ - contours = processed_predictions.pop("contours") - n = len(contours) - - # Build delayed tasks - delayed_tasks = [ - _build_single_annotation( - i, - contours[i], - processed_predictions, - class_dict, - origin, - scale_factor, - ) - for i in get_tqdm_full( - range(n), - leave=False, - desc="Creating list of delayed tasks for writing annotations", - verbose=verbose, - ) - ] - - ann = tqdm_dask_progress_bar( - write_tasks=delayed_tasks, - desc="Saving annotations", - num_workers=num_workers, - verbose=verbose, - ) - - logger.info("Added %d annotations.", len(ann)) - store.append_many(ann) - - return store - - def prepare_multitask_full_batch( batch_output: tuple[np.ndarray], batch_locs: np.ndarray, @@ -3164,51 +3061,6 @@ def _build_tile_tasks( return tile_metadata -@delayed -def _build_single_annotation( - i: int, - contour: np.ndarray, - processed_predictions: dict[str, Any], - class_dict: dict[int, str] | None, - origin: tuple[float, float], - scale_factor: tuple[float, float], -) -> Annotation: - """Creates a delayed annotation to run with dask. - - Build a single Annotation object for index `i`. - - This function performs: - - geometry creation - - coordinate scaling + translation - - per-object property extraction - - class_dict mapping (if provided) - - Returns: - A single Annotation instance. - - """ - geom = make_valid_poly( - feature2geometry( - { - "type": processed_predictions.get("geom_type", "Polygon"), - "coordinates": scale_factor * np.array([contour]), - } - ), - tuple(origin), - ) - - properties = { - prop: ( - class_dict[processed_predictions[prop][i]] - if prop == "type" and class_dict is not None - else np.array(processed_predictions[prop][i]).tolist() - ) - for prop in processed_predictions - } - - return Annotation(geom, properties) - - def _compute_info_dict_for_merge( inst_dict: dict, tile_mode: int, @@ -3237,3 +3089,257 @@ def _compute_info_dict_for_merge( ref_inst_dict=ref_inst_info_dict, ref_inst_rtree=ref_inst_rtree, ) + + +def dict_to_store( + store: SQLiteStore, + processed_predictions: dict, + class_dict: dict | None = None, + origin: tuple[float, float] = (0, 0), + scale_factor: tuple[float, float] = (1, 1), + num_workers: int = multiprocessing.cpu_count(), + *, + verbose: bool = True, +) -> AnnotationStore: + """Write polygonal multitask predictions into an SQLite-backed AnnotationStore. + + Converts a task dictionary (with per-object fields) into `Annotation` records, + applying coordinate scaling and translation to move predictions into the slide's + baseline coordinate space. Each geometry is created from the per-object + `"contours"` entry, validated, and shifted by `origin`. All remaining keys in + `processed_predictions` are attached as annotation properties; the `"type"` key + can be mapped via `class_dict`. + + Expected `processed_predictions` structure: + - "contours": list-like of polygon coordinates per object, where each item + is shaped like `[[x0, y0], [x1, y1], ..., [xN, yN]]`. These are interpreted + according to `"geom_type"` (default `"Polygon"`). + - Optional "geom_type": str (e.g., "Polygon", "MultiPolygon"). + Defaults to "Polygon". + - Additional per-object fields (e.g., "type", "probability", scores, attributes) + with list-like values aligned to `contours` length. + + Args: + store (SQLiteStore): + Target annotation store that will receive the converted annotations. + processed_predictions (dict): + Dictionary containing per-object fields. Must include `"contours"`; + may include `"geom_type"` and any number of additional fields to be + written as properties. + class_dict (dict | None): + Optional mapping for the `"type"` field. When provided and when + `"type"` is present in `processed_predictions`, each `"type"` value is + replaced by `class_dict[type_id]` in the saved annotation properties. + origin (tuple[float, float]): + `(x0, y0)` offset to add to the final geometry coordinates (in pixels) + after scaling. Typically corresponds to the tile/patch origin in WSI + space. + scale_factor (tuple[float, float]): + `(sx, sy)` factors applied to coordinates before translation, used to + convert from model space to baseline slide resolution (e.g., + `model_mpp / slide_mpp`). + num_workers (int): + Number of parallel worker threads to use. If set to 0 or None, + defaults to the number of CPU cores. + verbose (bool): + Whether to display logs and progress bar. + + Returns: + AnnotationStore: + The input `store` after appending all converted annotations. + + Notes: + - Geometries are constructed from `processed_predictions["contours"]` using + `geom_type` (default `"Polygon"`), scaled by `scale_factor`, and translated + by `origin`. Invalid geometries are auto-corrected using `make_valid_poly`. + - Per-object properties are created by taking the i-th element from each + remaining key in `processed_predictions`. Scalars are coerced to arrays + first, then converted with `.tolist()` to ensure JSON-serializable values. + - If `class_dict` is provided and a `"type"` key exists, `"type"` values are + mapped prior to saving. + - All annotations are appended in a single batch via `store.append_many(...)`. + + """ + contours = processed_predictions.pop("contours") + delayed_tasks = DaskDelayedAnnotationStore( + contours=contours, + processed_predictions=processed_predictions, + ) + + return delayed_tasks.compute_annotations( + store=store, + class_dict=class_dict, + origin=origin, + scale_factor=scale_factor, + batch_size=100, + num_workers=num_workers, + verbose=verbose, + ) + + +class DaskDelayedAnnotationStore: + """Compute and write TIAToolbox annotations using batched Dask Delayed tasks. + + This class parallelizes annotation construction using Dask Delayed while + avoiding serialization overhead by storing contours and prediction arrays + as instance attributes. Annotations are computed in batches and written + directly to a TIAToolbox `SQLiteStore` via `append_many()`. + + """ + + def __init__( + self: DaskDelayedAnnotationStore, + contours: np.ndarray, + processed_predictions: dict, + ) -> DaskDelayedAnnotationStore: + """Initialize :class:`DaskDelayedAnnotationStore`. + + Args: + contours (np.ndarray): + A sequence of polygon contours. Each element is an array-like + of shape ``(N_i, 2)`` representing the coordinates of a single + object contour. + + processed_predictions (dict): + A dictionary of per-object prediction fields. Each key maps to + an array-like of length ``len(contours)``. Example keys include + ``"type"``, ``"prob"``, ``"centroid"``, etc. May also contain + a global field ``"geom_type"``. + + """ + self._contours = contours + self._processed_predictions = processed_predictions + + def _build_single_annotation( + self: DaskDelayedAnnotationStore, + i: int, + class_dict: dict[int, str] | None, + origin: tuple[float, float], + scale_factor: tuple[float, float], + ) -> Annotation: + """Build a single annotation for index ``i``. + + This method performs: + - geometry creation + - coordinate scaling and translation + - per-object property extraction + - optional class label mapping + + Args: + i (int): + Index of the object to convert into an annotation. + + class_dict (dict[int, str] | None): + Optional mapping from integer class IDs to string labels. + If ``None``, raw integer class IDs are used. + + origin (tuple[float, float]): + Translation offset ``(x, y)`` applied after scaling. + + scale_factor (tuple[float, float]): + Scaling factors ``(sx, sy)`` applied to contour coordinates. + + Returns: + Annotation: + A fully constructed TIAToolbox `Annotation` instance. + + """ + geom = make_valid_poly( + feature2geometry( + { + "type": self._processed_predictions.get("geom_type", "Polygon"), + "coordinates": scale_factor * np.array([self._contours[i]]), + } + ), + tuple(origin), + ) + + properties = { + prop: ( + class_dict[self._processed_predictions[prop][i]] + if prop == "type" and class_dict is not None + else np.array(self._processed_predictions[prop][i]).tolist() + ) + for prop in self._processed_predictions + } + + return Annotation(geom, properties) + + def compute_annotations( + self: DaskDelayedAnnotationStore, + store: SQLiteStore, + class_dict: dict[int, str] | None, + origin: tuple[float, float] = (0, 0), + scale_factor: tuple[float, float] = (1, 1), + batch_size: int = 100, + num_workers: int = 0, + *, + verbose: bool = True, + ) -> SQLiteStore: + """Compute annotations in batches and write them to a SQLiteStore. + + This method creates Dask Delayed tasks in batches to reduce scheduler + overhead. Each batch is computed and written immediately using + ``store.append_many()``. + + Args: + store (SQLiteStore): + A TIAToolbox SQLiteStore instance used to write annotations. + + class_dict (dict[int, str] | None): + Optional mapping from integer class IDs to string labels. + + origin (tuple[float, float], optional): + Translation offset ``(x, y)`` applied after scaling. + Defaults to ``(0, 0)``. + + scale_factor (tuple[float, float], optional): + Scaling factors ``(sx, sy)`` applied to contour coordinates. + Defaults to ``(1, 1)``. + + batch_size (int, optional): + Number of annotations to compute per batch. Larger batches + reduce Dask scheduler overhead. Defaults to ``100``. + + num_workers (int, optional): + Number of Dask workers to use. ``0`` means auto-detect. + Passed through to the progress bar helper. Defaults to ``0``. + + verbose (bool, optional): + Whether to display progress bars. Defaults to ``True``. + + Returns: + SQLiteStore: + The same store instance, after all annotations have been written. + + """ + num_contours = len(self._contours) + for batch_id in get_tqdm_full( + range(0, num_contours, batch_size), + leave=False, + desc="Calculating annotations in batches.", + ): + delayed_tasks = [ + delayed(self._build_single_annotation)( + i, + class_dict, + origin, + scale_factor, + ) + for i in get_tqdm_full( + range(batch_id, min(batch_id + batch_size, num_contours)), + leave=False, + desc="Creating list of delayed tasks for writing annotations", + verbose=True, + ) + ] + + store.append_many( + tqdm_dask_progress_bar( + write_tasks=delayed_tasks, + desc="Saving annotations", + verbose=verbose, + num_workers=num_workers, + ) + ) + return store From 787c489a9098772943eee97247d015e1192f378a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 11 Feb 2026 15:34:48 +0000 Subject: [PATCH 129/156] :bug: Fix output annotation type --- tiatoolbox/models/engine/multi_task_segmentor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index d1c3dbca1..69865d5bd 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -3191,7 +3191,7 @@ def __init__( self: DaskDelayedAnnotationStore, contours: np.ndarray, processed_predictions: dict, - ) -> DaskDelayedAnnotationStore: + ) -> None: """Initialize :class:`DaskDelayedAnnotationStore`. Args: From bd289918455230a6a394447d273bd40b70eefee0 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 11 Feb 2026 16:31:44 +0000 Subject: [PATCH 130/156] :zap: Use vectorized shapely.box instead of shapely.geometry.box --- .../models/engine/multi_task_segmentor.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 69865d5bd..aed2e5800 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -126,10 +126,10 @@ import numpy as np import pandas as pd import psutil +import shapely import torch import zarr from dask import delayed -from shapely.geometry import box as shapely_box from shapely.geometry import shape as feature2geometry from shapely.strtree import STRtree from typing_extensions import Unpack @@ -1314,12 +1314,12 @@ def _get_tile_info( def unset_removal_flag(boxes: tuple, removal_flag: np.ndarray) -> np.ndarray: """Unset removal flags for tiles intersecting image boundaries.""" sel_boxes = [ - shapely_box(0, 0, w, 0), # top edge - shapely_box(0, h, w, h), # bottom edge - shapely_box(0, 0, 0, h), # left - shapely_box(w, 0, w, h), # right + shapely.box(0, 0, w, 0), # top edge + shapely.box(0, h, w, h), # bottom edge + shapely.box(0, 0, 0, h), # left + shapely.box(w, 0, w, h), # right ] - geometries = [shapely_box(*bounds) for bounds in boxes] + geometries = [shapely.box(*bounds) for bounds in boxes] spatial_indexer = STRtree(geometries) for idx, sel_box in enumerate(sel_boxes): @@ -2789,23 +2789,23 @@ def _get_sel_indices_margin_lines( inst_boxes = [v["box"] for v in inst_dict.values()] inst_boxes = np.array(inst_boxes) - geometries = [shapely_box(*bounds) for bounds in inst_boxes] + geometries = [shapely.box(*bounds) for bounds in inst_boxes] tile_rtree = STRtree(geometries) # ! # create margin bounding box, ordering should match with # created tile info flag (top, bottom, left, right) boundary_lines = [ - shapely_box(0, 0, width, 1), # top egde - shapely_box(0, height - 1, width, height), # bottom edge - shapely_box(0, 0, 1, height), # left - shapely_box(width - 1, 0, width, height), # right + shapely.box(0, 0, width, 1), # top egde + shapely.box(0, height - 1, width, height), # bottom edge + shapely.box(0, 0, 1, height), # left + shapely.box(width - 1, 0, width, height), # right ] margin_boxes = [ - shapely_box(0, 0, width, margin), # top egde - shapely_box(0, height - margin, width, height), # bottom edge - shapely_box(0, 0, margin, height), # left - shapely_box(width - margin, 0, width, height), # right + shapely.box(0, 0, width, margin), # top egde + shapely.box(0, height - margin, width, height), # bottom edge + shapely.box(0, 0, margin, height), # left + shapely.box(width - margin, 0, width, height), # right ] margin_lines = _get_margin_lines( margin=margin, @@ -2865,7 +2865,7 @@ def _get_margin_lines( [[width - margin, margin], [width - margin, height - margin]], # right ] margin_lines = np.array(margin_lines) + tile_tl[None, None] - return [shapely_box(*v.flatten().tolist()) for v in margin_lines] + return [shapely.box(*v.flatten().tolist()) for v in margin_lines] def _move_tile_space_to_wsi_space( @@ -3076,7 +3076,7 @@ def _compute_info_dict_for_merge( inst_boxes = [v["box"] for v in ref_inst_info_dict.values()] inst_boxes = np.array(inst_boxes) - geometries = [shapely_box(*bounds) for bounds in inst_boxes] + geometries = [shapely.box(*bounds) for bounds in inst_boxes] ref_inst_rtree = STRtree(geometries) return _process_instance_predictions( From 476a365d24546bc719cb53f860ff93ccd1f96904 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 12 Feb 2026 10:15:35 +0000 Subject: [PATCH 131/156] :bulb: Address review comments --- tests/engines/test_multi_task_segmentor.py | 4 +- tiatoolbox/cli/__init__.py | 11 ++-- tiatoolbox/models/architecture/hovernet.py | 7 ++- tiatoolbox/models/architecture/micronet.py | 6 +- .../models/engine/deep_feature_extractor.py | 12 ++-- tiatoolbox/models/engine/engine_abc.py | 10 +-- .../models/engine/multi_task_segmentor.py | 63 ++++++++++--------- tiatoolbox/models/engine/nucleus_detector.py | 11 ++-- .../engine/nucleus_instance_segmentor.py | 2 +- .../models/engine/semantic_segmentor.py | 25 ++++---- tiatoolbox/utils/misc.py | 40 +++++------- 11 files changed, 94 insertions(+), 97 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index 7e7e07af1..fda263104 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -13,6 +13,7 @@ import torch import zarr from click.testing import CliRunner +from tqdm.auto import tqdm from tiatoolbox import cli from tiatoolbox.annotation import SQLiteStore @@ -25,7 +26,6 @@ ) from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils import imwrite -from tiatoolbox.utils.misc import get_tqdm_full from tiatoolbox.wsicore import WSIReader if TYPE_CHECKING: @@ -596,7 +596,7 @@ class FakeVM: # --- Real numpy array for shape/dtype --- probabilities = np.zeros((1, 3)) - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( range(1), ) diff --git a/tiatoolbox/cli/__init__.py b/tiatoolbox/cli/__init__.py index 838b14dfa..f1b20655f 100644 --- a/tiatoolbox/cli/__init__.py +++ b/tiatoolbox/cli/__init__.py @@ -41,21 +41,20 @@ def main() -> int: return 0 +main.add_command(deep_feature_extractor) +main.add_command(multitask_segmentor) +main.add_command(nucleus_detector) main.add_command(nucleus_instance_segment) main.add_command(patch_predictor) main.add_command(read_bounds) main.add_command(save_tiles) main.add_command(semantic_segmentor) -main.add_command(multitask_segmentor) -main.add_command(nucleus_detector) -main.add_command(nucleus_instance_segment) -main.add_command(deep_feature_extractor) +main.add_command(show_wsi) main.add_command(slide_info) main.add_command(slide_thumbnail) -main.add_command(tissue_mask) main.add_command(stain_norm) +main.add_command(tissue_mask) main.add_command(visualize) -main.add_command(show_wsi) if __name__ == "__main__": diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 7ff8ffa96..eb31c8d6a 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -16,6 +16,7 @@ from skimage.morphology import remove_small_objects from skimage.segmentation import watershed from torch import nn +from tqdm.auto import tqdm from tiatoolbox.models.architecture.utils import ( UpSample2x, @@ -23,7 +24,7 @@ centre_crop_to_shape, ) from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils.misc import get_bounding_box, get_tqdm_full +from tiatoolbox.utils.misc import get_bounding_box class TFSamepaddingLayer(nn.Module): @@ -663,11 +664,11 @@ def get_instance_info( """ inst_id_list = np.unique(pred_inst)[1:] # exclude background inst_info_dict = {} - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( inst_id_list, leave=False, desc="Generating 'info_dict' for instances", - verbose=verbose, + disable=not verbose, ) for inst_id in tqdm_loop: inst_map = pred_inst == inst_id diff --git a/tiatoolbox/models/architecture/micronet.py b/tiatoolbox/models/architecture/micronet.py index 01caced79..f5595d4cb 100644 --- a/tiatoolbox/models/architecture/micronet.py +++ b/tiatoolbox/models/architecture/micronet.py @@ -17,13 +17,13 @@ from skimage import morphology from torch import nn from torch.nn import functional +from tqdm.auto import tqdm from tiatoolbox.models.architecture.hovernet import ( HoVerNet, _inst_dict_for_dask_processing, ) from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils.misc import get_tqdm_full def group1_forward_branch( @@ -601,11 +601,11 @@ def postproc( pred_inst = ndimage.label(pred_bin)[0] pred_inst = morphology.remove_small_objects(pred_inst, min_size=50) canvas = np.zeros(pred_inst.shape[:2], dtype=np.int32) - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( range(1, np.max(pred_inst) + 1), leave=False, desc="Performing morphological operations to improve segmentation quality.", - verbose=verbose, + disable=not verbose, ) for inst_id in tqdm_loop: # Get coordinates of this instance diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index e596e706e..66b15f5cf 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -46,9 +46,10 @@ import psutil import zarr from dask import compute +from tqdm.auto import tqdm from typing_extensions import Unpack -from tiatoolbox.utils.misc import get_tqdm_full +from tiatoolbox.utils.misc import update_tqdm_desc from .patch_predictor import PatchPredictor, PredictorRunParams @@ -292,11 +293,11 @@ def infer_wsi( ) # Inference loop - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( dataloader, leave=False, desc="Inferring Patches", - verbose=self.verbose, + disable=not self.verbose, ) probabilities_zarr, coordinates_zarr = None, None @@ -335,8 +336,7 @@ def infer_wsi( f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) - - tqdm_loop.desc = msg + update_tqdm_desc(tqdm_loop=tqdm_loop, desc=msg) # Flush data in Memory and clear dask graph probabilities_zarr, coordinates_zarr = save_to_cache( probabilities, @@ -349,7 +349,7 @@ def infer_wsi( probabilities, coordinates = [], [] probabilities_used_percent = 0 gc.collect() - tqdm_loop.desc = "Inferring patches" + update_tqdm_desc(tqdm_loop=tqdm_loop, desc="Inferring patches") if probabilities_zarr is not None: probabilities_zarr, coordinates_zarr = save_to_cache( diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 8377d7afd..e107ac9c7 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -48,6 +48,7 @@ from dask import compute from numcodecs import Pickle from torch import nn +from tqdm.auto import tqdm from typing_extensions import Unpack from tiatoolbox import DuplicateFilter, logger, rcParam @@ -57,7 +58,6 @@ from tiatoolbox.models.models_abc import load_torch_model from tiatoolbox.utils.misc import ( dict_to_store_patch_predictions, - get_tqdm_full, tqdm_dask_progress_bar, ) from tiatoolbox.wsicore.wsireader import WSIReader, is_zarr @@ -532,11 +532,11 @@ def infer_patches( raw_predictions = {key: [] for key in keys} # Inference loop - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( dataloader, leave=False, desc="Inferring patches", - verbose=self.verbose, + disable=not self.verbose, ) infer_batch = self._get_model_attr("infer_batch") @@ -1566,11 +1566,11 @@ def get_path(image: Path | WSIReader) -> Path: for image in self.images } - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( self.images, leave=False, desc="Processing WSIs", - verbose=self.verbose, + disable=not self.verbose, ) for image_num, image in enumerate(tqdm_loop): diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index aed2e5800..0d1569986 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -132,6 +132,7 @@ from dask import delayed from shapely.geometry import shape as feature2geometry from shapely.strtree import STRtree +from tqdm.auto import tqdm from typing_extensions import Unpack from tiatoolbox import logger @@ -140,9 +141,9 @@ from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils.misc import ( create_smart_array, - get_tqdm_full, make_valid_poly, tqdm_dask_progress_bar, + update_tqdm_desc, ) from tiatoolbox.wsicore.wsireader import is_zarr @@ -160,7 +161,6 @@ import os from torch.utils.data import DataLoader - from tqdm import tqdm, tqdm_notebook from tiatoolbox.annotation import AnnotationStore from tiatoolbox.models.models_abc import ModelABC @@ -427,11 +427,11 @@ def infer_patches( raw_predictions["probabilities"] = [[] for _ in range(num_expected_output)] # Inference loop - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( dataloader, leave=False, desc="Inferring patches", - verbose=self.verbose, + disable=not self.verbose, ) for batch_data in tqdm_loop: @@ -573,11 +573,11 @@ def infer_wsi( ) # Inference loop - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( dataloader, leave=False, desc="Inferring patches", - verbose=self.verbose, + disable=not self.verbose, ) # Expected number of outputs from the model @@ -1109,11 +1109,11 @@ def _process_tile_mode( batch_size = max(int(available_memory // tile_memory), 1) wsi_info_dict = None - for i in get_tqdm_full( + for i in tqdm( range(0, len(tile_metadata), batch_size), leave=False, desc="Post-Processing WSI to generate predictions and contours", - verbose=self.verbose, + disable=not self.verbose, ): tile_metadata_ = tile_metadata[i : i + batch_size] @@ -1122,11 +1122,11 @@ def _process_tile_mode( self._compute_tile( _tile_meta[0], ) - for _tile_meta in get_tqdm_full( + for _tile_meta in tqdm( tile_metadata_, leave=False, desc="Creating list of delayed tasks for post-processing", - verbose=self.verbose, + disable=not self.verbose, ) ] @@ -1140,10 +1140,10 @@ def _process_tile_mode( verbose=self.verbose, ) - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( batch_outputs, leave=False, - verbose=self.verbose, + disable=not self.verbose, desc="Merging Output Predictions", ) @@ -1196,11 +1196,11 @@ def _process_tile_mode( wsi_info_dict[inst_id]["info_dict"].pop(inst_uuid, None) for idx, wsi_info_dict_ in enumerate( - get_tqdm_full( + tqdm( wsi_info_dict, leave=False, desc="Converting 'info_dict' to dask arrays", - verbose=self.verbose, + disable=not self.verbose, ) ): info_df = pd.DataFrame(wsi_info_dict_["info_dict"]).transpose() @@ -2295,10 +2295,10 @@ def save_multitask_to_cache( and ``count`` to free RAM and continue populating new entries. """ - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( canvas, desc="Memory Overload, Spilling to disk", - verbose=verbose, + disable=not verbose, ) for idx, canvas_ in enumerate(tqdm_loop): canvas_zarr[idx], count_zarr[idx] = save_to_cache( @@ -2406,11 +2406,11 @@ def merge_multitask_vertical_chunkwise( next_chunk = canvas_.blocks[1, 0].compute() if num_chunks > 1 else None next_count = count[idx].blocks[1, 0].compute() if num_chunks > 1 else None - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( overlaps, leave=False, desc=f"Merging rows for probability map {idx}", - verbose=verbose, + disable=not verbose, ) for i, overlap in enumerate(tqdm_loop): if next_chunk is not None and overlap > 0: @@ -2467,7 +2467,7 @@ def _save_multitask_vertical_to_cache( probabilities_da: list[da.Array] | list[None], probabilities: np.ndarray, idx: int, - tqdm_loop: type[tqdm_notebook | tqdm], + tqdm_loop: type[tqdm], save_path: Path, chunk_shape: tuple, memory_threshold: int = 80, @@ -2480,13 +2480,13 @@ def _save_multitask_vertical_to_cache( total_bytes = sum(0 if arr is None else arr.nbytes for arr in probabilities_da) used_percent = (total_bytes / max(vm.available, 1)) * 100 if probabilities_zarr[idx] is None and used_percent > memory_threshold: - desc = tqdm_loop.desc + desc = tqdm_loop.desc if hasattr(tqdm_loop, "desc") else "" msg = ( f"Current Memory usage: {used_percent} % " f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) - tqdm_loop.desc = msg + update_tqdm_desc(tqdm_loop=tqdm_loop, desc=msg) zarr_group = zarr.open(str(save_path), mode="a") probabilities_zarr[idx] = zarr_group.create_dataset( name=f"probabilities/{idx}", @@ -2496,7 +2496,7 @@ def _save_multitask_vertical_to_cache( overwrite=True, ) probabilities_zarr[idx][:] = probabilities_da[idx].compute() - tqdm_loop.desc = desc + update_tqdm_desc(tqdm_loop=tqdm_loop, desc=desc) probabilities_da[idx] = None return probabilities_zarr, probabilities_da @@ -2598,7 +2598,7 @@ def _check_and_update_for_memory_overload( f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) - tqdm_loop.desc = msg + update_tqdm_desc(tqdm_loop=tqdm_loop, desc=msg) # Flush data in Memory and clear dask graph canvas_zarr, count_zarr = save_multitask_to_cache( canvas, @@ -2611,7 +2611,7 @@ def _check_and_update_for_memory_overload( canvas = [None for _ in range(num_expected_output)] count = [None for _ in range(num_expected_output)] gc.collect() - tqdm_loop.desc = "Inferring patches" + update_tqdm_desc(tqdm_loop=tqdm_loop, desc="Inferring patches") return canvas, count, canvas_zarr, count_zarr, tqdm_loop @@ -3038,19 +3038,19 @@ def _build_tile_tasks( tile_metadata: list = [] for set_idx, (set_bounds, set_flags) in enumerate( - get_tqdm_full( + tqdm( tile_info_sets, leave=False, desc="Building delayed tile-processing tasks", - verbose=verbose, + disable=not verbose, ) ): for tile_idx, tile_bounds in enumerate( - get_tqdm_full( + tqdm( set_bounds, leave=False, desc=f"Building delayed tile-processing tasks for tile set {set_idx}", - verbose=verbose, + disable=not verbose, ) ): tile_flag = set_flags[tile_idx] @@ -3314,10 +3314,11 @@ def compute_annotations( """ num_contours = len(self._contours) - for batch_id in get_tqdm_full( + for batch_id in tqdm( range(0, num_contours, batch_size), leave=False, desc="Calculating annotations in batches.", + disable=not verbose, ): delayed_tasks = [ delayed(self._build_single_annotation)( @@ -3326,11 +3327,11 @@ def compute_annotations( origin, scale_factor, ) - for i in get_tqdm_full( + for i in tqdm( range(batch_id, min(batch_id + batch_size, num_contours)), leave=False, desc="Creating list of delayed tasks for writing annotations", - verbose=True, + disable=not verbose, ) ] diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index ae95ce25b..601cadbac 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -54,6 +54,7 @@ import numpy as np import zarr from shapely.geometry import Point +from tqdm.auto import tqdm from tiatoolbox import logger from tiatoolbox.annotation.storage import Annotation, SQLiteStore @@ -61,7 +62,7 @@ SemanticSegmentor, SemanticSegmentorRunParams, ) -from tiatoolbox.utils.misc import get_tqdm_full, tqdm_dask_progress_bar +from tiatoolbox.utils.misc import tqdm_dask_progress_bar if TYPE_CHECKING: # pragma: no cover import os @@ -773,10 +774,10 @@ class IDs for each detection (``np.uint32``). classes_list = [] probs_list = [] - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( range(num_blocks_h), desc="Processing detection blocks", - verbose=verbose, + disable=not verbose, ) for i in tqdm_loop: for j in range(num_blocks_w): @@ -896,10 +897,10 @@ def make_points(xs_batch: np.ndarray, ys_batch: np.ndarray) -> list[Point]: for xx, yy in zip(xs_batch, ys_batch, strict=True) ] - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( range(0, n, batch_size), desc="Writing detections to store", - verbose=verbose, + disable=not verbose, ) written = 0 for i in tqdm_loop: diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 5729547d1..353e472e6 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -133,7 +133,7 @@ def __init__( device: str = "cpu", verbose: bool = True, ) -> None: - """Initialize :class:`MultiTaskSegmentor`. + """Initialize :class:`NucleusInstanceSegmentor`. Args: model (str | ModelABC): diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index e87285c89..6fd982216 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -61,13 +61,14 @@ import psutil import torch import zarr +from tqdm.auto import tqdm from typing_extensions import Unpack from tiatoolbox import logger from tiatoolbox.models.dataset.dataset_abc import WSIPatchDataset from tiatoolbox.utils.misc import ( dict_to_store_semantic_segmentor, - get_tqdm_full, + update_tqdm_desc, ) from tiatoolbox.wsicore.wsireader import is_zarr @@ -460,11 +461,11 @@ def infer_wsi( ) # Inference loop - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( dataloader, leave=False, desc="Inferring patches", - verbose=self.verbose, + disable=not self.verbose, ) canvas_np, output_locs_y_ = None, None @@ -536,7 +537,7 @@ def infer_wsi( f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) - tqdm_loop.desc = msg + update_tqdm_desc(tqdm_loop=tqdm_loop, desc=msg) # Flush data in Memory and clear dask graph canvas_zarr, count_zarr = save_to_cache( canvas, @@ -548,7 +549,7 @@ def infer_wsi( ) canvas, count = None, None gc.collect() - tqdm_loop.desc = "Inferring patches" + update_tqdm_desc(tqdm_loop=tqdm_loop, desc="Inferring patches") coordinates.append( da.from_array( @@ -1198,11 +1199,11 @@ def save_to_cache( # Append remaining blocks one-at-a-time to limit peak memory. num_blocks = canvas.numblocks[0] - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( range(start_idx, num_blocks), leave=False, desc="Memory Overload, Spilling to disk", - verbose=verbose, + disable=not verbose, ) for block_idx in tqdm_loop: canvas_block = canvas.blocks[block_idx, 0, 0].compute() @@ -1269,11 +1270,11 @@ def merge_vertical_chunkwise( probabilities_zarr, probabilities_da = None, None chunk_shape = tuple(chunk[0] for chunk in canvas.chunks) - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( overlaps, leave=False, desc="Merging rows", - verbose=verbose, + disable=not verbose, ) used_percent = 0 @@ -1304,13 +1305,13 @@ def merge_vertical_chunkwise( vm = psutil.virtual_memory() used_percent = (probabilities_da.nbytes / vm.free) * 100 if probabilities_zarr is None and used_percent > memory_threshold: - desc = tqdm_loop.desc + desc = tqdm_loop.desc if hasattr(tqdm_loop, "desc") else "" msg = ( f"Current Memory usage: {used_percent} % " f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) - tqdm_loop.desc = msg + update_tqdm_desc(tqdm_loop=tqdm_loop, desc=msg) zarr_group = zarr.open(str(save_path), mode="a") probabilities_zarr = zarr_group.create_dataset( name="probabilities", @@ -1322,7 +1323,7 @@ def merge_vertical_chunkwise( probabilities_zarr[:] = probabilities_da.compute() probabilities_da = None - tqdm_loop.desc = desc + update_tqdm_desc(tqdm_loop=tqdm_loop, desc=desc) if next_chunk is not None: curr_chunk, curr_count = next_chunk[overlap:], next_count[overlap:] diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index f86934d42..71d0cf3af 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1224,11 +1224,11 @@ def patch_predictions_as_annotations( ) -> list: """Helper function to generate annotation per patch predictions.""" annotations = [] - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( patch_coords, leave=False, desc="Converting outputs to AnnotationStore.", - verbose=verbose, + disable=not verbose, ) for i, _ in enumerate(tqdm_loop): @@ -1405,11 +1405,11 @@ def dict_to_store_semantic_segmentor( annotations_list: list[Annotation] = [] - tqdm_loop = get_tqdm_full( + tqdm_loop = tqdm( layer_list, leave=False, desc="Converting outputs to AnnotationStore.", - verbose=verbose, + disable=not verbose, ) for type_class in tqdm_loop: @@ -1670,31 +1670,24 @@ def write_probability_heatmap_as_ome_tiff( logger.info(msg) -def get_tqdm_full( - iterable_input: Iterable, - desc: str = "Processing input", - *, - leave: bool = False, - verbose: bool = True, -) -> Iterable: - """Helper function to get appropriate tqdm progress bar. +def update_tqdm_desc( + tqdm_loop: tqdm | Iterable, + desc: str, +) -> None: + """Helper function to update tqdm progress bar description. Args: - iterable_input (Iterable): - Any iterable input. + tqdm_loop (tqdm): + tqdm progress bar. desc (str): tqdm progress bar description. - leave (bool): - Whether to leave progress bar after completion. - verbose (bool): - Whether to return progress bar or the input iterator. Returns: - Iterable: - Iterable of tqdm progress bar if self.verbose is True else input Iterable. + None """ - return tqdm(iterable_input, leave=leave, desc=desc) if verbose else iterable_input + if hasattr(tqdm_loop, "desc"): + tqdm_loop.desc = desc def cast_to_min_dtype(array: np.ndarray | da.Array) -> np.ndarray | da.Array: @@ -1735,9 +1728,9 @@ def create_smart_array( shape: tuple[int, ...], dtype: np.dtype | str, memory_threshold: float, - name: str | None, + name: str, zarr_path: str | Path, - chunks: tuple[int, ...] | None = None, + chunks: tuple[int, ...] | str = "auto", ) -> np.ndarray | zarr.Array: """Allocate a NumPy or Zarr array depending on available memory and a threshold. @@ -1826,6 +1819,7 @@ def tqdm_dask_progress_bar( list of outputs from dask compute. """ + num_workers = max(num_workers, 1) if verbose: with TqdmCallback(desc=desc, leave=leave): return compute(*write_tasks, scheduler=scheduler, num_workers=num_workers) From de2c89996d179b20d232b924df096fe039525a30 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 12 Feb 2026 10:29:47 +0000 Subject: [PATCH 132/156] :bug: Fix deepsource error --- .../models/engine/semantic_segmentor.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 6fd982216..5a08c5a76 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1284,6 +1284,8 @@ def merge_vertical_chunkwise( next_chunk = canvas.blocks[1, 0].compute() if num_chunks > 1 else None next_count = count.blocks[1, 0].compute() if num_chunks > 1 else None + probabilities = np.empty(0) + for i, overlap in enumerate(tqdm_loop): if next_chunk is not None and overlap > 0: curr_chunk[-overlap:] += next_chunk[:overlap] @@ -1335,17 +1337,32 @@ def merge_vertical_chunkwise( next_chunk, next_count = None, None if probabilities_zarr: - if "canvas" in zarr_group: - del zarr_group["canvas"] - if "count" in zarr_group: - del zarr_group["count"] - return da.from_zarr( - probabilities_zarr, chunks=(chunk_shape[0], *probabilities.shape[1:]) + return _get_probabilities_da_from_zarr( + zarr_group=zarr_group, + probabilities_zarr=probabilities_zarr, + chunk_shape=chunk_shape, + probabilities=probabilities, ) return probabilities_da +def _get_probabilities_da_from_zarr( + zarr_group: zarr.Group, + probabilities_zarr: zarr.Array, + chunk_shape: tuple, + probabilities: zarr.Array | np.ndarray, +) -> da.Array: + """Helper function to return dask array after probabilities have been merged.""" + if "canvas" in zarr_group: + del zarr_group["canvas"] + if "count" in zarr_group: + del zarr_group["count"] + return da.from_zarr( + probabilities_zarr, chunks=(chunk_shape[0], *probabilities.shape[1:]) + ) + + def store_probabilities( probabilities: np.ndarray, chunk_shape: tuple[int, ...], From 5137f6f3ec699a04516147654d46f1c6052c7652 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 12 Feb 2026 10:33:25 +0000 Subject: [PATCH 133/156] :lipstick: Set `leave=False` for tqdm loops --- tiatoolbox/models/engine/multi_task_segmentor.py | 1 + tiatoolbox/models/engine/nucleus_detector.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 0d1569986..0f5ea7808 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -2297,6 +2297,7 @@ def save_multitask_to_cache( """ tqdm_loop = tqdm( canvas, + leave=False, desc="Memory Overload, Spilling to disk", disable=not verbose, ) diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index 601cadbac..f235ffc7c 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -776,6 +776,7 @@ class IDs for each detection (``np.uint32``). tqdm_loop = tqdm( range(num_blocks_h), + leave=False, desc="Processing detection blocks", disable=not verbose, ) @@ -899,6 +900,7 @@ def make_points(xs_batch: np.ndarray, ys_batch: np.ndarray) -> list[Point]: tqdm_loop = tqdm( range(0, n, batch_size), + leave=False, desc="Writing detections to store", disable=not verbose, ) From 6e7597ee4a0b6fc941114e6ddf1e6c569fede4d9 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 12 Feb 2026 14:35:23 +0000 Subject: [PATCH 134/156] :bug: Fix multi-gpu run --- .../models/engine/multi_task_segmentor.py | 17 +++++++++++------ tiatoolbox/models/engine/nucleus_detector.py | 6 ++++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 0f5ea7808..de3942e0f 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -409,7 +409,8 @@ def infer_patches( coordinates = [] # Expected number of outputs from the model - batch_output = self.model.infer_batch( + infer_batch = self._get_model_attr("infer_batch") + batch_output = infer_batch( self.model, torch.Tensor(dataloader.dataset[0]["image"][np.newaxis, ...]), device=self.device, @@ -435,7 +436,7 @@ def infer_patches( ) for batch_data in tqdm_loop: - batch_output = self.model.infer_batch( + batch_output = infer_batch( self.model, batch_data["image"], device=self.device, @@ -581,7 +582,8 @@ def infer_wsi( ) # Expected number of outputs from the model - batch_output = self.model.infer_batch( + infer_batch = self._get_model_attr("infer_batch") + batch_output = infer_batch( self.model, torch.Tensor(dataloader.dataset[0]["image"][np.newaxis, ...]), device=self.device, @@ -783,8 +785,9 @@ def post_process_patches( # skipcq: PYL-R0201 """ probabilities = raw_predictions["probabilities"] + postproc_func = self._get_model_attr("postproc_func") post_process_predictions = [ - self.model.postproc_func(list(probs_for_idx)) + postproc_func(list(probs_for_idx)) for probs_for_idx in zip(*probabilities, strict=False) ] @@ -992,7 +995,8 @@ def _process_full_wsi( removed from the output. """ - post_process_predictions = self.model.postproc_func(probabilities) + postproc_func = self._get_model_attr("postproc_func") + post_process_predictions = postproc_func(probabilities) if return_predictions is None: return_predictions = [False for _ in post_process_predictions] for idx, return_predictions_ in enumerate(return_predictions): @@ -1244,7 +1248,8 @@ def _compute_tile( ].compute() for p in self._probabilities ] - return self.model.postproc_func(head_raws) + postproc_func = self._get_model_attr("postproc_func") + return postproc_func(head_raws) @staticmethod def _get_tile_info( diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index f235ffc7c..2652722d4 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -343,7 +343,8 @@ def post_process_patches( # Process each patch's predictions for i in range(raw_predictions["probabilities"].shape[0]): probs_prediction_patch = raw_predictions["probabilities"][i].compute() - centroids_map_patch = self.model.postproc( + postproc_func = self._get_model_attr("postproc_func") + centroids_map_patch = postproc_func( probs_prediction_patch, min_distance=min_distance, threshold_abs=threshold_abs, @@ -439,8 +440,9 @@ def post_process_wsi( (postproc_tile_shape[0], postproc_tile_shape[1], -1) ) + postproc_func = self._get_model_attr("postproc_func") centroid_maps = da.map_overlap( - self.model.postproc, + postproc_func, rechunked_probability_map, min_distance=min_distance, threshold_abs=threshold_abs, From ba7f5d2257e783e8015361d6a3082f8cd0f56962 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 12 Feb 2026 18:28:45 +0000 Subject: [PATCH 135/156] :zap: Convert to annotationstore from memory --- tiatoolbox/models/engine/multi_task_segmentor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index de3942e0f..74701f7ab 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -3166,6 +3166,11 @@ def dict_to_store( - All annotations are appended in a single batch via `store.append_many(...)`. """ + # Assumes annotationstore is computed for properties which can fit in memory. + processed_predictions = { + key: np.asarray(arr) if isinstance(arr, zarr.Array) and len(arr) > 0 else arr + for key, arr in processed_predictions.items() + } contours = processed_predictions.pop("contours") delayed_tasks = DaskDelayedAnnotationStore( contours=contours, From 0328ba95fd257a8418813ebb4558e970f550361a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 13 Feb 2026 11:24:53 +0000 Subject: [PATCH 136/156] :construction: Add qupath json --- tests/engines/test_patch_predictor.py | 48 ++++++++++++++++ tiatoolbox/models/engine/engine_abc.py | 33 +++++++++-- tiatoolbox/utils/misc.py | 80 ++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 5 deletions(-) diff --git a/tests/engines/test_patch_predictor.py b/tests/engines/test_patch_predictor.py index c4b0317bb..77eeed7b0 100644 --- a/tests/engines/test_patch_predictor.py +++ b/tests/engines/test_patch_predictor.py @@ -544,6 +544,54 @@ def test_engine_run_wsi_annotation_store( shutil.rmtree(save_dir) +def test_engine_run_wsi_qupath( + sample_wsi_dict: dict, + track_tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the engine run for Whole slide images.""" + # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + eng = PatchPredictor(model="alexnet-kather100k") + + patch_size = np.array([224, 224]) + save_dir = f"{track_tmp_path}/model_wsi_output" + + kwargs = { + "patch_input_shape": patch_size, + "stride_shape": patch_size, + "resolution": 0.5, + "save_dir": save_dir, + "units": "mpp", + "scale_factor": (2.0, 2.0), + } + + output = eng.run( + images=[mini_wsi_svs], + masks=[mini_wsi_msk], + patch_mode=False, + output_type="QuPath", + batch_size=4, + **kwargs, + ) + + output_ = output[mini_wsi_svs] + + assert output_.exists() + assert output_.suffix == ".db" + output_ = _extract_probabilities_from_annotation_store(output_) + + # prediction for each patch + assert np.array(output_["predictions"]).shape == (69,) + assert _validate_probabilities(output_) + + assert "Output file saved at " in caplog.text + + shutil.rmtree(save_dir) + + # -------------------------------------------------------------------------------------- # torch.compile # -------------------------------------------------------------------------------------- diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index e107ac9c7..2114ce801 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -740,6 +740,24 @@ def save_predictions( verbose=self.verbose, ) + if output_type.lower() == "qupath": + save_path = Path( + kwargs.get("output_file", save_path.parent / "output.json") + ) + + # scale_factor set from kwargs + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # class_dict set from kwargs + class_dict = kwargs.get("class_dict", self.model.class_dict) + + return dict_to_store_patch_predictions( + processed_predictions, + scale_factor, + class_dict, + save_path, + verbose=self.verbose, + ) + msg = f"Unsupported output type: {output_type}" raise TypeError(msg) @@ -1292,8 +1310,8 @@ def _update_run_params( self.patch_mode = patch_mode self._validate_input_numbers(images=images, masks=masks, labels=self.labels) - if output_type.lower() not in ["dict", "zarr", "annotationstore"]: - msg = "output_type must be 'dict' or 'zarr' or 'annotationstore'." + if output_type.lower() not in ["dict", "zarr", "qupath", "annotationstore"]: + msg = "output_type must be 'dict' or 'zarr', 'qupath' or 'annotationstore'." raise TypeError(msg) self.output_type = output_type @@ -1301,6 +1319,7 @@ def _update_run_params( if save_dir is not None and output_type.lower() not in [ "zarr", "annotationstore", + "qupath", ]: self.output_type = "zarr" msg = ( @@ -1310,7 +1329,11 @@ def _update_run_params( ) logger.info(msg) - if save_dir is None and output_type.lower() in ["zarr", "annotationstore"]: + if save_dir is None and output_type.lower() in [ + "zarr", + "qupath", + "annotationstore", + ]: msg = f"Please provide save_dir for output_type={output_type}" raise ValueError(msg) @@ -1414,7 +1437,7 @@ def _run_patch_mode( ) raw_predictions = self.infer_patches( dataloader=self.dataloader, - return_coordinates=output_type == "annotationstore", + return_coordinates=output_type in ["annotationstore", "qupath"], ) processed_predictions = self.post_process_patches( @@ -1660,7 +1683,7 @@ def run( overwrite (bool): Whether to overwrite existing output files. Default is False. output_type (str): - Desired output format: "dict", "zarr", or "annotationstore". + Desired output format: "dict", "zarr", "QuPath", or "annotationstore". **kwargs (EngineABCRunParams): Additional runtime parameters to update engine attributes. diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 71d0cf3af..3f6fa2772 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1538,6 +1538,86 @@ def dict_to_store_patch_predictions( return store +def patch_predictions_to_qupath_json( + patch_output: dict, + scale_factor: tuple[float, float], + class_dict: dict | None = None, + save_path: Path | None = None, +) -> dict | Path: + """Convert TIAToolbox PatchPredictor output to QuPath-compatible GeoJSON. + + Args: + patch_output (dict): + Must contain "coordinates", "predictions", and optionally "probabilities", "labels". + scale_factor (tuple): + Scale factor to convert coordinates to baseline resolution. + class_dict (dict): + Optional mapping from class index → class name. + save_path (Path): + Optional path to save the resulting JSON. + + Returns: + dict or Path: + A QuPath FeatureCollection JSON structure. + + """ + if "coordinates" not in patch_output: + raise ValueError("Patch output must contain coordinates.") + + coords = np.array(patch_output["coordinates"], dtype=float) + preds = np.array(patch_output["predictions"], dtype=int) + + # Apply scale factor + if not np.all(np.array(scale_factor) == 1): + coords = coords * np.tile(scale_factor, 2) + + # Determine class dictionary + if class_dict is None: + unique_classes = np.unique(preds).tolist() + class_dict = {i: f"class_{i}" for i in unique_classes} + + # --- Build QuPath FeatureCollection --- + features = [] + + for i, (x, y, w, h) in enumerate(coords): + class_idx = int(preds[i]) + class_name = class_dict[class_idx] + + # Rectangle polygon for QuPath + polygon = [ + [x, y], + [x + w, y], + [x + w, y + h], + [x, y + h], + [x, y], # close polygon + ] + + feature = { + "type": "Feature", + "id": f"patch_{i}", + "geometry": {"type": "Polygon", "coordinates": [polygon]}, + "properties": { + "classification": { + "name": class_name, + "color": None, # QuPath will auto-assign if None + } + }, + } + + features.append(feature) + + qupath_json = {"type": "FeatureCollection", "features": features} + + # Save if requested + if save_path: + save_path.parent.mkdir(parents=True, exist_ok=True) + with open(save_path, "w") as f: + json.dump(qupath_json, f, indent=2) + return save_path + + return qupath_json + + def _tiles( in_img: np.ndarray | zarr.core.Array, tile_size: tuple[int, int], From 892a08f812e5893da06e8fe56c21a2fa6391873f Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 13 Feb 2026 14:29:52 +0000 Subject: [PATCH 137/156] :bulb: Address review comments --- tiatoolbox/cli/multitask_segmentor.py | 4 ++-- tiatoolbox/cli/nucleus_instance_segment.py | 4 ++-- tiatoolbox/models/engine/multi_task_segmentor.py | 4 ++-- tiatoolbox/models/engine/nucleus_instance_segmentor.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tiatoolbox/cli/multitask_segmentor.py b/tiatoolbox/cli/multitask_segmentor.py index 4b747ccb3..387d1b2a6 100644 --- a/tiatoolbox/cli/multitask_segmentor.py +++ b/tiatoolbox/cli/multitask_segmentor.py @@ -42,8 +42,8 @@ @tiatoolbox_cli.command() @cli_img_input() @cli_output_path( - usage_help="Output directory where model segmentation will be saved.", - default="semantic_segmentation", + usage_help="Output directory where model output will be saved.", + default="multitask_segmentor", ) @cli_output_file(default=None) @cli_file_type( diff --git a/tiatoolbox/cli/nucleus_instance_segment.py b/tiatoolbox/cli/nucleus_instance_segment.py index ea33055ee..56e62d493 100644 --- a/tiatoolbox/cli/nucleus_instance_segment.py +++ b/tiatoolbox/cli/nucleus_instance_segment.py @@ -43,8 +43,8 @@ @tiatoolbox_cli.command() @cli_img_input() @cli_output_path( - usage_help="Output directory where model segmentation will be saved.", - default="semantic_segmentation", + usage_help="Output directory where model output will be saved.", + default="nucleus_instance_segment", ) @cli_output_file(default=None) @cli_file_type( diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 74701f7ab..1073302c6 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -29,8 +29,8 @@ MultiTaskSegmentorRunParams TypedDict of runtime parameters used across the engine. Extends - :class:`SemanticSegmentorRunParams` with additional multitask options: - `return_predictions`, `return_probabilities`, `memory_threshold`, etc. + :class:`SemanticSegmentorRunParams` with additional multitask option: + `return_predictions`. Important Functions infer_patches(dataloader, *, return_coordinates=False) -> dict diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 353e472e6..4b2d7bb63 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -41,7 +41,7 @@ class NucleusInstanceSegmentor(MultiTaskSegmentor): weights (str | Path | None): Path to model weights. If None, default weights are used. - >>> engine = SemanticSegmentor( + >>> engine = NucleusInstanceSegmentor( ... model="pretrained-model", ... weights="/path/to/pretrained-local-weights.pth" ... ) From 34c63ef0371bb0fae92f072d7740807ec082486a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 13 Feb 2026 14:56:10 +0000 Subject: [PATCH 138/156] :sparkles: Add support for `qupath` output --- tests/engines/test_engine_abc.py | 2 +- tests/engines/test_patch_predictor.py | 26 ++++- tiatoolbox/models/engine/engine_abc.py | 21 +--- tiatoolbox/utils/misc.py | 131 ++++++++----------------- 4 files changed, 70 insertions(+), 110 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 50da6309d..1fdbcf1c0 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -125,7 +125,7 @@ def test_incorrect_output_type() -> NoReturn: with pytest.raises( TypeError, - match=r".*output_type must be 'dict' or 'zarr' or 'annotationstore*", + match=r".*output_type must be 'dict' or 'zarr', 'qupath' or 'annotationstore*", ): _ = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), diff --git a/tests/engines/test_patch_predictor.py b/tests/engines/test_patch_predictor.py index 77eeed7b0..946d78679 100644 --- a/tests/engines/test_patch_predictor.py +++ b/tests/engines/test_patch_predictor.py @@ -341,6 +341,28 @@ def _extract_probabilities_from_annotation_store(dbfile: str) -> dict: return output +def _extract_from_qupath_json(json_file: str) -> dict: + """Extract predictions (and optionally coordinates) from QuPath GeoJSON.""" + with Path.open(json_file, "r") as f: + data = json.load(f) + + output = {"predictions": [], "coordinates": []} + + for feature in data.get("features", []): + props = feature.get("properties", {}) + cls = props.get("classification", {}) + + # prediction - class name + output["predictions"].append(cls.get("name")) + + # geometry - polygon + geom = feature.get("geometry", {}) + coords = geom.get("coordinates", [[]])[0] # first ring of polygon + output["coordinates"].append(coords) + + return output + + def _validate_probabilities(output: list | dict | zarr.group) -> bool: """Helper function to test if the probabilities value are valid.""" probabilities = np.array([0.5]) @@ -580,8 +602,8 @@ def test_engine_run_wsi_qupath( output_ = output[mini_wsi_svs] assert output_.exists() - assert output_.suffix == ".db" - output_ = _extract_probabilities_from_annotation_store(output_) + assert output_.suffix == ".json" + output_ = _extract_from_qupath_json(output_) # prediction for each patch assert np.array(output_["predictions"]).shape == (69,) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 2114ce801..5851f98c9 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -724,7 +724,7 @@ def save_predictions( if output_type.lower() == "dict": return processed_predictions - if output_type.lower() == "annotationstore": + if output_type.lower() in ["qupath", "annotationstore"]: save_path = Path(kwargs.get("output_file", save_path.parent / "output.db")) # scale_factor set from kwargs @@ -737,24 +737,7 @@ def save_predictions( scale_factor, class_dict, save_path, - verbose=self.verbose, - ) - - if output_type.lower() == "qupath": - save_path = Path( - kwargs.get("output_file", save_path.parent / "output.json") - ) - - # scale_factor set from kwargs - scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) - # class_dict set from kwargs - class_dict = kwargs.get("class_dict", self.model.class_dict) - - return dict_to_store_patch_predictions( - processed_predictions, - scale_factor, - class_dict, - save_path, + output_type=output_type, verbose=self.verbose, ) diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 3f6fa2772..19cc95e3a 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1216,7 +1216,7 @@ def patch_predictions_as_annotations( keys: list, class_dict: dict, class_probs: list | np.ndarray, - patch_coords: list, + patch_coords: list | np.ndarray, classes_predicted: list, labels: list, *, @@ -1451,10 +1451,11 @@ def dict_to_store_patch_predictions( scale_factor: tuple[float, float], class_dict: dict | None = None, save_path: Path | None = None, + output_type: str = "AnnotationStore", *, verbose: bool = True, -) -> AnnotationStore | Path: - """Converts output of TIAToolbox PatchPredictor engine to AnnotationStore. +) -> AnnotationStore | dict | Path: + """Converts output of the PatchPredictor engine to AnnotationStore or QuPath json. Args: patch_output (dict | zarr.Group): @@ -1470,6 +1471,9 @@ def dict_to_store_patch_predictions( save_path (str or Path): Optional Output directory to save the Annotation Store results. + output_type (str): + "annotationstore" → return AnnotationStore + "qupath" → return QuPath JSON dict verbose (bool): Whether to display logs and progress bar. @@ -1488,13 +1492,15 @@ def dict_to_store_patch_predictions( # get relevant keys class_probs = get_zarr_array(patch_output.get("probabilities", [])) preds = get_zarr_array(patch_output.get("predictions", [])) - patch_coords = np.array(patch_output.get("coordinates", [])) + + # Scale coordinates if not np.all(np.array(scale_factor) == 1): patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp labels = patch_output.get("labels", []) - # get classes to consider + + # Determine classes if len(class_probs) == 0: classes_predicted = np.unique(preds).tolist() else: @@ -1507,12 +1513,41 @@ def dict_to_store_patch_predictions( else: class_dict = {i: i for i in range(len(class_probs[0]))} - # find what keys we need to save + # Keys to save keys = ["predictions"] keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output] + if output_type.lower() == "qupath": + features = [] + + for i, (x, y, w, h) in enumerate(patch_coords): + class_idx = int(preds[i]) + class_name = class_dict[class_idx] + + polygon = [[x, y], [x + w, y], [x + w, y + h], [x, y + h], [x, y]] + + feature = { + "type": "Feature", + "id": f"patch_{i}", + "geometry": {"type": "Polygon", "coordinates": [polygon]}, + "properties": {"classification": {"name": class_name, "color": None}}, + } + + features.append(feature) + + qupath_json = {"type": "FeatureCollection", "features": features} + + if save_path: + save_path.parent.mkdir(parents=True, exist_ok=True) + save_path = save_path.with_suffix(".json") + with Path.open(save_path, "w") as f: + json.dump(qupath_json, f, indent=2) + return save_path + + return qupath_json + # put patch predictions into a store - annotations = patch_predictions_as_annotations( + annotations_ = patch_predictions_as_annotations( preds.astype(float), keys, class_dict, @@ -1524,7 +1559,7 @@ def dict_to_store_patch_predictions( ) store = SQLiteStore() - _ = store.append_many(annotations, [str(i) for i in range(len(annotations))]) + _ = store.append_many(annotations_, [str(i) for i in range(len(annotations_))]) # if a save director is provided, then dump store into a file if save_path: @@ -1538,86 +1573,6 @@ def dict_to_store_patch_predictions( return store -def patch_predictions_to_qupath_json( - patch_output: dict, - scale_factor: tuple[float, float], - class_dict: dict | None = None, - save_path: Path | None = None, -) -> dict | Path: - """Convert TIAToolbox PatchPredictor output to QuPath-compatible GeoJSON. - - Args: - patch_output (dict): - Must contain "coordinates", "predictions", and optionally "probabilities", "labels". - scale_factor (tuple): - Scale factor to convert coordinates to baseline resolution. - class_dict (dict): - Optional mapping from class index → class name. - save_path (Path): - Optional path to save the resulting JSON. - - Returns: - dict or Path: - A QuPath FeatureCollection JSON structure. - - """ - if "coordinates" not in patch_output: - raise ValueError("Patch output must contain coordinates.") - - coords = np.array(patch_output["coordinates"], dtype=float) - preds = np.array(patch_output["predictions"], dtype=int) - - # Apply scale factor - if not np.all(np.array(scale_factor) == 1): - coords = coords * np.tile(scale_factor, 2) - - # Determine class dictionary - if class_dict is None: - unique_classes = np.unique(preds).tolist() - class_dict = {i: f"class_{i}" for i in unique_classes} - - # --- Build QuPath FeatureCollection --- - features = [] - - for i, (x, y, w, h) in enumerate(coords): - class_idx = int(preds[i]) - class_name = class_dict[class_idx] - - # Rectangle polygon for QuPath - polygon = [ - [x, y], - [x + w, y], - [x + w, y + h], - [x, y + h], - [x, y], # close polygon - ] - - feature = { - "type": "Feature", - "id": f"patch_{i}", - "geometry": {"type": "Polygon", "coordinates": [polygon]}, - "properties": { - "classification": { - "name": class_name, - "color": None, # QuPath will auto-assign if None - } - }, - } - - features.append(feature) - - qupath_json = {"type": "FeatureCollection", "features": features} - - # Save if requested - if save_path: - save_path.parent.mkdir(parents=True, exist_ok=True) - with open(save_path, "w") as f: - json.dump(qupath_json, f, indent=2) - return save_path - - return qupath_json - - def _tiles( in_img: np.ndarray | zarr.core.Array, tile_size: tuple[int, int], From 4ee26d912d9ef4a8f6d14e745eeff1de4a857fcb Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 13 Feb 2026 15:58:35 +0000 Subject: [PATCH 139/156] update tiatoolbox/utils/misc.py --- tiatoolbox/utils/misc.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 19cc95e3a..cfbafe8f8 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -14,6 +14,7 @@ import cv2 import dask.array as da import joblib +import matplotlib.pyplot as plt import numpy as np import pandas as pd import psutil @@ -24,7 +25,7 @@ from dask import compute from filelock import FileLock from shapely.affinity import translate -from shapely.geometry import Polygon +from shapely.geometry import Polygon, mapping from shapely.geometry import shape as feature2geometry from skimage import exposure from tqdm import trange @@ -1519,18 +1520,37 @@ def dict_to_store_patch_predictions( if output_type.lower() == "qupath": features = [] + # pick a color for each class based on the class index, using a colormap + num_classes = len(class_dict) + cmap = plt.cm.get_cmap("tab20", num_classes) + class_colours = { + class_idx: [ + int(cmap(class_idx)[0] * 255), + int(cmap(class_idx)[1] * 255), + int(cmap(class_idx)[2] * 255), + ] + for class_idx in class_dict + } - for i, (x, y, w, h) in enumerate(patch_coords): + for i in range(patch_coords.shape[0]): class_idx = int(preds[i]) class_name = class_dict[class_idx] - - polygon = [[x, y], [x + w, y], [x + w, y + h], [x, y + h], [x, y]] + polygon_geo = Polygon.from_bounds(*patch_coords[i]) + polygon_feat = mapping(polygon_geo) feature = { "type": "Feature", "id": f"patch_{i}", - "geometry": {"type": "Polygon", "coordinates": [polygon]}, - "properties": {"classification": {"name": class_name, "color": None}}, + "geometry": polygon_feat, + "properties": { + "classification": { + "name": class_name, + "color": class_colours[class_idx], + } + }, + "objectType": "annotation", + "name": class_name, + "class_value": class_idx, } features.append(feature) From 1cdcd76ea4233138cc83fb202873de3522f68537 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 13 Feb 2026 16:27:51 +0000 Subject: [PATCH 140/156] :bug: Fix deep-source errors --- tiatoolbox/utils/misc.py | 100 ++++++++++++++++++++++++--------------- 1 file changed, 61 insertions(+), 39 deletions(-) diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index cfbafe8f8..85d6728fd 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -28,8 +28,7 @@ from shapely.geometry import Polygon, mapping from shapely.geometry import shape as feature2geometry from skimage import exposure -from tqdm import trange -from tqdm.auto import tqdm +from tqdm.auto import tqdm, trange from tqdm.dask import TqdmCallback from tiatoolbox import logger @@ -1248,6 +1247,60 @@ def patch_predictions_as_annotations( return annotations +def patch_predictions_as_qupath_json( + preds: list | np.ndarray, + class_dict: dict, + patch_coords: list | np.ndarray, + *, + verbose: bool = True, +) -> dict: + """Helper function to generate QuPath JSON per patch predictions.""" + features = [] + # pick a color for each class based on the class index, using a colormap + num_classes = len(class_dict) + cmap = plt.cm.get_cmap("tab20", num_classes) + class_colours = { + class_idx: [ + int(cmap(class_idx)[0] * 255), + int(cmap(class_idx)[1] * 255), + int(cmap(class_idx)[2] * 255), + ] + for class_idx in class_dict + } + + tqdm_loop = tqdm( + range(patch_coords.shape[0]), + leave=False, + desc="Converting outputs to QuPath JSON.", + disable=not verbose, + ) + + for i in tqdm_loop: + class_idx = int(preds[i]) + class_name = class_dict[class_idx] + polygon_geo = Polygon.from_bounds(*patch_coords[i]) + polygon_feat = mapping(polygon_geo) + + feature = { + "type": "Feature", + "id": f"patch_{i}", + "geometry": polygon_feat, + "properties": { + "classification": { + "name": class_name, + "color": class_colours[class_idx], + } + }, + "objectType": "annotation", + "name": class_name, + "class_value": class_idx, + } + + features.append(feature) + + return {"type": "FeatureCollection", "features": features} + + def get_zarr_array(zarr_array: zarr.core.Array | np.ndarray | list) -> np.ndarray: """Converts a zarr array into a numpy array.""" if isinstance(zarr_array, zarr.core.Array): @@ -1519,43 +1572,12 @@ def dict_to_store_patch_predictions( keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output] if output_type.lower() == "qupath": - features = [] - # pick a color for each class based on the class index, using a colormap - num_classes = len(class_dict) - cmap = plt.cm.get_cmap("tab20", num_classes) - class_colours = { - class_idx: [ - int(cmap(class_idx)[0] * 255), - int(cmap(class_idx)[1] * 255), - int(cmap(class_idx)[2] * 255), - ] - for class_idx in class_dict - } - - for i in range(patch_coords.shape[0]): - class_idx = int(preds[i]) - class_name = class_dict[class_idx] - polygon_geo = Polygon.from_bounds(*patch_coords[i]) - polygon_feat = mapping(polygon_geo) - - feature = { - "type": "Feature", - "id": f"patch_{i}", - "geometry": polygon_feat, - "properties": { - "classification": { - "name": class_name, - "color": class_colours[class_idx], - } - }, - "objectType": "annotation", - "name": class_name, - "class_value": class_idx, - } - - features.append(feature) - - qupath_json = {"type": "FeatureCollection", "features": features} + qupath_json = patch_predictions_as_qupath_json( + preds=preds, + class_dict=class_dict, + patch_coords=patch_coords, + verbose=True, + ) if save_path: save_path.parent.mkdir(parents=True, exist_ok=True) From 21cdd9a61fbde5e033bb6f049a909afa83313890 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 13 Feb 2026 16:55:07 +0000 Subject: [PATCH 141/156] :bug: Fix `mypy` errors --- tiatoolbox/utils/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 85d6728fd..4e605ea70 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1269,7 +1269,7 @@ def patch_predictions_as_qupath_json( } tqdm_loop = tqdm( - range(patch_coords.shape[0]), + range(np.asarray(patch_coords).shape[0]), leave=False, desc="Converting outputs to QuPath JSON.", disable=not verbose, From e98d1a15cce2f5e25ead3bc6b27f80d0788c1297 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 14 Feb 2026 14:25:17 +0000 Subject: [PATCH 142/156] :bug: Allow "qupath" output_type. --- tiatoolbox/models/engine/engine_abc.py | 27 ++++++++++++++------- tiatoolbox/models/engine/patch_predictor.py | 9 ++++--- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 5851f98c9..f4720a587 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -234,7 +234,7 @@ class EngineABC(ABC): # noqa: B024 drop_keys (list): Keys to exclude from model output. output_type (Any): - Format of output ("dict", "zarr", "AnnotationStore"). + Format of output ("dict", "zarr", "qupath", "AnnotationStore"). verbose (bool): Whether to enable verbose logging. @@ -649,7 +649,7 @@ def save_predictions( Dictionary containing processed model predictions. output_type (str): Desired output format. - Supported values are "dict", "zarr", and "annotationstore". + Supported values are "dict", "zarr", "qupath" and "annotationstore". save_path (Path | None): Path to save the output file. Required for "zarr" and "annotationstore" formats. @@ -694,6 +694,8 @@ def save_predictions( dict | AnnotationStore | Path: - If output_type is "dict": returns predictions as a dictionary. - If output_type is "zarr": returns path to saved zarr file. + - If output_type is "qupath": returns a QuPath JSON + or path to .json file. - If output_type is "annotationstore": returns an AnnotationStore or path to .db file. @@ -725,7 +727,10 @@ def save_predictions( return processed_predictions if output_type.lower() in ["qupath", "annotationstore"]: - save_path = Path(kwargs.get("output_file", save_path.parent / "output.db")) + suffix = "output.json" if output_type.lower() == "qupath" else ".db" + save_path = Path( + kwargs.get("output_file", save_path.parent / ("output" + suffix)) + ) # scale_factor set from kwargs scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) @@ -1217,7 +1222,7 @@ def _update_run_params( ioconfig (ModelIOConfigABC | None): IO configuration for patch extraction and resolution settings. output_type (str): - Desired output format: "dict", "zarr", or "annotationstore". + Desired output format: "dict", "zarr", "qupath" or "annotationstore". overwrite (bool): Whether to overwrite existing output files. Default is False. patch_mode (bool): @@ -1269,7 +1274,7 @@ def _update_run_params( ValueError: If required configuration or input parameters are missing. ValueError: - If save_dir is not provided and output_type is "zarr" + If save_dir is not provided and output_type is "zarr", "qupath" or "annotationstore". """ @@ -1355,7 +1360,7 @@ def _run_patch_mode( Args: output_type (str): Desired output format. Supported values are "dict", "zarr", - and "annotationstore". + "qupath" and "annotationstore". save_dir (Path): Directory to save the output files. **kwargs (EngineABCRunParams): @@ -1399,6 +1404,8 @@ def _run_patch_mode( dict | AnnotationStore | Path: - If output_type is "dict": returns predictions as a dictionary. - If output_type is "zarr": returns path to saved zarr file. + - If output_type is "qupath": returns a QuPath JSON + or path to .json file. - If output_type is "annotationstore": returns an AnnotationStore or path to .db file. @@ -1508,7 +1515,7 @@ def _run_wsi_mode( Args: output_type (str): Desired output format. Supported values are "dict", "zarr", - and "annotationstore". + "qupath" and "annotationstore". save_dir (Path): Directory to save the output files. **kwargs (EngineABCRunParams): @@ -1555,8 +1562,10 @@ def _run_wsi_mode( """ suffix = ".zarr" - if output_type == "AnnotationStore": + if output_type.lower() == "annotationstore": suffix = ".db" + if output_type.lower() == "qupath": + suffix = ".json" def get_path(image: Path | WSIReader) -> Path: """Return path to output file.""" @@ -1666,7 +1675,7 @@ def run( overwrite (bool): Whether to overwrite existing output files. Default is False. output_type (str): - Desired output format: "dict", "zarr", "QuPath", or "annotationstore". + Desired output format: "dict", "zarr", "qupath", or "annotationstore". **kwargs (EngineABCRunParams): Additional runtime parameters to update engine attributes. diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index c2e71325d..a64a07945 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -241,7 +241,7 @@ class PatchPredictor(EngineABC): drop_keys (list): Keys to exclude from model output. output_type (str): - Format of output ("dict", "zarr", "annotationstore"). + Format of output ("dict", "zarr", "qupath", "annotationstore"). Example: >>> # list of 2 image patches as input @@ -479,7 +479,8 @@ def _update_run_params( ioconfig (IOPatchPredictorConfig | None): IO configuration for patch extraction and resolution. output_type (str): - Desired output format: "dict", "zarr", or "annotationstore". + Desired output format: "dict", "zarr", "qupath" + or "annotationstore". overwrite (bool): Whether to overwrite existing output files. Default is False. patch_mode (bool): @@ -589,8 +590,8 @@ def run( overwrite (bool): Whether to overwrite existing output files. Default is False. output_type (str): - Desired output format: "dict", "zarr", or "annotationstore". - Default value is "zarr". + Desired output format: "dict", "zarr", "qupath" + or "annotationstore". Default value is "zarr". **kwargs (PredictorRunParams): Additional runtime parameters to configure prediction. From 8e2c620526489021e9c619477141720c7a6e7c61 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 14 Feb 2026 16:33:50 +0000 Subject: [PATCH 143/156] :sparkles: Add support for QuPath output in semantic_segmentor.py --- tests/engines/test_patch_predictor.py | 17 ++- tests/engines/test_semantic_segmentor.py | 88 ++++++++++++- .../models/engine/semantic_segmentor.py | 29 +++-- tiatoolbox/utils/misc.py | 122 +++++++++++++++++- 4 files changed, 240 insertions(+), 16 deletions(-) diff --git a/tests/engines/test_patch_predictor.py b/tests/engines/test_patch_predictor.py index 946d78679..98f297555 100644 --- a/tests/engines/test_patch_predictor.py +++ b/tests/engines/test_patch_predictor.py @@ -501,7 +501,6 @@ def test_patch_predictor_patch_mode_no_probabilities( assert "probabilities" not in output - # don't run test on GPU output = predictor.run( images=inputs, return_probabilities=False, @@ -517,6 +516,22 @@ def test_patch_predictor_patch_mode_no_probabilities( assert np.all(output["predictions"] == [6, 3]) assert output["probabilities"] == [] + # QuPath Output + output = predictor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + save_dir=track_tmp_path / "patch_out_check", + output_type="qupath", + overwrite=True, + ) + + assert output.exists() + output = _extract_from_qupath_json(output) + assert np.all(output["predictions"] == [6, 3]) + def test_engine_run_wsi_annotation_store( sample_wsi_dict: dict, diff --git a/tests/engines/test_semantic_segmentor.py b/tests/engines/test_semantic_segmentor.py index a2fc939ac..4d97030c8 100644 --- a/tests/engines/test_semantic_segmentor.py +++ b/tests/engines/test_semantic_segmentor.py @@ -154,6 +154,39 @@ def _test_store_output_patch(output: Path) -> None: assert annotations_properties is not None +def _test_qupath_output_patch(output: Path) -> None: + """Helper function to test QuPath JSON output for a patch.""" + with Path.open(output) as f: + data = json.load(f) + + assert "features" in data + features = data["features"] + assert len(features) > 0 + + geometry_types = [] + class_values = set() + + for feat in features: + # geometry type + geom = feat.get("geometry", {}) + geometry_types.append(geom.get("type")) + + # class index (you stored this as class_value) + class_val = feat.get("class_value") + if class_val is not None: + class_values.add(class_val) + + # Check geometry type + assert "Polygon" in geometry_types + + # When class_dict is None, types are assigned as 0, 1, ... + assert 0 in class_values + assert 1 in class_values + + # Basic sanity check + assert features is not None + + def test_semantic_segmentor_tiles(track_tmp_path: Path) -> None: """Tests SemanticSegmentor on image tiles with no mpp metadata.""" segmentor = SemanticSegmentor( @@ -198,7 +231,7 @@ def test_save_annotation_store(remote_sample: Callable, track_tmp_path: Path) -> # Test str input sample_image = remote_sample("thumbnail-1k-1k") - inputs = [str(sample_image)] + inputs = [Path(sample_image)] output = segmentor.run( images=inputs, @@ -216,6 +249,33 @@ def test_save_annotation_store(remote_sample: Callable, track_tmp_path: Path) -> _test_store_output_patch(output[0]) +def test_save_qupath_json(remote_sample: Callable, track_tmp_path: Path) -> None: + """Test for saving output as annotation store.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", batch_size=32, verbose=False, device=device + ) + + # Test str input + sample_image = remote_sample("thumbnail-1k-1k") + + inputs = [Path(sample_image)] + + output = segmentor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + save_dir=track_tmp_path / "output1", + output_type="qupath", + verbose=True, + ) + + assert output[0] == track_tmp_path / "output1" / (sample_image.stem + ".json") + assert len(output) == 1 + _test_qupath_output_patch(output[0]) + + def test_save_annotation_store_nparray( remote_sample: Callable, track_tmp_path: Path, caplog: pytest.LogCaptureFixture ) -> None: @@ -557,7 +617,31 @@ def test_wsi_segmentor_annotationstore( zarr_group = zarr.open(output[sample_svs].with_suffix(".zarr"), mode="r") assert "probabilities" in zarr_group - assert "Probability maps cannot be saved as AnnotationStore." in caplog.text + assert "Probability maps cannot be saved as AnnotationStore" in caplog.text + + +def test_wsi_segmentor_qupath_json(sample_svs: Path, track_tmp_path: Path) -> None: + """Test SemanticSegmentor for WSIs with QuPath JSON output.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", + batch_size=32, + verbose=False, + ) + # Return Probabilities is False + output = segmentor.run( + images=[sample_svs], + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=False, + save_dir=track_tmp_path / "wsi_out_check", + verbose=False, + output_type="QuPath", + ) + + assert output[sample_svs] == track_tmp_path / "wsi_out_check" / ( + sample_svs.stem + ".json" + ) def test_prepare_full_batch_low_memory(track_tmp_path: Path) -> None: diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 5a08c5a76..2cb74f424 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -233,7 +233,7 @@ class SemanticSegmentor(PatchPredictor): drop_keys (list): Keys to exclude from model output. output_type (str): - Format of output ("dict", "zarr", "annotationstore"). + Format of output ("dict", "zarr", "qupath", "annotationstore"). output_locations (list | None): Coordinates of output patches used during WSI processing. @@ -618,9 +618,10 @@ def save_predictions( processed_predictions (dict): Dictionary containing processed model predictions. output_type (str): - Desired output format: "dict", "zarr", or "annotationstore". + Desired output format: "dict", "zarr", "qupath" or "annotationstore". save_path (Path | None): - Path to save the output file. Required for "zarr" and "annotationstore". + Path to save the output file. Required for "zarr", "qupath" + and "annotationstore". **kwargs (SemanticSegmentorRunParams): Additional runtime parameters to configure segmentation. @@ -664,12 +665,14 @@ def save_predictions( dict | AnnotationStore | Path | list[Path]: - If output_type is "dict": returns predictions as a dictionary. - If output_type is "zarr": returns path to saved Zarr file. + - If output_type is "qupath": returns QuPath JSON + or path or list of paths to .json file. - If output_type is "annotationstore": returns AnnotationStore or path or list of paths to .db file. """ # Conversion to annotationstore uses a different function for SemanticSegmentor - if output_type.lower() != "annotationstore": + if output_type.lower() not in ["qupath", "annotationstore"]: return super().save_predictions( processed_predictions, output_type, save_path=save_path, **kwargs ) @@ -700,16 +703,18 @@ def save_predictions( save_paths = [] logger.info("Saving predictions as AnnotationStore.") + suffix = ".json" if output_type.lower() == "qupath" else ".db" if self.patch_mode: for i, predictions in enumerate(processed_predictions["predictions"]): if isinstance(self.images[i], Path): - output_path = save_path.parent / (self.images[i].stem + ".db") + output_path = save_path.parent / (self.images[i].stem + suffix) else: - output_path = save_path.parent / (str(i) + ".db") + output_path = save_path.parent / (str(i) + suffix) out_file = dict_to_store_semantic_segmentor( patch_output={"predictions": predictions}, scale_factor=scale_factor, + output_type=output_type, class_dict=class_dict, save_path=output_path, verbose=self.verbose, @@ -720,15 +725,16 @@ def save_predictions( out_file = dict_to_store_semantic_segmentor( patch_output=processed_predictions, scale_factor=scale_factor, + output_type=output_type, class_dict=class_dict, - save_path=save_path.with_suffix(".db"), + save_path=save_path.with_suffix(suffix), verbose=self.verbose, ) save_paths = out_file if return_probabilities: msg = ( - f"Probability maps cannot be saved as AnnotationStore. " + f"Probability maps cannot be saved as AnnotationStore or JSON. " f"To visualise heatmaps in TIAToolbox Visualization tool," f"convert heatmaps in {save_path} to ome.tiff using" f"tiatoolbox.utils.misc.write_probability_heatmap_as_ome_tiff." @@ -777,7 +783,8 @@ def _update_run_params( ioconfig (ModelIOConfigABC | None): IO configuration for patch extraction and resolution. output_type (str): - Desired output format: "dict", "zarr", or "annotationstore". + Desired output format: "dict", "zarr", "qupath", + or "annotationstore". overwrite (bool): Whether to overwrite existing output files. Default is False. patch_mode (bool): @@ -893,8 +900,8 @@ def run( overwrite (bool): Whether to overwrite existing output files. Default is False. output_type (str): - Desired output format: "dict", "zarr", or "annotationstore". Default - is "dict". + Desired output format: "dict", "zarr", "qupath", + or "annotationstore". Default is "dict". **kwargs (SemanticSegmentorRunParams): Additional runtime parameters to configure segmentation. diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 4e605ea70..42d0d7f2e 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1417,6 +1417,7 @@ def process_contours( def dict_to_store_semantic_segmentor( patch_output: dict | zarr.Group, scale_factor: tuple[float, float], + output_type: str, class_dict: dict | None = None, save_path: Path | None = None, *, @@ -1432,6 +1433,9 @@ def dict_to_store_semantic_segmentor( annotations. All coordinates will be multiplied by this factor to allow conversion of annotations saved at non-baseline resolution to baseline. Should be model_mpp/slide_mpp. + output_type (str): + "annotationstore" → return AnnotationStore + "qupath" → return QuPath JSON dict class_dict (dict): Optional dictionary mapping class indices to class names. save_path (str or Path): @@ -1452,11 +1456,125 @@ def dict_to_store_semantic_segmentor( # Get the number of unique predictions layer_list = da.unique(preds).compute() - store = SQLiteStore() - if class_dict is None: class_dict = {int(i): int(i) for i in layer_list.tolist()} + if output_type.lower() == "qupath": + return _semantic_segmentations_as_qupath_json( + layer_list=layer_list, + preds=preds, + scale_factor=scale_factor, + class_dict=class_dict, + save_path=save_path, + verbose=verbose, + ) + + return _semantic_segmentations_as_annotations( + layer_list=layer_list, + preds=preds, + scale_factor=scale_factor, + class_dict=class_dict, + save_path=save_path, + verbose=verbose, + ) + + +def _semantic_segmentations_as_qupath_json( + layer_list: list, + preds: da.Array, + scale_factor: tuple[float, float], + class_dict: dict | None = None, + save_path: Path | None = None, + *, + verbose: bool = True, +) -> dict | Path: + """Helper function to save semantic segmentation as QuPath json.""" + features = [] + + # color map for classes + num_classes = len(class_dict) + cmap = plt.cm.get_cmap("tab20", num_classes) + class_colours = { + class_idx: [ + int(cmap(class_idx)[0] * 255), + int(cmap(class_idx)[1] * 255), + int(cmap(class_idx)[2] * 255), + ] + for class_idx in class_dict + } + + tqdm_loop = tqdm( + layer_list, + leave=False, + desc="Converting outputs to QuPath JSON.", + disable=not verbose, + ) + + for type_class in tqdm_loop: + class_id = int(type_class) + class_label = class_dict[class_id] + + # binary mask for this class + layer = da.where(preds == type_class, 1, 0).astype("uint8").compute() + + contours, _ = cv2.findContours(layer, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) + + contours = cast("list[np.ndarray]", contours) + + # Convert contours to polygons + for cnt in contours: + if cnt.shape[0] < 3: # noqa: PLR2004 + continue + + # scale coordinates + cnt_scaled = cnt.squeeze(1).astype(float) + cnt_scaled[:, 0] *= scale_factor[0] + cnt_scaled[:, 1] *= scale_factor[1] + + poly = Polygon(cnt_scaled) + poly_geo = mapping(poly) + + feature = { + "type": "Feature", + "geometry": poly_geo, + "id": f"class_{class_id}_{len(features)}", + "properties": { + "classification": { + "name": class_label, + "color": class_colours[class_id], + } + }, + "objectType": "annotation", + "name": class_label, + "class_value": class_id, + } + + features.append(feature) + + qupath_json = {"type": "FeatureCollection", "features": features} + + # if a save director is provided, then dump json into a file + if save_path: + save_path.parent.mkdir(parents=True, exist_ok=True) + save_path = save_path.with_suffix(".json") + with Path.open(save_path, "w") as f: + json.dump(qupath_json, f, indent=2) + return save_path + + return qupath_json + + +def _semantic_segmentations_as_annotations( + layer_list: list, + preds: da.Array, + scale_factor: tuple[float, float], + class_dict: dict | None = None, + save_path: Path | None = None, + *, + verbose: bool = True, +) -> AnnotationStore | Path: + """Helper function to save semantic segmentation as annotations.""" + store = SQLiteStore() annotations_list: list[Annotation] = [] tqdm_loop = tqdm( From 3e3207ba0db4dfcf33801f897841789ea13a4f5c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 14 Feb 2026 16:54:16 +0000 Subject: [PATCH 144/156] :bug: Fix tests and `mypy` errors --- tests/engines/test_semantic_segmentor.py | 4 ++-- tests/test_utils.py | 8 ++++++++ tiatoolbox/utils/misc.py | 8 ++++---- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/engines/test_semantic_segmentor.py b/tests/engines/test_semantic_segmentor.py index 4d97030c8..b5a6f220c 100644 --- a/tests/engines/test_semantic_segmentor.py +++ b/tests/engines/test_semantic_segmentor.py @@ -307,7 +307,7 @@ def test_save_annotation_store_nparray( zarr_group = zarr.open(str(track_tmp_path / "output1" / "output.zarr"), mode="r") assert "probabilities" in zarr_group - assert "Probability maps cannot be saved as AnnotationStore." in caplog.text + assert "Probability maps cannot be saved as AnnotationStore" in caplog.text _test_store_output_patch(output[0]) _test_store_output_patch(output[1]) @@ -617,7 +617,7 @@ def test_wsi_segmentor_annotationstore( zarr_group = zarr.open(output[sample_svs].with_suffix(".zarr"), mode="r") assert "probabilities" in zarr_group - assert "Probability maps cannot be saved as AnnotationStore" in caplog.text + assert "Probability maps cannot be saved as AnnotationStore or JSON." in caplog.text def test_wsi_segmentor_qupath_json(sample_svs: Path, track_tmp_path: Path) -> None: diff --git a/tests/test_utils.py b/tests/test_utils.py index 04e91405c..5d524dfb6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1886,6 +1886,7 @@ def test_dict_to_store_semantic_segment() -> None: scale_factor=(1.0, 1.0), class_dict=None, save_path=None, + output_type="annotationstore", ) assert len(store_) == 1 for annotation in store_.values(): @@ -1899,6 +1900,7 @@ def test_dict_to_store_semantic_segment() -> None: scale_factor=(1.0, 1.0), class_dict=None, save_path=None, + output_type="annotationstore", ) assert len(store_) == 2 @@ -1918,6 +1920,7 @@ def test_dict_to_store_semantic_segment() -> None: scale_factor=(1.0, 1.0), class_dict=None, save_path=None, + output_type="annotationstore", ) assert len(store_) == 3 @@ -1938,6 +1941,7 @@ def test_dict_to_store_semantic_segment() -> None: scale_factor=(1.0, 1.0), class_dict=None, save_path=None, + output_type="annotationstore", ) assert len(store_) == 4 annotations_ = store_.values() @@ -1974,6 +1978,7 @@ def test_dict_to_store_semantic_segment_holes(track_tmp_path: Path) -> None: scale_factor=(1.0, 1.0), class_dict={0: "background", 1: "object"}, save_path=save_dir_path, + output_type="annotationstore", ) assert save_dir_path.exists() @@ -1983,6 +1988,7 @@ def test_dict_to_store_semantic_segment_holes(track_tmp_path: Path) -> None: scale_factor=(1.0, 1.0), class_dict={0: "background", 1: "object"}, save_path=None, + output_type="annotationstore", ) # outer contour and inner contour/hole are now within the same geometry @@ -2033,6 +2039,7 @@ def test_dict_to_store_semantic_segment_multiple_holes() -> None: scale_factor=(1.0, 1.0), class_dict={0: "background", 1: "object"}, save_path=None, + output_type="annotationstore", ) # outer contour and inner contour/hole are now within the same geometry @@ -2081,6 +2088,7 @@ def test_dict_to_store_semantic_segment_no_holes() -> None: scale_factor=(1.0, 1.0), class_dict={0: "background", 1: "object"}, save_path=None, + output_type="annotationstore", ) # outer contour and inner contour/hole are now within the same geometry diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 42d0d7f2e..529354153 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1483,13 +1483,13 @@ def _semantic_segmentations_as_qupath_json( layer_list: list, preds: da.Array, scale_factor: tuple[float, float], - class_dict: dict | None = None, + class_dict: dict, save_path: Path | None = None, *, verbose: bool = True, ) -> dict | Path: """Helper function to save semantic segmentation as QuPath json.""" - features = [] + features: list = [] # color map for classes num_classes = len(class_dict) @@ -1527,7 +1527,7 @@ def _semantic_segmentations_as_qupath_json( continue # scale coordinates - cnt_scaled = cnt.squeeze(1).astype(float) + cnt_scaled: np.ndarray = cnt.squeeze(1).astype(float) cnt_scaled[:, 0] *= scale_factor[0] cnt_scaled[:, 1] *= scale_factor[1] @@ -1568,7 +1568,7 @@ def _semantic_segmentations_as_annotations( layer_list: list, preds: da.Array, scale_factor: tuple[float, float], - class_dict: dict | None = None, + class_dict: dict, save_path: Path | None = None, *, verbose: bool = True, From c82273fc8a0db676f3cf8e9d869091d05d4624eb Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 14 Feb 2026 16:59:32 +0000 Subject: [PATCH 145/156] :bug: Fix `mypy` errors --- tiatoolbox/utils/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 529354153..38d931075 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1422,7 +1422,7 @@ def dict_to_store_semantic_segmentor( save_path: Path | None = None, *, verbose: bool = True, -) -> AnnotationStore | Path: +) -> AnnotationStore | dict | Path: """Converts output of TIAToolbox SemanticSegmentor engine to AnnotationStore. Args: From 607be4ae46a2dc75d3c1280a98b083ae46975428 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sun, 15 Feb 2026 10:26:53 +0000 Subject: [PATCH 146/156] :recycle: Restructure code to avoid duplicates. --- tests/test_utils.py | 2 ++ tiatoolbox/utils/misc.py | 53 +++++++++++++++++++++++----------------- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 5d524dfb6..e814f910a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1780,6 +1780,7 @@ def test_patch_pred_store_persist(track_tmp_path: pytest.TempPathFactory) -> Non "labels": [1, 0, 1], } save_path = track_tmp_path / "patch_output" / "output.db" + save_path.parent.mkdir() store_path = misc.dict_to_store_patch_predictions( patch_output, (1.0, 1.0), save_path=save_path @@ -1816,6 +1817,7 @@ def test_patch_pred_store_persist_ext(track_tmp_path: pytest.TempPathFactory) -> # sends the path of a jpeg source image, expects .db file in the same directory save_path = track_tmp_path / "patch_output" / "output.jpeg" + save_path.parent.mkdir() store_path = misc.dict_to_store_patch_predictions( patch_output, (1.0, 1.0), save_path=save_path diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 38d931075..f414c6afd 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1555,11 +1555,7 @@ def _semantic_segmentations_as_qupath_json( # if a save director is provided, then dump json into a file if save_path: - save_path.parent.mkdir(parents=True, exist_ok=True) - save_path = save_path.with_suffix(".json") - with Path.open(save_path, "w") as f: - json.dump(qupath_json, f, indent=2) - return save_path + return save_qupath_json(save_path=save_path, qupath_json=qupath_json) return qupath_json @@ -1607,17 +1603,34 @@ def _semantic_segmentations_as_annotations( # # if a save director is provided, then dump store into a file if save_path: - # ensure parent directory exists - save_path.parent.absolute().mkdir(parents=True, exist_ok=True) - # ensure proper db extension - save_path = save_path.parent.absolute() / (save_path.stem + ".db") - store.commit() - store.dump(save_path) - return save_path + return save_annotations( + save_path=save_path, + store=store, + ) return store +def save_annotations( + save_path: Path, + store: AnnotationStore, +) -> Path: + """Saves Annotation Store to disk.""" + # ensure proper db extension + save_path = save_path.parent.absolute() / (save_path.stem + ".db") + store.commit() + store.dump(save_path) + return save_path + + +def save_qupath_json(save_path: Path, qupath_json: dict) -> Path: + """Saves QuPath JSON to disk.""" + save_path = save_path.with_suffix(".json") + with Path.open(save_path, "w") as f: + json.dump(qupath_json, f, indent=2) + return save_path + + def dict_to_store_patch_predictions( patch_output: dict | zarr.group, scale_factor: tuple[float, float], @@ -1698,11 +1711,7 @@ def dict_to_store_patch_predictions( ) if save_path: - save_path.parent.mkdir(parents=True, exist_ok=True) - save_path = save_path.with_suffix(".json") - with Path.open(save_path, "w") as f: - json.dump(qupath_json, f, indent=2) - return save_path + return save_qupath_json(save_path=save_path, qupath_json=qupath_json) return qupath_json @@ -1723,12 +1732,10 @@ def dict_to_store_patch_predictions( # if a save director is provided, then dump store into a file if save_path: - # ensure parent directory exists - save_path.parent.absolute().mkdir(parents=True, exist_ok=True) - # ensure proper db extension - save_path = save_path.parent.absolute() / (save_path.stem + ".db") - store.dump(save_path) - return save_path + return save_annotations( + save_path=save_path, + store=store, + ) return store From 18b8fcd6c1b550ee4ead581a256a337f9abe8c77 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sun, 15 Feb 2026 11:59:14 +0000 Subject: [PATCH 147/156] :sparkles: Add `QuPath` support for nucleus detection. --- .../engines/test_nucleus_detection_engine.py | 82 ++- tiatoolbox/models/engine/nucleus_detector.py | 527 +++++++++++------- 2 files changed, 410 insertions(+), 199 deletions(-) diff --git a/tests/engines/test_nucleus_detection_engine.py b/tests/engines/test_nucleus_detection_engine.py index c91993b6a..15b413c18 100644 --- a/tests/engines/test_nucleus_detection_engine.py +++ b/tests/engines/test_nucleus_detection_engine.py @@ -1,5 +1,6 @@ """Tests for NucleusDetector.""" +import json import shutil from collections.abc import Callable from pathlib import Path @@ -14,6 +15,9 @@ from tiatoolbox.annotation.storage import SQLiteStore from tiatoolbox.models.engine.nucleus_detector import ( NucleusDetector, + _write_detection_arrays_to_store, + save_detection_arrays_to_qupath_json, + save_detection_arrays_to_store, ) from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils.misc import imwrite @@ -57,7 +61,7 @@ def test_write_detection_arrays_to_store() -> None: "probabilities": np.array([1.0, 0.5], dtype=np.float32), } - store = NucleusDetector.save_detection_arrays_to_store(detection_arrays) + store = save_detection_arrays_to_store(detection_arrays) assert len(store.values()) == 2 detection_arrays = { @@ -70,7 +74,32 @@ def test_write_detection_arrays_to_store() -> None: ValueError, match=r"Detection record lengths are misaligned.", ): - _ = NucleusDetector.save_detection_arrays_to_store(detection_arrays) + _ = save_detection_arrays_to_store(detection_arrays) + + +def test_write_detection_arrays_to_qupath() -> None: + """Test writing detection arrays to QuPath JSON.""" + detection_arrays = { + "x": np.array([1, 3], dtype=np.uint32), + "y": np.array([1, 2], dtype=np.uint32), + "classes": np.array([0, 1], dtype=np.uint32), + "probabilities": np.array([1.0, 0.5], dtype=np.float32), + } + + json_ = save_detection_arrays_to_qupath_json(detection_arrays) + assert len(json_.values()) == 2 + + detection_arrays = { + "x": np.array([1], dtype=np.uint32), + "y": np.array([1, 2], dtype=np.uint32), + "classes": np.array([0], dtype=np.uint32), + "probabilities": np.array([1.0, 0.5], dtype=np.float32), + } + with pytest.raises( + ValueError, + match=r"Detection record lengths are misaligned.", + ): + _ = save_detection_arrays_to_store(detection_arrays) def test_write_detection_records_to_store_no_class_dict() -> None: @@ -78,7 +107,7 @@ def test_write_detection_records_to_store_no_class_dict() -> None: detection_records = (np.array([1]), np.array([2]), np.array([0]), np.array([1.0])) dummy_store = SQLiteStore() - total = NucleusDetector._write_detection_arrays_to_store( + total = _write_detection_arrays_to_store( detection_records, store=dummy_store, scale_factor=(1.0, 1.0), class_dict=None ) assert len(dummy_store.values()) == 1 @@ -88,7 +117,7 @@ def test_write_detection_records_to_store_no_class_dict() -> None: dummy_store.close() -def test_nucleus_detector_patch_annotation_store_output( +def test_nuc_detector_patch_qupath_json_annotation_store( remote_sample: Callable, track_tmp_path: Path ) -> None: """Test for nucleus detection engine in patch mode.""" @@ -149,6 +178,27 @@ def test_nucleus_detector_patch_annotation_store_output( assert len(store_2.values()) == 0 store_2.close() + _ = nucleus_detector.run( + patch_mode=True, + device=device, + output_type="qupath", + memory_threshold=50, + images=[image_dir / "patch_0.png", image_dir / "patch_1.png"], + save_dir=save_dir, + overwrite=True, + ) + + with Path.open(save_dir / "patch_0.json", "r") as f: + data_1 = json.load(f) + features_1 = data_1.get("features", []) + assert len(features_1) == 1 + + with Path.open(save_dir / "patch_1.json", "r") as f: + data_2 = json.load(f) + features_2 = data_2.get("features", []) + + assert len(features_2) == 0 + _rm_dir(save_dir) @@ -271,6 +321,30 @@ def test_nucleus_detector_wsi(remote_sample: Callable, track_tmp_path: Path) -> assert annotation.properties["type"] == "test_nucleus" store.close() + # QuPath + nucleus_detector.drop_keys = [] + _ = nucleus_detector.run( + patch_mode=False, + device=device, + output_type="qupath", + memory_threshold=50, + images=[mini_wsi_svs], + save_dir=save_dir, + overwrite=True, + batch_size=8, + class_dict={0: "test_nucleus"}, + min_distance=5, + postproc_tile_shape=(2048, 2048), + ) + + with Path.open(save_dir / "wsi4_512_512.json", "r") as f: + qupath_json = json.load(f) + features: list[dict] = qupath_json.get("features", []) + assert 245 <= len(features) <= 255 + first = features[0] + # Classification name + assert first["properties"]["classification"]["name"] == "test_nucleus" + # Check cached centroid maps are removed temp_zarr_files = save_dir / "wsi4_512_512.zarr" assert not temp_zarr_files.exists() diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index 2652722d4..aa16a3070 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -53,6 +53,7 @@ import dask.array as da import numpy as np import zarr +from matplotlib import pyplot as plt from shapely.geometry import Point from tqdm.auto import tqdm @@ -62,7 +63,11 @@ SemanticSegmentor, SemanticSegmentorRunParams, ) -from tiatoolbox.utils.misc import tqdm_dask_progress_bar +from tiatoolbox.utils.misc import ( + save_annotations, + save_qupath_json, + tqdm_dask_progress_bar, +) if TYPE_CHECKING: # pragma: no cover import os @@ -212,7 +217,8 @@ class NucleusDetector(SemanticSegmentor): drop_keys (list): Keys to exclude from model output when saving results. output_type (str): - Output format (``"dict"``, ``"zarr"``, or ``"annotationstore"``). + Output format (``"dict"``, ``"zarr"``, ``"qupath"``, + or ``"annotationstore"``). Examples: >>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector @@ -520,7 +526,8 @@ class IDs. - ``"probabilities"`` (da.Array): detection probabilities. output_type (str): - Desired output format: ``"dict"``, ``"zarr"``, or ``"annotationstore"``. + Desired output format: ``"dict"``, ``"zarr"``, ``"qupath"`` + or ``"annotationstore"``. save_path (Path | None): Path at which to save the output file(s). Required for file outputs (e.g., Zarr or SQLite .db). If ``None`` and ``output_type="dict"``, @@ -576,6 +583,10 @@ class names. returns a Python dictionary of predictions. - If ``output_type="zarr"``: returns the path to the saved ``.zarr`` group. + - If ``output_type="qupath"``: + returns QuPath JSON or the path(s) to saved + ``.json`` file(s). In patch mode, a list of per-image paths + may be returned. - If ``output_type="annotationstore"``: returns an AnnotationStore handle or the path(s) to saved ``.db`` file(s). In patch mode, a list of per-image paths @@ -587,7 +598,7 @@ class names. TIAToolbox engines. """ - if output_type.lower() != "annotationstore": + if output_type.lower() not in ["qupath", "annotationstore"]: out = super().save_predictions( processed_predictions, output_type, @@ -602,11 +613,12 @@ class names. if class_dict is None: class_dict = self.model.output_class_dict - out = self._save_predictions_annotation_store( + out = self._save_predictions_qupath_json_annotations_db( processed_predictions, save_path=save_path, scale_factor=scale_factor, class_dict=class_dict, + output_type=output_type, ) # Remove cached centroid maps if wsi mode @@ -619,12 +631,13 @@ class names. return out - def _save_predictions_annotation_store( + def _save_predictions_qupath_json_annotations_db( self: NucleusDetector, processed_predictions: dict, save_path: Path | None = None, scale_factor: tuple[float, float] = (1.0, 1.0), class_dict: dict | None = None, + output_type: str = "annotationstore", ) -> AnnotationStore | Path | list[Path]: """Save nucleus detections to an AnnotationStore (.db). @@ -664,6 +677,8 @@ def _save_predictions_annotation_store( Scaling factors applied to x and y coordinates prior to writing. Typically corresponds to ``model_mpp / slide_mpp``. Defaults to ``(1.0, 1.0)``. + output_type (str): + Desired output format: ``"qupath"`` or ``"annotationstore"``. class_dict (dict or None): Optional mapping from original class IDs to class names or remapped IDs. If ``None``, an identity mapping based on present classes is used. @@ -686,11 +701,12 @@ def _save_predictions_annotation_store( save_paths = [] num_patches = len(processed_predictions["x"]) + suffix = ".json" if output_type == "qupath" else ".db" for i in range(num_patches): if isinstance(self.images[i], Path): - output_path = save_path.parent / (self.images[i].stem + ".db") + output_path = save_path.parent / (self.images[i].stem + suffix) else: - output_path = save_path.parent / (str(i) + ".db") + output_path = save_path.parent / (str(i) + suffix) detection_arrays = { "x": processed_predictions["x"][i], @@ -699,17 +715,34 @@ def _save_predictions_annotation_store( "probabilities": processed_predictions["probabilities"][i], } - out_file = self.save_detection_arrays_to_store( - detection_arrays=detection_arrays, - scale_factor=scale_factor, - class_dict=class_dict, - save_path=output_path, + out_file = ( + save_detection_arrays_to_qupath_json( + detection_arrays=detection_arrays, + scale_factor=scale_factor, + class_dict=class_dict, + save_path=output_path, + ) + if output_type == "qupath" + else save_detection_arrays_to_store( + detection_arrays=detection_arrays, + scale_factor=scale_factor, + class_dict=class_dict, + save_path=output_path, + ) ) save_paths.append(out_file) return save_paths - return self.save_detection_arrays_to_store( + if output_type == "qupath": + return save_detection_arrays_to_qupath_json( + detection_arrays=processed_predictions, + scale_factor=scale_factor, + save_path=save_path, + class_dict=class_dict, + ) + + return save_detection_arrays_to_store( detection_arrays=processed_predictions, scale_factor=scale_factor, save_path=save_path, @@ -825,189 +858,9 @@ class IDs for each detection (``np.uint32``). "probabilities": da.from_array(probs, chunks="auto"), } - @staticmethod - def _write_detection_arrays_to_store( - detection_arrays: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], - store: SQLiteStore, - scale_factor: tuple[float, float], - class_dict: dict[int, str | int] | None, - batch_size: int = 5000, - *, - verbose: bool = True, - ) -> int: - """Write detection arrays to an AnnotationStore in batches. - - Converts coordinate, class, and probability arrays into `Annotation` - objects and appends them to an SQLite-backed store in configurable - batch sizes. Coordinates are scaled to baseline slide resolution using - the provided `scale_factor`, and optional class-ID remapping is applied - via `class_dict`. - - Args: - detection_arrays (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): - Tuple of arrays in the order: - `(x_coords, y_coords, class_ids, probabilities)`. - Each element must be a 1-D NumPy array of equal length. - store (SQLiteStore): - Target `AnnotationStore` instance to receive the detections. - scale_factor (tuple[float, float]): - Factors applied to `(x, y)` coordinates prior to writing, - typically `(model_mpp / slide_mpp)`. The scaled coordinates are - rounded to `np.uint32`. - class_dict (dict[int, str | int] | None): - Optional mapping from original class IDs to names or remapped IDs. - If `None`, an identity mapping is used for the set of present classes. - batch_size (int): - Number of records to write per batch. Default is `5000`. - verbose (bool): - Whether to display logs and progress bar. - - Returns: - int: - Total number of detection records written to the store. - - Notes: - - Coordinates are scaled and rounded to integers to ensure consistent - geometry creation for `Annotation` points. - - Class mapping is applied per-record; unmapped IDs fall back to their - original values. - - Writing in batches reduces memory pressure and improves throughput - on large number of detections. - - """ - xs, ys, classes, probs = detection_arrays - n = len(xs) - if n == 0: - return 0 # nothing to write - - # scale coordinates - xs = np.rint(xs * scale_factor[0]).astype(np.uint32, copy=False) - ys = np.rint(ys * scale_factor[1]).astype(np.uint32, copy=False) - - # class mapping - if class_dict is None: - # identity over actually-present types - uniq = np.unique(classes) - class_dict = {int(k): int(k) for k in uniq} - labels = np.array( - [class_dict.get(int(k), int(k)) for k in classes], dtype=object - ) - - def make_points(xs_batch: np.ndarray, ys_batch: np.ndarray) -> list[Point]: - """Create Shapely Point geometries from coordinate arrays in batches.""" - return [ - Point(int(xx), int(yy)) - for xx, yy in zip(xs_batch, ys_batch, strict=True) - ] - - tqdm_loop = tqdm( - range(0, n, batch_size), - leave=False, - desc="Writing detections to store", - disable=not verbose, - ) - written = 0 - for i in tqdm_loop: - j = min(i + batch_size, n) - pts = make_points(xs[i:j], ys[i:j]) - - anns = [ - Annotation( - geometry=pt, properties={"type": lbl, "probability": float(pp)} - ) - for pt, lbl, pp in zip(pts, labels[i:j], probs[i:j], strict=True) - ] - store.append_many(anns) - written += j - i - return written - - @staticmethod - def save_detection_arrays_to_store( - detection_arrays: dict[str, da.Array], - scale_factor: tuple[float, float] = (1.0, 1.0), - class_dict: dict | None = None, - save_path: Path | None = None, - batch_size: int = 5000, - ) -> Path | SQLiteStore: - """Write nucleus detection arrays to an SQLite-backed AnnotationStore. - - Converts the detection arrays into NumPy form, applies coordinate scaling - and optional class-ID remapping, and writes the results into an in-memory - SQLiteStore. If `save_path` is provided, the store is committed and saved - to disk as a `.db` file. This method provides a unified interface for - converting Dask-based detection outputs into persistent annotation storage. - - Args: - detection_arrays (dict[str, da.Array]): - A dictionary containing the detection fields: - - ``"x"``: dask array of x coordinates (``np.uint32``). - - ``"y"``: dask array of y coordinates (``np.uint32``). - - ``"classes"``: dask array of class IDs (``np.uint32``). - - ``"probabilities"``: dask array of detection scores (``np.float32``). - scale_factor (tuple[float, float], optional): - Multiplicative factors applied to the x and y coordinates before - saving. The scaled coordinates are rounded to integer pixel - locations. Defaults to ``(1.0, 1.0)``. - class_dict (dict or None): - Optional mapping of class IDs to class names or remapped IDs. - If ``None``, an identity mapping is used based on the detected - class IDs. - save_path (Path or None): - Destination path for saving the `.db` file. If ``None``, the - resulting SQLiteStore is returned in memory. If provided, the - parent directory is created if needed, and the final store is - written as ``save_path.with_suffix(".db")``. - batch_size (int): - Number of detection records to write per batch. Defaults to ``5000``. - - Returns: - Path or SQLiteStore: - - If `save_path` is provided: the path to the saved `.db` file. - - If `save_path` is ``None``: an in-memory `SQLiteStore` containing - all detections. - - Notes: - - The heavy lifting is delegated to - :meth:`NucleusDetector._write_detection_arrays_to_store`, - which performs coordinate scaling, class mapping, and batch writing. - - """ - xs = detection_arrays["x"] - ys = detection_arrays["y"] - classes = detection_arrays["classes"] - probs = detection_arrays["probabilities"] - - xs = np.atleast_1d(np.asarray(xs)) - ys = np.atleast_1d(np.asarray(ys)) - classes = np.atleast_1d(np.asarray(classes)) - probs = np.atleast_1d(np.asarray(probs)) - - if not len(xs) == len(ys) == len(classes) == len(probs): - msg = "Detection record lengths are misaligned." - raise ValueError(msg) - - store = SQLiteStore() - total_written = NucleusDetector._write_detection_arrays_to_store( - (xs, ys, classes, probs), - store, - scale_factor, - class_dict, - batch_size, - ) - logger.info("Total detections written to store: %s", total_written) - - if save_path: - save_path.parent.absolute().mkdir(parents=True, exist_ok=True) - save_path = save_path.parent.absolute() / (save_path.stem + ".db") - store.commit() - store.dump(save_path) - return save_path - - return store - def run( self: NucleusDetector, - images: list[os.PathLike | Path | WSIReader] | np.ndarray, + images: list[os.PathLike | Path | WSIReader | np.ndarray] | np.ndarray, *, masks: list[os.PathLike | Path] | np.ndarray | None = None, input_resolutions: list[dict[Units, Resolution]] | None = None, @@ -1140,3 +993,287 @@ class names. output_type=output_type, **kwargs, ) + + +def save_detection_arrays_to_qupath_json( + detection_arrays: dict[str, da.Array], + scale_factor: tuple[float, float] = (1.0, 1.0), + class_dict: dict | None = None, + save_path: Path | None = None, +) -> dict | Path: + """Write nucleus detection arrays to QuPath JSON. + + Produces a FeatureCollection where each detection is represented as a + Point geometry with classification metadata and probability score. + + Args: + detection_arrays (dict[str, da.Array]): + A dictionary containing the detection fields: + - ``"x"``: dask array of x coordinates (``np.uint32``). + - ``"y"``: dask array of y coordinates (``np.uint32``). + - ``"classes"``: dask array of class IDs (``np.uint32``). + - ``"probabilities"``: dask array of detection scores (``np.float32``). + scale_factor (tuple[float, float], optional): + Multiplicative factors applied to the x and y coordinates before + saving. The scaled coordinates are rounded to integer pixel + locations. Defaults to ``(1.0, 1.0)``. + class_dict (dict or None): + Optional mapping of class IDs to class names or remapped IDs. + If ``None``, an identity mapping is used based on the detected + class IDs. + save_path (Path or None): + Destination path for saving the `.db` file. If ``None``, the + resulting SQLiteStore is returned in memory. If provided, the + parent directory is created if needed, and the final store is + written as ``save_path.with_suffix(".db")``. + batch_size (int): + Number of detection records to write per batch. Defaults to ``5000``. + + Returns: + Path or QuPath: + - If `save_path` is provided: the path to the saved `.json` file. + - If `save_path` is ``None``: an in-memory `JSON` containing + all detections. + + """ + xs = np.atleast_1d(np.asarray(detection_arrays["x"])) + ys = np.atleast_1d(np.asarray(detection_arrays["y"])) + classes = np.atleast_1d(np.asarray(detection_arrays["classes"])) + probs = np.atleast_1d(np.asarray(detection_arrays["probabilities"])) + + if not len(xs) == len(ys) == len(classes) == len(probs): + msg = "Detection record lengths are misaligned." + raise ValueError(msg) + + # Determine class dictionary + unique_classes = np.unique(classes).tolist() + if class_dict is None: + class_dict = {int(i): int(i) for i in unique_classes} + + # Color map for classes + num_classes = len(class_dict) + cmap = plt.cm.get_cmap("tab20", num_classes) + class_colours = { + class_idx: [ + int(cmap(class_idx)[0] * 255), + int(cmap(class_idx)[1] * 255), + int(cmap(class_idx)[2] * 255), + ] + for class_idx in class_dict + } + + features: list[dict] = [] + + for i in range(len(xs)): + # Scale coordinates + x = float(xs[i]) * scale_factor[0] + y = float(ys[i]) * scale_factor[1] + + class_id = int(classes[i]) + class_label = class_dict[class_id] + prob = float(probs[i]) + + # QuPath point geometry + point_geo = { + "type": "Point", + "coordinates": [x, y], + } + + feature = { + "type": "Feature", + "id": f"detection_{i}", + "geometry": point_geo, + "properties": { + "classification": { + "name": class_label, + "color": class_colours[class_id], + }, + "probability": prob, + }, + "objectType": "detection", + "name": class_label, + "class_value": class_id, + } + + features.append(feature) + + qupath_json = {"type": "FeatureCollection", "features": features} + + if save_path: + return save_qupath_json(save_path=save_path, qupath_json=qupath_json) + + return qupath_json + + +def save_detection_arrays_to_store( + detection_arrays: dict[str, da.Array], + scale_factor: tuple[float, float] = (1.0, 1.0), + class_dict: dict | None = None, + save_path: Path | None = None, + batch_size: int = 5000, +) -> Path | SQLiteStore: + """Write nucleus detection arrays to an SQLite-backed AnnotationStore. + + Converts the detection arrays into NumPy form, applies coordinate scaling + and optional class-ID remapping, and writes the results into an in-memory + SQLiteStore. If `save_path` is provided, the store is committed and saved + to disk as a `.db` file. This method provides a unified interface for + converting Dask-based detection outputs into persistent annotation storage. + + Args: + detection_arrays (dict[str, da.Array]): + A dictionary containing the detection fields: + - ``"x"``: dask array of x coordinates (``np.uint32``). + - ``"y"``: dask array of y coordinates (``np.uint32``). + - ``"classes"``: dask array of class IDs (``np.uint32``). + - ``"probabilities"``: dask array of detection scores (``np.float32``). + scale_factor (tuple[float, float], optional): + Multiplicative factors applied to the x and y coordinates before + saving. The scaled coordinates are rounded to integer pixel + locations. Defaults to ``(1.0, 1.0)``. + class_dict (dict or None): + Optional mapping of class IDs to class names or remapped IDs. + If ``None``, an identity mapping is used based on the detected + class IDs. + save_path (Path or None): + Destination path for saving the `.db` file. If ``None``, the + resulting SQLiteStore is returned in memory. If provided, the + parent directory is created if needed, and the final store is + written as ``save_path.with_suffix(".db")``. + batch_size (int): + Number of detection records to write per batch. Defaults to ``5000``. + + Returns: + Path or SQLiteStore: + - If `save_path` is provided: the path to the saved `.db` file. + - If `save_path` is ``None``: an in-memory `SQLiteStore` containing + all detections. + + Notes: + - The heavy lifting is delegated to + :meth:`_write_detection_arrays_to_store`, + which performs coordinate scaling, class mapping, and batch writing. + + """ + xs = detection_arrays["x"] + ys = detection_arrays["y"] + classes = detection_arrays["classes"] + probs = detection_arrays["probabilities"] + + xs = np.atleast_1d(np.asarray(xs)) + ys = np.atleast_1d(np.asarray(ys)) + classes = np.atleast_1d(np.asarray(classes)) + probs = np.atleast_1d(np.asarray(probs)) + + if not len(xs) == len(ys) == len(classes) == len(probs): + msg = "Detection record lengths are misaligned." + raise ValueError(msg) + + store = SQLiteStore() + total_written = _write_detection_arrays_to_store( + detection_arrays=(xs, ys, classes, probs), + store=store, + scale_factor=scale_factor, + class_dict=class_dict, + batch_size=batch_size, + ) + logger.info("Total detections written to store: %s", total_written) + + if save_path: + return save_annotations( + save_path=save_path, + store=store, + ) + + return store + + +def _write_detection_arrays_to_store( + detection_arrays: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], + store: SQLiteStore, + scale_factor: tuple[float, float], + class_dict: dict[int, str | int] | None, + batch_size: int = 5000, + *, + verbose: bool = True, +) -> int: + """Write detection arrays to an AnnotationStore in batches. + + Converts coordinate, class, and probability arrays into `Annotation` + objects and appends them to an SQLite-backed store in configurable + batch sizes. Coordinates are scaled to baseline slide resolution using + the provided `scale_factor`, and optional class-ID remapping is applied + via `class_dict`. + + Args: + detection_arrays (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): + Tuple of arrays in the order: + `(x_coords, y_coords, class_ids, probabilities)`. + Each element must be a 1-D NumPy array of equal length. + store (SQLiteStore): + Target `AnnotationStore` instance to receive the detections. + scale_factor (tuple[float, float]): + Factors applied to `(x, y)` coordinates prior to writing, + typically `(model_mpp / slide_mpp)`. The scaled coordinates are + rounded to `np.uint32`. + class_dict (dict[int, str | int] | None): + Optional mapping from original class IDs to names or remapped IDs. + If `None`, an identity mapping is used for the set of present classes. + batch_size (int): + Number of records to write per batch. Default is `5000`. + verbose (bool): + Whether to display logs and progress bar. + + Returns: + int: + Total number of detection records written to the store. + + Notes: + - Coordinates are scaled and rounded to integers to ensure consistent + geometry creation for `Annotation` points. + - Class mapping is applied per-record; unmapped IDs fall back to their + original values. + - Writing in batches reduces memory pressure and improves throughput + on large number of detections. + + """ + xs, ys, classes, probs = detection_arrays + n = len(xs) + if n == 0: + return 0 # nothing to write + + # scale coordinates + xs = np.rint(xs * scale_factor[0]).astype(np.uint32, copy=False) + ys = np.rint(ys * scale_factor[1]).astype(np.uint32, copy=False) + + # class mapping + if class_dict is None: + # identity over actually-present types + uniq = np.unique(classes) + class_dict = {int(k): int(k) for k in uniq} + labels = np.array([class_dict.get(int(k), int(k)) for k in classes], dtype=object) + + def make_points(xs_batch: np.ndarray, ys_batch: np.ndarray) -> list[Point]: + """Create Shapely Point geometries from coordinate arrays in batches.""" + return [ + Point(int(xx), int(yy)) for xx, yy in zip(xs_batch, ys_batch, strict=True) + ] + + tqdm_loop = tqdm( + range(0, n, batch_size), + leave=False, + desc="Writing detections to store", + disable=not verbose, + ) + written = 0 + for i in tqdm_loop: + j = min(i + batch_size, n) + pts = make_points(xs[i:j], ys[i:j]) + + anns = [ + Annotation(geometry=pt, properties={"type": lbl, "probability": float(pp)}) + for pt, lbl, pp in zip(pts, labels[i:j], probs[i:j], strict=True) + ] + store.append_many(anns) + written += j - i + return written From 7441f9932add69fef4181eb122f7b7cdd1c08db0 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sun, 15 Feb 2026 12:07:29 +0000 Subject: [PATCH 148/156] :bug: Fix deepsource error --- tiatoolbox/models/engine/nucleus_detector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index aa16a3070..f0124765d 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -1064,7 +1064,7 @@ class IDs. features: list[dict] = [] - for i in range(len(xs)): + for i, _ in enumerate(xs): # Scale coordinates x = float(xs[i]) * scale_factor[0] y = float(ys[i]) * scale_factor[1] From f14154967a7334ca7231bcf966befa21a4d0ea65 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 16 Feb 2026 13:52:33 +0000 Subject: [PATCH 149/156] :twisted_rightwards_arrows: Merge branch 'dev-define-engines-abc' into `dev-define-multitask-segmentor` # Conflicts: # tests/engines/test_engine_abc.py --- tests/engines/test_engine_abc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index f24e0e85a..3e9f01cc7 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -90,7 +90,7 @@ def test_incorrect_output_type_save_dir() -> None: ): _ = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), - on_gpu=False, + device=device, patch_mode=True, ioconfig=None, output_type="zarr", @@ -102,7 +102,7 @@ def test_incorrect_output_type_save_dir() -> None: ): _ = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), - on_gpu=False, + device=device, patch_mode=True, ioconfig=None, output_type="annotationstore", From 459c9ed700417d0e25bf75f2a2137efc12435986 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 16 Feb 2026 17:05:10 +0000 Subject: [PATCH 150/156] :white_check_mark: Improve tests --- tests/engines/test_semantic_segmentor.py | 31 ++------------ tiatoolbox/models/engine/nucleus_detector.py | 43 ++++++++++---------- 2 files changed, 26 insertions(+), 48 deletions(-) diff --git a/tests/engines/test_semantic_segmentor.py b/tests/engines/test_semantic_segmentor.py index 8a8b0908c..23409cc53 100644 --- a/tests/engines/test_semantic_segmentor.py +++ b/tests/engines/test_semantic_segmentor.py @@ -514,7 +514,7 @@ def test_wsi_segmentor_zarr( assert 0.48 < np.mean(output_["probabilities"][:]) < 0.52 -def test_wsi_segmentor_annotationstore( +def test_wsi_segmentor_annotationstore_qupath( remote_sample: Callable, track_tmp_path: Path, caplog: pytest.CaptureFixture ) -> None: """Test SemanticSegmentor for WSIs with AnnotationStore output.""" @@ -547,6 +547,7 @@ def test_wsi_segmentor_annotationstore( verbose=False, ) # Return Probabilities is True + # Check QuPath output output = segmentor.run( images=[wsi4_512_512_svs], return_probabilities=True, @@ -555,11 +556,11 @@ def test_wsi_segmentor_annotationstore( patch_mode=False, save_dir=track_tmp_path / "wsi_prob_out_check", verbose=True, - output_type="annotationstore", + output_type="QuPath", ) assert output[wsi4_512_512_svs] == track_tmp_path / "wsi_prob_out_check" / ( - wsi4_512_512_svs.stem + ".db" + wsi4_512_512_svs.stem + ".json" ) assert output[wsi4_512_512_svs].with_suffix(".zarr").exists() @@ -568,30 +569,6 @@ def test_wsi_segmentor_annotationstore( assert "Probability maps cannot be saved as AnnotationStore or JSON." in caplog.text -def test_wsi_segmentor_qupath_json(sample_svs: Path, track_tmp_path: Path) -> None: - """Test SemanticSegmentor for WSIs with QuPath JSON output.""" - segmentor = SemanticSegmentor( - model="fcn-tissue_mask", - batch_size=32, - verbose=False, - ) - # Return Probabilities is False - output = segmentor.run( - images=[sample_svs], - return_probabilities=False, - return_labels=False, - device=device, - patch_mode=False, - save_dir=track_tmp_path / "wsi_out_check", - verbose=False, - output_type="QuPath", - ) - - assert output[sample_svs] == track_tmp_path / "wsi_out_check" / ( - sample_svs.stem + ".json" - ) - - def test_prepare_full_batch_low_memory(track_tmp_path: Path) -> None: """Test prepare_full_batch with low memory condition (disk-based zarr).""" # Create mock data diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index f0124765d..3d7ba0504 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -1036,14 +1036,9 @@ class IDs. all detections. """ - xs = np.atleast_1d(np.asarray(detection_arrays["x"])) - ys = np.atleast_1d(np.asarray(detection_arrays["y"])) - classes = np.atleast_1d(np.asarray(detection_arrays["classes"])) - probs = np.atleast_1d(np.asarray(detection_arrays["probabilities"])) - - if not len(xs) == len(ys) == len(classes) == len(probs): - msg = "Detection record lengths are misaligned." - raise ValueError(msg) + xs, ys, classes, probs = _validate_detections_for_saving_to_json( + detection_arrays=detection_arrays, + ) # Determine class dictionary unique_classes = np.unique(classes).tolist() @@ -1155,19 +1150,9 @@ class IDs. which performs coordinate scaling, class mapping, and batch writing. """ - xs = detection_arrays["x"] - ys = detection_arrays["y"] - classes = detection_arrays["classes"] - probs = detection_arrays["probabilities"] - - xs = np.atleast_1d(np.asarray(xs)) - ys = np.atleast_1d(np.asarray(ys)) - classes = np.atleast_1d(np.asarray(classes)) - probs = np.atleast_1d(np.asarray(probs)) - - if not len(xs) == len(ys) == len(classes) == len(probs): - msg = "Detection record lengths are misaligned." - raise ValueError(msg) + xs, ys, classes, probs = _validate_detections_for_saving_to_json( + detection_arrays=detection_arrays, + ) store = SQLiteStore() total_written = _write_detection_arrays_to_store( @@ -1188,6 +1173,22 @@ class IDs. return store +def _validate_detections_for_saving_to_json( + detection_arrays: dict[str, da.Array], +) -> tuple: + """Validates x, y, classes and probs for writing to QuPath or AnnotationStore.""" + xs = np.atleast_1d(np.asarray(detection_arrays["x"])) + ys = np.atleast_1d(np.asarray(detection_arrays["y"])) + classes = np.atleast_1d(np.asarray(detection_arrays["classes"])) + probs = np.atleast_1d(np.asarray(detection_arrays["probabilities"])) + + if not len(xs) == len(ys) == len(classes) == len(probs): + msg = "Detection record lengths are misaligned." + raise ValueError(msg) + + return xs, ys, classes, probs + + def _write_detection_arrays_to_store( detection_arrays: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], store: SQLiteStore, From ecc84e198ab7569f118652307c49a3e11bdeea54 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 16 Feb 2026 18:39:24 +0000 Subject: [PATCH 151/156] :sparkles: Add `qupath` output to multitask segmentor --- tests/engines/test_multi_task_segmentor.py | 39 ++++ .../models/engine/multi_task_segmentor.py | 209 +++++++++++++++--- 2 files changed, 219 insertions(+), 29 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index fda263104..e97792556 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -454,6 +454,45 @@ def test_wsi_segmentor_annotationstore( weights_path.unlink() +def test_wsi_segmentor_qupath(remote_sample: Callable, track_tmp_path: Path) -> None: + """Test MultiTaskSegmentor for WSIs with AnnotationStore output.""" + wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") + # testing different configuration for hovernet. + # kumar only has two probability maps + model_name = "hovernet_fast-pannuke" + mtsegmentor = MultiTaskSegmentor( + model=model_name, + batch_size=32, + verbose=False, + ) + + class_dict = mtsegmentor.model.class_dict + + # Return Probabilities is False + output = mtsegmentor.run( + images=[wsi4_512_512_svs], + return_probabilities=False, + device=device, + patch_mode=False, + save_dir=track_tmp_path / "wsi_out_check", + verbose=True, + output_type="qupath", + class_dict=class_dict, + memory_threshold=0, + ) + + for output_ in output[wsi4_512_512_svs]: + assert output_.suffix != ".zarr" + + json_file_name = f"{wsi4_512_512_svs.stem}.json" + json_file_name = track_tmp_path / "wsi_out_check" / json_file_name + assert json_file_name.exists() + assert json_file_name == output[wsi4_512_512_svs][0] + + weights_path = Path(fetch_pretrained_weights(model_name=model_name)) + weights_path.unlink() + + def test_wsi_segmentor_annotationstore_probabilities( remote_sample: Callable, track_tmp_path: Path, caplog: pytest.CaptureFixture ) -> None: diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 1073302c6..fda7a84a1 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -130,6 +130,9 @@ import torch import zarr from dask import delayed +from matplotlib import pyplot as plt +from shapely import Polygon +from shapely.geometry import mapping from shapely.geometry import shape as feature2geometry from shapely.strtree import STRtree from tqdm.auto import tqdm @@ -142,6 +145,7 @@ from tiatoolbox.utils.misc import ( create_smart_array, make_valid_poly, + save_qupath_json, tqdm_dask_progress_bar, update_tqdm_desc, ) @@ -1609,11 +1613,12 @@ def _save_predictions_as_dict_zarr( ) return save_path - def _save_predictions_as_annotationstore( + def _save_predictions_as_json_store( self: MultiTaskSegmentor, processed_predictions: dict, task_name: str | None = None, save_path: Path | None = None, + output_type: str = "annotationstore", **kwargs: Unpack[MultiTaskSegmentorRunParams], ) -> dict | AnnotationStore | Path | list[Path]: """Helper function to save predictions as annotationstore.""" @@ -1651,13 +1656,14 @@ def _save_predictions_as_annotationstore( if self.patch_mode: for idx, curr_image in enumerate(self.images): values = [processed_predictions[key][idx] for key in keys_to_compute] - output_path = _save_annotation_store( + predictions = dict(zip(keys_to_compute, values, strict=False)) + output_path = _save_annotation_json_store( curr_image=curr_image, - keys_to_compute=keys_to_compute, - values=values, + predictions=predictions, task_name=task_name, idx=idx, save_path=save_path, + output_type=output_type, class_dict=class_dict, scale_factor=scale_factor, num_workers=num_workers, @@ -1667,14 +1673,15 @@ def _save_predictions_as_annotationstore( else: values = [processed_predictions[key] for key in keys_to_compute] + predictions = dict(zip(keys_to_compute, values, strict=False)) save_paths = [ - _save_annotation_store( + _save_annotation_json_store( curr_image=save_path, - keys_to_compute=keys_to_compute, - values=values, + predictions=predictions, task_name=task_name, idx=0, save_path=save_path, + output_type=output_type, class_dict=class_dict, scale_factor=scale_factor, num_workers=num_workers, @@ -1863,10 +1870,11 @@ def save_predictions( **processed_predictions[task_name], "coordinates": processed_predictions["coordinates"], } - out_path = self._save_predictions_as_annotationstore( + out_path = self._save_predictions_as_json_store( processed_predictions=dict_for_store, task_name=task_name, save_path=save_path, + output_type=output_type, **kwargs, ) save_paths += out_path @@ -1875,10 +1883,11 @@ def save_predictions( return save_paths - return self._save_predictions_as_annotationstore( + return self._save_predictions_as_json_store( processed_predictions=processed_predictions, task_name=None, save_path=save_path, + output_type=output_type, **kwargs, ) @@ -2622,13 +2631,13 @@ def _check_and_update_for_memory_overload( return canvas, count, canvas_zarr, count_zarr, tqdm_loop -def _save_annotation_store( +def _save_annotation_json_store( curr_image: Path | None, - keys_to_compute: list[str], - values: list[da.Array | list[da.Array]], + predictions: dict[str, da.Array | list[da.Array]], task_name: str, idx: int, save_path: Path, + output_type: str, class_dict: dict, scale_factor: tuple[float, float], num_workers: int, @@ -2644,27 +2653,21 @@ def _save_annotation_store( ) else: store_file_name = f"{idx}.db" if task_name is None else f"{idx}_{task_name}.db" - predictions_ = dict(zip(keys_to_compute, values, strict=False)) output_path = save_path.parent / store_file_name # Patch mode indexes the "coordinates" while calculating "values" variable. origin = (0.0, 0.0) - _ = predictions_.pop("coordinates") - store = SQLiteStore() - store = dict_to_store( - store=store, - processed_predictions=predictions_, + _ = predictions.pop("coordinates") + return dict_to_json_store( + processed_predictions=predictions, class_dict=class_dict, scale_factor=scale_factor, origin=origin, num_workers=num_workers, verbose=verbose, + output_path=output_path, + output_type=output_type, ) - store.commit() - store.dump(output_path) - - return output_path - def _process_instance_predictions( inst_dict: dict, @@ -3097,16 +3100,17 @@ def _compute_info_dict_for_merge( ) -def dict_to_store( - store: SQLiteStore, +def dict_to_json_store( processed_predictions: dict, + output_path: str | Path, + output_type: str, class_dict: dict | None = None, origin: tuple[float, float] = (0, 0), scale_factor: tuple[float, float] = (1, 1), num_workers: int = multiprocessing.cpu_count(), *, verbose: bool = True, -) -> AnnotationStore: +) -> Path: """Write polygonal multitask predictions into an SQLite-backed AnnotationStore. Converts a task dictionary (with per-object fields) into `Annotation` records, @@ -3126,12 +3130,14 @@ def dict_to_store( with list-like values aligned to `contours` length. Args: - store (SQLiteStore): - Target annotation store that will receive the converted annotations. processed_predictions (dict): Dictionary containing per-object fields. Must include `"contours"`; may include `"geom_type"` and any number of additional fields to be written as properties. + output_path (str | Path): + Path to save the output. + output_type (str): + Desired output format: "qupath" or "annotationstore". class_dict (dict | None): Optional mapping for the `"type"` field. When provided and when `"type"` is present in `processed_predictions`, each `"type"` value is @@ -3177,7 +3183,16 @@ def dict_to_store( processed_predictions=processed_predictions, ) - return delayed_tasks.compute_annotations( + if output_type == "qupath": + return delayed_tasks.compute_qupath_json( + class_dict={0: "Tumor", 1: "Stroma"}, + origin=(0, 0), + scale_factor=scale_factor, + save_path=output_path.with_suffix(".json"), + ) + + store = SQLiteStore() + store = delayed_tasks.compute_annotations( store=store, class_dict=class_dict, origin=origin, @@ -3187,6 +3202,11 @@ def dict_to_store( verbose=verbose, ) + store.commit() + store.dump(output_path) + + return output_path + class DaskDelayedAnnotationStore: """Compute and write TIAToolbox annotations using batched Dask Delayed tasks. @@ -3276,6 +3296,64 @@ def _build_single_annotation( return Annotation(geom, properties) + def _build_single_qupath_feature( + self: DaskDelayedAnnotationStore, + i: int, + class_dict: dict | None, + origin: tuple[float, float], + scale_factor: tuple[float, float], + class_colours: dict, + ) -> dict: + contour = np.array(self._contours[i], dtype=float) + contour[:, 0] = contour[:, 0] * scale_factor[0] + origin[0] + contour[:, 1] = contour[:, 1] * scale_factor[1] + origin[1] + + poly = Polygon(contour) + poly_geo = mapping(poly) + + props = {} + class_value = None + class_name = None + + for key, arr in self._processed_predictions.items(): + value = arr[i] + if key == "type": + # Convert numpy/zarr scalar to Python + value = value.tolist() if hasattr(value, "tolist") else value + + # Safe class lookup + if class_dict is not None and value in class_dict: + class_name = class_dict[value] + else: + class_name = value # keep raw value + + if class_name is not None: + props["type"] = class_name + class_value = value + + else: + if value is None: + continue + props[key] = np.array(value).tolist() + + # Classification block + if class_name is not None and class_value in class_colours: + color = class_colours[class_value] + props["classification"] = { + "name": class_name, + "color": color, + } + props["class_value"] = class_value + + return { + "type": "Feature", + "id": f"object_{i}", + "geometry": poly_geo, + "properties": props, + "objectType": "annotation", + "name": class_name if class_name is not None else "object", + } + def compute_annotations( self: DaskDelayedAnnotationStore, store: SQLiteStore, @@ -3355,3 +3433,76 @@ def compute_annotations( ) ) return store + + def compute_qupath_json( + self: DaskDelayedAnnotationStore, + class_dict: dict[int, str] | None, + origin: tuple[float, float] = (0, 0), + scale_factor: tuple[float, float] = (1, 1), + save_path: Path | None = None, + batch_size: int = 100, + num_workers: int = 0, + *, + verbose: bool = True, + ) -> Path: + """Compute annotations in batches and return/save QuPath JSON.""" + num_contours = len(self._contours) + features: list[dict] = [] + + if class_dict is None: + type_arr = self._processed_predictions.get("type") + if type_arr is not None: + max_class = int(type_arr.max()) + class_dict = {i: i for i in range(max_class + 1)} + else: + class_dict = {0: 0} + + # Enumerate class_dict keys to assign stable integer color indices + class_keys = list(class_dict.keys()) + num_classes = len(class_keys) + cmap = plt.cm.get_cmap("tab20", num_classes) + + class_colours = { + key: [ + int(cmap(i)[0] * 255), + int(cmap(i)[1] * 255), + int(cmap(i)[2] * 255), + ] + for i, key in enumerate(class_keys) + } + + # Batch processing (mirrors compute_annotations) + for batch_id in tqdm( + range(0, num_contours, batch_size), + leave=False, + desc="Calculating QuPath features in batches.", + disable=not verbose, + ): + delayed_tasks = [ + delayed(self._build_single_qupath_feature)( + i, + class_dict, + origin, + scale_factor, + class_colours, + ) + for i in tqdm( + range(batch_id, min(batch_id + batch_size, num_contours)), + leave=False, + desc="Creating delayed tasks for QuPath JSON", + disable=not verbose, + ) + ] + + # Compute batch immediately + batch_features = tqdm_dask_progress_bar( + write_tasks=delayed_tasks, + desc="Computing QuPath features", + verbose=verbose, + num_workers=num_workers, + ) + features.extend(batch_features) + + qupath_json = {"type": "FeatureCollection", "features": features} + + return save_qupath_json(save_path=save_path, qupath_json=qupath_json) From b4991a25ac996403b83bbb987a417e1edb2e68de Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 17 Feb 2026 09:30:18 +0000 Subject: [PATCH 152/156] :bulb: Address Co-Pilot comments. --- tiatoolbox/models/engine/engine_abc.py | 2 +- tiatoolbox/models/engine/nucleus_detector.py | 4 ++-- tiatoolbox/models/engine/nucleus_instance_segmentor.py | 2 +- tiatoolbox/models/engine/semantic_segmentor.py | 3 ++- tiatoolbox/utils/misc.py | 2 +- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index f4720a587..b84d92169 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -727,7 +727,7 @@ def save_predictions( return processed_predictions if output_type.lower() in ["qupath", "annotationstore"]: - suffix = "output.json" if output_type.lower() == "qupath" else ".db" + suffix = ".json" if output_type.lower() == "qupath" else ".db" save_path = Path( kwargs.get("output_file", save_path.parent / ("output" + suffix)) ) diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index 3d7ba0504..adc81c85a 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -1065,7 +1065,7 @@ class IDs. y = float(ys[i]) * scale_factor[1] class_id = int(classes[i]) - class_label = class_dict[class_id] + class_label = class_dict.get(class_id, class_id) prob = float(probs[i]) # QuPath point geometry @@ -1081,7 +1081,7 @@ class IDs. "properties": { "classification": { "name": class_label, - "color": class_colours[class_id], + "color": class_colours.get(class_id, class_id), }, "probability": prob, }, diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 4b2d7bb63..9369707c5 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -124,7 +124,7 @@ class NucleusInstanceSegmentor(MultiTaskSegmentor): """ def __init__( - self: MultiTaskSegmentor, + self: NucleusInstanceSegmentor, model: str | ModelABC, batch_size: int = 8, num_workers: int = 0, diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 2cb74f424..a6aa1526f 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -702,8 +702,9 @@ def save_predictions( # Need to add support for zarr conversion. save_paths = [] - logger.info("Saving predictions as AnnotationStore.") suffix = ".json" if output_type.lower() == "qupath" else ".db" + msg = f"Saving predictions as f{output_type} in {suffix} format." + logger.info(msg) if self.patch_mode: for i, predictions in enumerate(processed_predictions["predictions"]): if isinstance(self.images[i], Path): diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index f414c6afd..c6b084f97 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1707,7 +1707,7 @@ def dict_to_store_patch_predictions( preds=preds, class_dict=class_dict, patch_coords=patch_coords, - verbose=True, + verbose=verbose, ) if save_path: From c3a6c65a9232f7d4a0cc636fec106a2b15936360 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 17 Feb 2026 09:46:43 +0000 Subject: [PATCH 153/156] :white_check_mark: Add tests for coverage. --- tests/test_utils.py | 115 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index e814f910a..8fbc44f5e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -36,7 +36,13 @@ ) from tiatoolbox.utils import misc from tiatoolbox.utils.exceptions import FileNotSupportedError -from tiatoolbox.utils.misc import cast_to_min_dtype, create_smart_array +from tiatoolbox.utils.misc import ( + _semantic_segmentations_as_qupath_json, + _tiles, + cast_to_min_dtype, + create_smart_array, + dict_to_store_patch_predictions, +) from tiatoolbox.utils.transforms import locsize2bounds if TYPE_CHECKING: @@ -2343,3 +2349,110 @@ class FakeVM: assert isinstance(arr, np.ndarray) assert arr.shape == shape assert arr.dtype == dtype + + +def test_tiles_zero_iterations() -> None: + """Test helper function with zero iterations.""" + in_img = np.zeros((0, 0), dtype=np.uint8) + + tile_size = (32, 32) # larger than the image + colormap = 2 # arbitrary valid OpenCV colormap + + tile_iter = _tiles(in_img, tile_size, colormap=colormap, level=0) + + tiles = list(tile_iter) + + assert tiles == [] # no tiles generated + + +def test_semantic_segmentation_returns_json_dict() -> None: + """Test for semantic_segmentation QuPath JSON dict.""" + # Fake 4 x 4 prediction map with two classes: 0 and 1 + preds_np = np.array( + [ + [0, 0, 1, 1], + [0, 0, 1, 1], + [0, 0, 1, 1], + [0, 0, 1, 1], + ], + dtype=np.uint8, + ) + + preds = da.from_array(preds_np, chunks=(4, 4)) + + layer_list = [0, 1] # two classes + scale_factor = (1.0, 1.0) + class_dict = {0: "Background", 1: "Tumor"} + + qupath_json = _semantic_segmentations_as_qupath_json( + layer_list=layer_list, + preds=preds, + scale_factor=scale_factor, + class_dict=class_dict, + save_path=None, + verbose=False, + ) + + # --- Assert --- + assert isinstance(qupath_json, dict) + assert "type" in qupath_json + assert qupath_json["type"] == "FeatureCollection" + assert "features" in qupath_json + assert isinstance(qupath_json["features"], list) + assert len(qupath_json["features"]) > 0 + + for feature in qupath_json["features"]: + assert feature["properties"]["classification"]["name"] in class_dict.values() + assert feature["properties"]["classification"]["color"] is not None + assert feature["name"] in class_dict.values() + assert feature["class_value"] in class_dict + + +def test_dict_to_store_patch_predictions_returns_qupath_json() -> None: + """Test for dict_to_store_patch_predictions QuPath JSON dict.""" + # Fake patch output + patch_output = { + "predictions": np.array([0, 1, 0, 1], dtype=np.uint8), + "coordinates": np.array( + [ + [0, 0, 10, 10], + [10, 0, 20, 10], + [0, 10, 10, 20], + [10, 10, 20, 20], + ] + ), + "labels": np.array([0, 1, 0, 1]), + } + + scale_factor = (1.0, 1.0) + class_dict = {0: "Background", 1: "Tumor"} + + result = dict_to_store_patch_predictions( + patch_output=patch_output, + scale_factor=scale_factor, + class_dict=class_dict, + save_path=None, + output_type="qupath", + verbose=False, + ) + + assert isinstance(result, dict) + + assert "type" in result + assert result["type"] == "FeatureCollection" + assert "features" in result + assert isinstance(result["features"], list) + + assert len(result["features"]) > 0 + + for feature in result["features"]: + assert feature["type"] == "Feature" + assert "geometry" in feature + assert "properties" in feature + assert "classification" in feature["properties"] + assert "name" in feature + assert "class_value" in feature + + assert feature["class_value"] in class_dict + assert feature["properties"]["classification"]["name"] in class_dict.values() + assert feature["properties"]["classification"]["color"] is not None From acafdd94650ae380783bb44acf69fe2311df121c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 17 Feb 2026 12:56:31 +0000 Subject: [PATCH 154/156] :white_check_mark: Add tests to validate QuPath output. --- tests/engines/test_multi_task_segmentor.py | 192 +++++++++++++++++- tiatoolbox/models/engine/engine_abc.py | 2 +- .../models/engine/multi_task_segmentor.py | 96 ++++++--- 3 files changed, 258 insertions(+), 32 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index e97792556..fb80b5309 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import shutil from pathlib import Path from typing import TYPE_CHECKING, Any, Final @@ -140,9 +141,10 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N assert len(output_ann) == 6 + fields_nuclei = ["box", "centroid", "contours", "prob", "type"] + fields_layer = ["contours", "type"] + for task_name in mtsegmentor.tasks: - fields_nuclei = ["box", "centroid", "contours", "prob", "type"] - fields_layer = ["contours", "type"] fields = fields_nuclei if task_name == "nuclei_segmentation" else fields_layer output_ann_ = [p for p in output_ann if p.name.endswith(f"{task_name}.db")] expected_counts = ( @@ -161,6 +163,37 @@ def test_mtsegmentor_patches(remote_sample: Callable, track_tmp_path: Path) -> N class_dict=mtsegmentor.model.class_dict, ) + # QuPath JSON does not have fields + fields_nuclei = ["contours", "prob", "type"] + # QuPath output comparison + output_json = mtsegmentor.run( + images=patches, + patch_mode=True, + device=device, + output_type="QuPath", + save_dir=track_tmp_path / "patch_output_qupath", + ) + + assert len(output_json) == 6 + + for task_name in mtsegmentor.tasks: + fields = fields_nuclei if task_name == "nuclei_segmentation" else fields_layer + output_json_ = [p for p in output_json if p.name.endswith(f"{task_name}.json")] + expected_counts = ( + expected_counts_nuclei + if task_name == "nuclei_segmentation" + else expected_counts_layer + ) + assert_qupath_json_patch_output( + inputs=patches, + output_json=output_json_, + output_dict=output_dict[task_name], + track_tmp_path=track_tmp_path, + fields=fields, + expected_counts=expected_counts, + task_name=task_name, + ) + def test_single_task_mtsegmentor( remote_sample: Callable, @@ -278,7 +311,49 @@ def test_single_task_mtsegmentor( for field in fields: assert field not in zarr_group - assert "Probability maps cannot be saved as AnnotationStore" in caplog.text + assert "Probability maps cannot be saved as AnnotationStore or JSON" in caplog.text + + # QuPath output comparison + + # Reinitialize to check for probabilities in output. + mtsegmentor.drop_keys = [] + output_json = mtsegmentor.run( + images=inputs, + patch_mode=True, + device=device, + output_type="QuPath", + save_dir=track_tmp_path / "patch_output_qupath", + return_probabilities=True, + ) + + assert len(output_json) == 3 + + assert_qupath_json_patch_output( + inputs=inputs, + output_json=output_json, + output_dict=output_dict, + track_tmp_path=track_tmp_path, + fields=["box", "centroid", "contours", "prob", "type"], + expected_counts=expected_counts_nuclei, + task_name=None, + ) + + zarr_file = track_tmp_path / "patch_output_qupath" / "output.zarr" + + assert zarr_file.exists() + + zarr_group = zarr.open( + str(zarr_file), + mode="r", + ) + + assert "probabilities" in zarr_group + + fields = ["box", "centroid", "contours", "prob", "type", "predictions"] + for field in fields: + assert field not in zarr_group + + assert "Probability maps cannot be saved as AnnotationStore or JSON" in caplog.text def test_wsi_mtsegmentor_zarr( @@ -450,16 +525,14 @@ def test_wsi_segmentor_annotationstore( assert store_file_path.exists() assert store_file_path == output[wsi4_512_512_svs][0] - weights_path = Path(fetch_pretrained_weights(model_name=model_name)) - weights_path.unlink() - def test_wsi_segmentor_qupath(remote_sample: Callable, track_tmp_path: Path) -> None: """Test MultiTaskSegmentor for WSIs with AnnotationStore output.""" wsi4_512_512_svs = remote_sample("wsi4_512_512_svs") # testing different configuration for hovernet. # kumar only has two probability maps - model_name = "hovernet_fast-pannuke" + # Need to test Null values in JSON output. + model_name = "hovernet_original-kumar" mtsegmentor = MultiTaskSegmentor( model=model_name, batch_size=32, @@ -489,6 +562,7 @@ def test_wsi_segmentor_qupath(remote_sample: Callable, track_tmp_path: Path) -> assert json_file_name.exists() assert json_file_name == output[wsi4_512_512_svs][0] + # Weights not used after this test weights_path = Path(fetch_pretrained_weights(model_name=model_name)) weights_path.unlink() @@ -515,7 +589,7 @@ def test_wsi_segmentor_annotationstore_probabilities( output_type="annotationstore", ) - assert "Probability maps cannot be saved as AnnotationStore." in caplog.text + assert "Probability maps cannot be saved as AnnotationStore or JSON." in caplog.text zarr_group = zarr.open(output[wsi4_512_512_svs][0], mode="r") assert "probabilities" in zarr_group @@ -787,6 +861,108 @@ def assert_annotation_store_patch_output( assert annotations_list == [] +def assert_qupath_json_patch_output( + inputs: list | np.ndarray, + output_json: list[Path], + task_name: str | None, + track_tmp_path: Path, + expected_counts: Sequence[int], + output_dict: dict, + fields: list[str], +) -> None: + """Helper function to test QuPath JSON output.""" + for patch_idx, json_path in enumerate(output_json): + # --- 1. Verify filename matches expected pattern --- + if isinstance(inputs[patch_idx], Path): + file_name = ( + f"{inputs[patch_idx].stem}.json" + if task_name is None + else f"{inputs[patch_idx].stem}_{task_name}.json" + ) + else: + file_name = ( + f"{patch_idx}.json" + if task_name is None + else f"{patch_idx}_{task_name}.json" + ) + + assert json_path == track_tmp_path / "patch_output_qupath" / file_name + + # --- 2. Load JSON --- + with Path.open(json_path, "r") as f: + qupath_json = json.load(f) + + features = qupath_json.get("features", []) + assert isinstance(features, list) + + # --- 3. Zero-object case --- + if expected_counts[patch_idx] == 0: + assert len(features) == 0 + continue + + # --- 4. Non-zero case --- + assert len(features) == expected_counts[patch_idx] + + # Extract results from JSON + result = {field: [] for field in fields} + + for feat in features: + props = feat.get("properties", {}) + + # non-geometric fields (box, centroid, prob, type, etc.) + for field in fields: + if field == "contours": + continue + if field in props: + result[field].append(props[field]) + + # contours from geometry + if "contours" in fields: + geom = feat["geometry"] + coords = geom["coordinates"][0] # exterior ring + coords = [(int(x), int(y)) for x, y in coords] + result["contours"].append(coords) + + # Wrap for compatibility with assert_output_lengths + result_wrapped = {field: [result[field]] for field in fields} + + # --- 5. Length check --- + assert_output_lengths( + result_wrapped, + expected_counts=[expected_counts[patch_idx]], + fields=fields, + ) + + # --- 6. Equality check for non-contour fields --- + fields_no_contours = fields.copy() + if "contours" in fields_no_contours: + fields_no_contours.remove("contours") + + assert_output_equal( + result_wrapped, + output_dict, + fields=fields_no_contours, + indices_a=[0], + indices_b=[patch_idx], + ) + + # --- 7. Contour comparison --- + if "contours" in fields: + matches = [] + for a, b in zip( + result["contours"], + output_dict["contours"][patch_idx], + strict=False, + ): + # Discard last point (closed polygon) + a_arr = np.array(a[:-1], dtype=int) + b_arr = np.array(b, dtype=int) + matches.append(np.array_equal(a_arr, b_arr)) + + # Allow small geometric differences + assert sum(matches) / len(matches) >= 0.95 + + # ------------------------------------------------------------------------------------- # Command Line Interface # ------------------------------------------------------------------------------------- diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index b84d92169..47de20caf 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1427,7 +1427,7 @@ def _run_patch_mode( ) raw_predictions = self.infer_patches( dataloader=self.dataloader, - return_coordinates=output_type in ["annotationstore", "qupath"], + return_coordinates=output_type.lower() in ["annotationstore", "qupath"], ) processed_predictions = self.post_process_patches( diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index fda7a84a1..83d396a28 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -1695,7 +1695,7 @@ def _save_predictions_as_json_store( return_probabilities = kwargs.get("return_probabilities", False) if return_probabilities: msg = ( - f"Probability maps cannot be saved as AnnotationStore. " + f"Probability maps cannot be saved as AnnotationStore or JSON. " f"To visualise heatmaps in TIAToolbox Visualization tool," f"convert heatmaps in {save_path} to ome.tiff using" f"tiatoolbox.utils.misc.write_probability_heatmap_as_ome_tiff." @@ -2653,7 +2653,8 @@ def _save_annotation_json_store( ) else: store_file_name = f"{idx}.db" if task_name is None else f"{idx}_{task_name}.db" - output_path = save_path.parent / store_file_name + suffix = ".json" if output_type.lower() == "qupath" else ".db" + output_path = (save_path.parent / store_file_name).with_suffix(suffix) # Patch mode indexes the "coordinates" while calculating "values" variable. origin = (0.0, 0.0) _ = predictions.pop("coordinates") @@ -3178,16 +3179,19 @@ def dict_to_json_store( for key, arr in processed_predictions.items() } contours = processed_predictions.pop("contours") - delayed_tasks = DaskDelayedAnnotationStore( + delayed_tasks = DaskDelayedJSONStore( contours=contours, processed_predictions=processed_predictions, ) - if output_type == "qupath": + if output_type.lower() == "qupath": return delayed_tasks.compute_qupath_json( - class_dict={0: "Tumor", 1: "Stroma"}, - origin=(0, 0), + class_dict=class_dict, + origin=origin, scale_factor=scale_factor, + batch_size=100, + num_workers=num_workers, + verbose=verbose, save_path=output_path.with_suffix(".json"), ) @@ -3208,7 +3212,7 @@ def dict_to_json_store( return output_path -class DaskDelayedAnnotationStore: +class DaskDelayedJSONStore: """Compute and write TIAToolbox annotations using batched Dask Delayed tasks. This class parallelizes annotation construction using Dask Delayed while @@ -3219,7 +3223,7 @@ class DaskDelayedAnnotationStore: """ def __init__( - self: DaskDelayedAnnotationStore, + self: DaskDelayedJSONStore, contours: np.ndarray, processed_predictions: dict, ) -> None: @@ -3242,7 +3246,7 @@ def __init__( self._processed_predictions = processed_predictions def _build_single_annotation( - self: DaskDelayedAnnotationStore, + self: DaskDelayedJSONStore, i: int, class_dict: dict[int, str] | None, origin: tuple[float, float], @@ -3297,13 +3301,44 @@ def _build_single_annotation( return Annotation(geom, properties) def _build_single_qupath_feature( - self: DaskDelayedAnnotationStore, + self: DaskDelayedJSONStore, i: int, class_dict: dict | None, origin: tuple[float, float], scale_factor: tuple[float, float], class_colours: dict, ) -> dict: + """Build a single feature for index ``i``. + + This method performs: + - geometry creation + - coordinate scaling and translation + - per-object property extraction + - optional class label mapping + + Args: + i (int): + Index of the object to convert into an annotation. + + class_dict (dict[int, str] | None): + Optional mapping from integer class IDs to string labels. + If ``None``, raw integer class IDs are used. + + origin (tuple[float, float]): + Translation offset ``(x, y)`` applied after scaling. + + scale_factor (tuple[float, float]): + Scaling factors ``(sx, sy)`` applied to contour coordinates. + + class_colours (dict): + Maps classes to specific colors. + + Returns: + dict: + A fully constructed Feature dictionary instance for writing + to QuPath JSON. + + """ contour = np.array(self._contours[i], dtype=float) contour[:, 0] = contour[:, 0] * scale_factor[0] + origin[0] contour[:, 1] = contour[:, 1] * scale_factor[1] + origin[1] @@ -3316,21 +3351,26 @@ def _build_single_qupath_feature( class_name = None for key, arr in self._processed_predictions.items(): - value = arr[i] + value = arr[i].tolist() if hasattr(arr[i], "tolist") else arr[i] + if key == "type": - # Convert numpy/zarr scalar to Python - value = value.tolist() if hasattr(value, "tolist") else value + # Handle None class name + if value is None: + # Assign default class 0 + class_value = 0 + class_name = class_dict.get(0, 0) + props["type"] = class_name + continue # Safe class lookup if class_dict is not None and value in class_dict: class_name = class_dict[value] else: - class_name = value # keep raw value - - if class_name is not None: - props["type"] = class_name - class_value = value + # Already a name or no mapping available + class_name = value + props["type"] = class_name + class_value = value else: if value is None: continue @@ -3355,7 +3395,7 @@ def _build_single_qupath_feature( } def compute_annotations( - self: DaskDelayedAnnotationStore, + self: DaskDelayedJSONStore, store: SQLiteStore, class_dict: dict[int, str] | None, origin: tuple[float, float] = (0, 0), @@ -3435,7 +3475,7 @@ def compute_annotations( return store def compute_qupath_json( - self: DaskDelayedAnnotationStore, + self: DaskDelayedJSONStore, class_dict: dict[int, str] | None, origin: tuple[float, float] = (0, 0), scale_factor: tuple[float, float] = (1, 1), @@ -3451,11 +3491,21 @@ def compute_qupath_json( if class_dict is None: type_arr = self._processed_predictions.get("type") - if type_arr is not None: - max_class = int(type_arr.max()) + + # Extract only valid class IDs/names + valid_ids = [v for v in type_arr if v is not None] + + if len(valid_ids) == 0: + # No class info at all → fallback + class_dict = {0: 0} + # Numeric class IDs + elif all(isinstance(v, (int, np.integer)) for v in valid_ids): + max_class = int(max(valid_ids)) class_dict = {i: i for i in range(max_class + 1)} else: - class_dict = {0: 0} + # Already class names + unique_names = sorted(set(valid_ids)) + class_dict = {name: name for name in unique_names} # Enumerate class_dict keys to assign stable integer color indices class_keys = list(class_dict.keys()) From be10ef47bbbb056cfe4d3668969e57dfe38390e3 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 17 Feb 2026 13:24:47 +0000 Subject: [PATCH 155/156] :bug: Fix tests --- tests/engines/test_engine_abc.py | 42 -------------------------------- 1 file changed, 42 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 2e7ed9395..939b2de97 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -401,48 +401,6 @@ def test_patch_pred_zarr_store(track_tmp_path: pytest.TempPathFactory) -> None: ) assert Path.exists(out), "Zarr output file does not exist" - eng = TestEngineABC(model="alexnet-kather100k") - with pytest.raises( - ValueError, - match=r".*Patch output must contain coordinates.", - ): - _ = eng.run( - images=np.zeros((10, 224, 224, 3), dtype=np.uint8), - labels=list(range(10)), - device=device, - save_dir=save_dir, - overwrite=True, - output_type="AnnotationStore", - ) - - with pytest.raises( - ValueError, - match=r".*Patch output must contain coordinates.", - ): - _ = eng.run( - images=np.zeros((10, 224, 224, 3), dtype=np.uint8), - labels=list(range(10)), - device=device, - save_dir=save_dir, - overwrite=True, - output_type="AnnotationStore", - class_dict={0: "class0", 1: "class1"}, - ) - - with pytest.raises( - ValueError, - match=r".*Patch output must contain coordinates.", - ): - _ = eng.run( - images=np.zeros((10, 224, 224, 3), dtype=np.uint8), - labels=list(range(10)), - device=device, - save_dir=save_dir, - overwrite=True, - output_type="AnnotationStore", - scale_factor=(2.0, 2.0), - ) - def test_get_dataloader(sample_svs: Path) -> None: """Test the get_dataloader function.""" From 82d84548819f03463a78b015f6047f0ca890bed8 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 17 Feb 2026 13:31:54 +0000 Subject: [PATCH 156/156] :bug: Fix deepsource error. --- tests/engines/test_multi_task_segmentor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index fb80b5309..2721a70df 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -861,7 +861,7 @@ def assert_annotation_store_patch_output( assert annotations_list == [] -def assert_qupath_json_patch_output( +def assert_qupath_json_patch_output( # skipcq: PY-R1000 inputs: list | np.ndarray, output_json: list[Path], task_name: str | None,