diff --git a/monailabel/datastore/cvat.py b/monailabel/datastore/cvat.py index 461883ea0..0fb293bc5 100644 --- a/monailabel/datastore/cvat.py +++ b/monailabel/datastore/cvat.py @@ -16,6 +16,7 @@ import tempfile import time import urllib.parse +from typing import Any, Dict import numpy as np import requests @@ -318,6 +319,22 @@ def download_from_cvat(self, max_retry_count=5, retry_wait_time=10): retry_count += 1 return None + def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str: + """Not Implemented""" + raise NotImplementedError("This datastore does not support adding directories") + + def get_is_multichannel(self) -> bool: + """ + Returns whether the application's studies is directed at multichannel (4D) data + """ + raise NotImplementedError("This datastore does not support multichannel imaging") + + def get_is_multi_file(self) -> bool: + """ + Returns whether the application's studies is directed at directories containing multiple images per sample + """ + raise NotImplementedError("This datastore does not support support multi-volume imaging") + """ def main(): diff --git a/monailabel/datastore/dicom.py b/monailabel/datastore/dicom.py index ed4733ca6..0ebe7de2f 100644 --- a/monailabel/datastore/dicom.py +++ b/monailabel/datastore/dicom.py @@ -264,3 +264,19 @@ def _download_labeled_data(self): def datalist(self, full_path=True) -> List[Dict[str, Any]]: self._download_labeled_data() return super().datalist(full_path) + + def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str: + """Not Implemented""" + raise NotImplementedError("This datastore does not support adding directories") + + def get_is_multichannel(self) -> bool: + """ + Returns whether the application's studies is directed at multichannel (4D) data + """ + raise NotImplementedError("This datastore does not support multichannel imaging") + + def get_is_multi_file(self) -> bool: + """ + Returns whether the application's studies is directed at directories containing multiple images per sample + """ + raise NotImplementedError("This datastore does not support support multi-volume imaging") diff --git a/monailabel/datastore/dsa.py b/monailabel/datastore/dsa.py index 365cef24e..0184dc4cb 100644 --- a/monailabel/datastore/dsa.py +++ b/monailabel/datastore/dsa.py @@ -270,6 +270,22 @@ def status(self) -> Dict[str, Any]: def json(self): return self.datalist() + def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str: + """Not Implemented""" + raise NotImplementedError("This datastore does not support adding directories") + + def get_is_multichannel(self) -> bool: + """ + Returns whether the application's studies is directed at multichannel (4D) data + """ + raise NotImplementedError("This datastore does not support multichannel imaging") + + def get_is_multi_file(self) -> bool: + """ + Returns whether the application's studies is directed at directories containing multiple images per sample + """ + raise NotImplementedError("This datastore does not support support multi-volume imaging") + """ def main(): diff --git a/monailabel/datastore/local.py b/monailabel/datastore/local.py index d8b0538aa..cbe0a4134 100644 --- a/monailabel/datastore/local.py +++ b/monailabel/datastore/local.py @@ -102,9 +102,11 @@ def __init__( images_dir: str = ".", labels_dir: str = "labels", datastore_config: str = "datastore_v2.json", - extensions=("*.nii.gz", "*.nii"), + extensions=("*.nii.gz", "*.nii", "*.nrrd"), auto_reload=False, read_only=False, + multichannel: bool = False, + multi_file: bool = False, ): """ Creates a `LocalDataset` object @@ -124,6 +126,8 @@ def __init__( self._ignore_event_config = False self._config_ts = 0 self._auto_reload = auto_reload + self._multichannel: bool = multichannel + self._multi_file: bool = multi_file logging.getLogger("filelock").setLevel(logging.ERROR) @@ -256,6 +260,12 @@ def datalist(self, full_path=True) -> List[Dict[str, Any]]: ds = json.loads(json.dumps(ds).replace(f"{self._datastore_path.rstrip(os.pathsep)}{os.pathsep}", "")) return ds + def get_is_multichannel(self) -> bool: + return self._multichannel + + def get_is_multi_file(self) -> bool: + return self._multi_file + def get_image(self, image_id: str, params=None) -> Any: """ Retrieve image object based on image id @@ -431,6 +441,29 @@ def refresh(self): """ self._reconcile_datastore() + def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str: + id = os.path.basename(filename) + if not directory_id: + directory_id = id + + logger.info(f"Adding Image: {directory_id} => {filename}") + name = directory_id + dest = os.path.realpath(os.path.join(self._datastore.image_path(), name)) + + with FileLock(self._lock_file): + logger.debug("Acquired the lock!") + shutil.copy(filename, dest) + + info = info if info else {} + info["ts"] = int(time.time()) + info["name"] = name + + # images = get_directory_contents(filename) + self._datastore.objects[directory_id] = ImageLabelModel(image=DataModel(info=info, ext="")) + self._update_datastore_file(lock=False) + logger.debug("Released the lock!") + return directory_id + def add_image(self, image_id: str, image_filename: str, image_info: Dict[str, Any]) -> str: id, image_ext = self._to_id(os.path.basename(image_filename)) if not image_id: @@ -552,10 +585,15 @@ def _list_files(self, path, patterns): files = os.listdir(path) filtered = dict() - for pattern in patterns: - matching = fnmatch.filter(files, pattern) - for file in matching: - filtered[os.path.basename(file)] = file + if not self._multi_file: + for pattern in patterns: + matching = fnmatch.filter(files, pattern) + for file in matching: + filtered[os.path.basename(file)] = file + else: + for file in files: + if file.lower() not in ["labels", ".lock", "datastore_v2.json"]: + filtered[os.path.basename(file)] = file return filtered def _reconcile_datastore(self): @@ -585,23 +623,26 @@ def _add_non_existing_images(self) -> int: invalidate = 0 self._init_from_datastore_file() - local_images = self._list_files(self._datastore.image_path(), self._extensions) + local_files = self._list_files(self._datastore.image_path(), self._extensions) - image_ids = list(self._datastore.objects.keys()) - for image_file in local_images: - image_id, image_ext = self._to_id(image_file) - if image_id not in image_ids: - logger.info(f"Adding New Image: {image_id} => {image_file}") + ids = list(self._datastore.objects.keys()) + for file in local_files: + if self._multi_file: + # Directories have no extension — use the name as-is + file_id = file + file_ext_str = "" + else: + file_id, file_ext_str = self._to_id(file) - name = self._filename(image_id, image_ext) - image_info = { + if file_id not in ids: + logger.info(f"Adding New Image: {file_id} => {file}") + name = self._filename(file_id, file_ext_str) + file_info = { "ts": int(time.time()), - # "checksum": file_checksum(os.path.join(self._datastore.image_path(), name)), "name": name, } - invalidate += 1 - self._datastore.objects[image_id] = ImageLabelModel(image=DataModel(info=image_info, ext=image_ext)) + self._datastore.objects[file_id] = ImageLabelModel(name=DataModel(info=file_info, ext=file_ext_str)) return invalidate diff --git a/monailabel/datastore/xnat.py b/monailabel/datastore/xnat.py index c0904fd8d..895aaeb15 100644 --- a/monailabel/datastore/xnat.py +++ b/monailabel/datastore/xnat.py @@ -386,6 +386,22 @@ def __upload_assessment(self, aiaa_model_name, image_id, file_path, type): self._request_put(url, data, type=type) + def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str: + """Not Implemented""" + raise NotImplementedError("This datastore does not support adding directories") + + def get_is_multichannel(self) -> bool: + """ + Returns whether the application's studies is directed at multichannel (4D) data + """ + raise NotImplementedError("This datastore does not support multichannel imaging") + + def get_is_multi_file(self) -> bool: + """ + Returns whether the application's studies is directed at directories containing multiple images per sample + """ + raise NotImplementedError("This datastore does not support support multi-volume imaging") + """ def main(): diff --git a/monailabel/endpoints/datastore.py b/monailabel/endpoints/datastore.py index fdd63bb6e..2cf26bb6f 100644 --- a/monailabel/endpoints/datastore.py +++ b/monailabel/endpoints/datastore.py @@ -68,7 +68,7 @@ def add_image( logger.info(f"Image: {image}; File: {file}; params: {params}") file_ext = "".join(pathlib.Path(file.filename).suffixes) if file.filename else ".nii.gz" - image_id = image if image else os.path.basename(file.filename).replace(file_ext, "") + id = image if image else os.path.basename(file.filename).replace(file_ext, "") image_file = tempfile.NamedTemporaryFile(suffix=file_ext).name with open(image_file, "wb") as buffer: @@ -79,8 +79,12 @@ def add_image( save_params: Dict[str, Any] = json.loads(params) if params else {} if user: save_params["user"] = user - image_id = instance.datastore().add_image(image_id, image_file, save_params) - return {"image": image_id} + if not instance.datastore().get_is_multi_file(): + image_id = instance.datastore().add_image(id, image_file, save_params) + return {"image": image_id} + else: + directory_id = instance.datastore().add_directory(id, image_file, save_params) + return {"image": directory_id} def remove_image(id: str, user: Optional[str] = None): diff --git a/monailabel/interfaces/app.py b/monailabel/interfaces/app.py index f9b405bde..a8c92aa23 100644 --- a/monailabel/interfaces/app.py +++ b/monailabel/interfaces/app.py @@ -90,7 +90,9 @@ def __init__( self.app_dir = app_dir self.studies = studies self.conf = conf if conf else {} - + self.multichannel: bool = strtobool(conf.get("multichannel", False)) + self.multi_file: bool = strtobool(conf.get("multi_file", False)) + self.input_channels = conf.get("input_channels", False) self.name = name self.description = description self.version = version @@ -146,6 +148,8 @@ def init_datastore(self) -> Datastore: extensions=settings.MONAI_LABEL_DATASTORE_FILE_EXT, auto_reload=settings.MONAI_LABEL_DATASTORE_AUTO_RELOAD, read_only=settings.MONAI_LABEL_DATASTORE_READ_ONLY, + multichannel=self.multichannel, + multi_file=self.multi_file, ) def init_remote_datastore(self) -> Datastore: @@ -281,6 +285,10 @@ def infer(self, request, datastore=None): f"Inference Task is not Initialized. There is no model '{model}' available", ) + request["multi_file"] = self.multi_file + request["multichannel"] = self.multichannel + request["input_channels"] = self.input_channels + request = copy.deepcopy(request) request["description"] = task.description @@ -292,7 +300,7 @@ def infer(self, request, datastore=None): else: request["image"] = datastore.get_image_uri(request["image"]) - if os.path.isdir(request["image"]): + if os.path.isdir(request["image"]) and not self.multi_file: logger.info("Input is a Directory; Consider it as DICOM") logger.debug(f"Image => {request['image']}") @@ -430,6 +438,11 @@ def train(self, request): f"Train Task is not Initialized. There is no model '{model}' available; {request}", ) + # 4D image support, send train task information regarding data + request["multi_file"] = self.multi_file + request["multichannel"] = self.multichannel + request["input_channels"] = self.input_channels + request = copy.deepcopy(request) result = task(request, self.datastore()) diff --git a/monailabel/interfaces/datastore.py b/monailabel/interfaces/datastore.py index 78fa0aecc..dea84c592 100644 --- a/monailabel/interfaces/datastore.py +++ b/monailabel/interfaces/datastore.py @@ -201,6 +201,18 @@ def refresh(self) -> None: """ pass + @abstractmethod + def add_directory(self, id: str, filename: str, info: Dict[str, Any]) -> str: + """ + Save a directory for the given directory id and return the newly saved directory's id + + :param id: the directory id for the image; If None then base filename will be used + :param filename: the path to the directory + :param info: additional info for the directory + :return: the directory id for the saved image filename + """ + pass + @abstractmethod def add_image(self, image_id: str, image_filename: str, image_info: Dict[str, Any]) -> str: """ @@ -279,3 +291,17 @@ def json(self): Return json representation of datastore """ pass + + @abstractmethod + def get_is_multichannel(self) -> bool: + """ + Returns whether the application's studies is directed at multichannel (4D) data + """ + pass + + @abstractmethod + def get_is_multi_file(self) -> bool: + """ + Returns whether the application's studies is directed at directories containing multiple images per sample + """ + pass diff --git a/monailabel/tasks/activelearning/first.py b/monailabel/tasks/activelearning/first.py index 2a0ffa675..e80ec7704 100644 --- a/monailabel/tasks/activelearning/first.py +++ b/monailabel/tasks/activelearning/first.py @@ -35,5 +35,13 @@ def __call__(self, request, datastore: Datastore): images.sort() image = images[0] + # If the datastore contains 4d images send the multichannel flag to ensure images are loaded as sequences + if datastore.get_is_multichannel(): + return {"id": image, "multichannel": True} + + # If the datastore is multi_file, each sample has a directory with multiple images + if datastore.get_is_multi_file(): + return {"id": image, "multi_file": True} + logger.info(f"First: Selected Image: {image}") return {"id": image} diff --git a/monailabel/tasks/activelearning/random.py b/monailabel/tasks/activelearning/random.py index b196f7a6b..22b58bc98 100644 --- a/monailabel/tasks/activelearning/random.py +++ b/monailabel/tasks/activelearning/random.py @@ -45,4 +45,17 @@ def __call__(self, request, datastore: Datastore): image = random.choices(images, weights=weights)[0] logger.debug(f"Random: Images: {images}; Weight: {weights}") logger.info(f"Random: Selected Image: {image}; Weight: {weights[0]}") + + # If the datastore contains 4d images send the multichannel flag to ensure images are loaded as sequences + if datastore.get_is_multichannel(): + return {"id": image, "weight": weights[0], "multichannel": True} + + # If the datastore is multi_file, each sample has a directory with multiple images + if datastore.get_is_multi_file(): + return { + "id": image, + "weight": weights[0], + "multi_file": True, + } # this will send the directory and we will walk it later on + return {"id": image, "weight": weights[0]} diff --git a/monailabel/tasks/train/basic_train.py b/monailabel/tasks/train/basic_train.py index 9e5d0b1a9..65211f47a 100644 --- a/monailabel/tasks/train/basic_train.py +++ b/monailabel/tasks/train/basic_train.py @@ -83,6 +83,8 @@ def __init__(self): self.multi_gpu = False # multi gpu enabled self.local_rank = 0 # local rank in case of multi gpu self.world_size = 0 # world size in case of multi gpu + self.input_channels = 1 + self.multi_file = False self.request = None self.trainer = None @@ -490,6 +492,9 @@ def train(self, rank, world_size, request, datalist): context.run_id = request["run_id"] context.multi_gpu = request["multi_gpu"] + context.multi_file = request.get("multi_file", False) + context.input_channels = request.get("input_channels", 1) + if context.multi_gpu: os.environ["LOCAL_RANK"] = str(context.local_rank) diff --git a/plugins/slicer/MONAILabel/MONAILabel.py b/plugins/slicer/MONAILabel/MONAILabel.py index 5051dacf3..208e797e5 100644 --- a/plugins/slicer/MONAILabel/MONAILabel.py +++ b/plugins/slicer/MONAILabel/MONAILabel.py @@ -1303,19 +1303,57 @@ def onNextSampleButton(self): return logging.info(sample) - image_id = sample["id"] + id = sample["id"] image_file = sample.get("path") - image_name = sample.get("name", image_id) - node_name = sample.get("PatientID", sample.get("name", image_id)) + image_name = sample.get("name", id) + node_name = sample.get("PatientID", sample.get("name", id)) checksum = sample.get("checksum") local_exists = image_file and os.path.exists(image_file) + multichannel: bool = bool(sample.get("multichannel", False)) + multi_file: bool = bool(sample.get("multi_file", False)) logging.info(f"Check if file exists/shared locally: {image_file} => {local_exists}") if local_exists: - self._volumeNode = slicer.util.loadVolume(image_file) - self._volumeNode.SetName(node_name) + if multichannel: + # For 4D multichannel images, NOTE: slicer does not like 4D nifti images + # from https://github.com/Project-MONAI/MONAILabel/issues/241#issuecomment-1497788857 + volumeSequenceNode = slicer.util.loadSequence(image_file) + volumeSequenceNode.SetName(node_name) + # Get a volume node + browserNode = slicer.modules.sequences.logic().GetFirstBrowserNodeForSequenceNode( + volumeSequenceNode + ) + browserNode.SetOverwriteProxyName( + None, True + ) # set the proxy node name based on the sequence node name + self._volumeNode = browserNode.GetProxyNode(volumeSequenceNode) + else: + if not multi_file: + self._volumeNode = slicer.util.loadVolume(image_file) + self._volumeNode.SetName(node_name) + else: # in the case the underlying dataset is multi_file, we load all the images in the directory + dir_path = image_file + if not os.path.isdir(dir_path): + raise ValueError(f"multi_file=True but path is not a directory: {dir_path}") + + # get valid image paths + entries = sorted(os.listdir(dir_path)) + image_paths = [] + for name in entries: + full_path = os.path.join(dir_path, name) + if os.path.isfile(full_path): + image_paths.append(full_path) + + nodes = [] + for idx, image in enumerate(image_paths): + image_base_name = os.path.basename(image) + node = slicer.util.loadVolume(image) + node.SetName(image_base_name) + nodes.append(node) + + self._volumeNode = nodes[0] else: - download_uri = f"{self.serverUrl()}/datastore/image?image={quote_plus(image_id)}" + download_uri = f"{self.serverUrl()}/datastore/image?image={quote_plus(id)}" logging.info(download_uri) sampleDataLogic = SampleData.SampleDataLogic() @@ -1326,7 +1364,7 @@ def onNextSampleButton(self): if slicer.util.settingsValue("MONAILabel/originalLabel", True, converter=slicer.util.toBool): try: datastore = self.logic.datastore() - label_info = datastore["objects"][image_id]["labels"]["original"]["info"] + label_info = datastore["objects"][id]["labels"]["original"]["info"] labels = label_info.get("params", {}).get("label_names", {}) if labels: @@ -1338,7 +1376,7 @@ def onNextSampleButton(self): labels = self.logic.info().get("labels") # ext = datastore['objects'][image_id]['labels']['original']['ext'] - maskFile = self.logic.download_label(image_id, "original") + maskFile = self.logic.download_label(id, "original") self.updateSegmentationMask(maskFile, list(labels)) print("Original label uploaded! ") diff --git a/sample-apps/radiology/README.md b/sample-apps/radiology/README.md index 5c7356940..74ec5638d 100644 --- a/sample-apps/radiology/README.md +++ b/sample-apps/radiology/README.md @@ -215,6 +215,34 @@ the model to learn on new organ. - Output: N channels representing the segmented organs/tumors/tissues +
+ + Segmentation BraTS is a model based on UNet for automated multilabel brain tumor segmentation. This model is designed for multi-label segmentation tasks using pre-aligned, multi-modal MRI volumes. + + +> monailabel start_server --app workspace/radiology --studies workspace/images --conf models segmentation_brats --conf input_channels 4 --conf multi_file true + +- Additional Configs *(pass them as **--conf name value** while starting MONAILabel Server)* + +| Name | Values | Description | +|----------------------|------------------|--------------------------------------------------------------------------| +| use_pretrained_model | **true**, false | Set to `false` to skip loading pretrained weights | +| preload | true, **false** | Preload model into GPU at startup | +| scribbles | **true**, false | Set to `false` to disable scribble-based interactive segmentation models | + +- Network: This model uses the [UNet](https://docs.monai.io/en/latest/networks.html#unet) as the default network. Researchers can define their own network or use one of the listed [here](https://docs.monai.io/en/latest/networks.html) +- Labels + ```json + { + "tumor core": 1, + "whole tumor": 2, + "enhancing tumor": 3 + } + ``` +- Dataset: The model is trained over the adataset: https://www.med.upenn.edu/cbica/brats2020/ +- Inputs: 4 channels for the 4 BRATS image modalities +- Output: N channels representing the segmented tumors/tissues +
diff --git a/sample-apps/radiology/lib/configs/segmentation_brats.py b/sample-apps/radiology/lib/configs/segmentation_brats.py new file mode 100644 index 000000000..cd2f628a8 --- /dev/null +++ b/sample-apps/radiology/lib/configs/segmentation_brats.py @@ -0,0 +1,107 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Dict, Optional, Union + +import lib.infers +import lib.trainers +from monai.networks.nets import SegResNet +from monai.utils import optional_import + +from monailabel.interfaces.config import TaskConfig +from monailabel.interfaces.tasks.infer_v2 import InferTask +from monailabel.interfaces.tasks.train import TrainTask +from monailabel.utils.others.generic import download_file, strtobool + +_, has_cp = optional_import("cupy") +_, has_cucim = optional_import("cucim") + +logger = logging.getLogger(__name__) + + +class Segmentation(TaskConfig): + def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **kwargs): + super().init(name, model_dir, conf, planner, **kwargs) + + # BraTS labels: 3 multi-label channels produced by ConvertToMultiChannelBasedOnBratsClassesd + # Channel 0: TC - Tumor Core (label 2 OR label 3) + # Channel 1: WT - Whole Tumor (label 1 OR label 2 OR label 3) + # Channel 2: ET - Enhancing Tumor (label 2) + self.labels = { + "tumor core": 1, # Tumor Core + "whole tumor": 2, # Whole Tumor + "enhancing tumor": 3, # Enhancing Tumor + } + + # Model Files + self.path = [ + os.path.join(self.model_dir, f"pretrained_{name}.pt"), # pretrained + os.path.join(self.model_dir, f"{name}.pt"), # published + ] + + # Download PreTrained Model (optional) + if strtobool(self.conf.get("use_pretrained_model", "false")): + url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}" + url = f"{url}/radiology_segmentation_segresnet_brats.pt" + download_file(url, self.path[0]) + + # Spacing and ROI for BraTS (isotropic 1mm, large crop matching tutorial) + self.target_spacing = (1.0, 1.0, 1.0) + self.roi_size = (224, 224, 144) + + # Number of input channels: 4 MRI modalities (FLAIR, T1, T1Gd, T2) + # when multi_file=True the LoadDirectoryImagesd loader stacks them; + # when multi_file=False the image file must already be a 4-channel volume. + try: + input_channels = int(self.conf.get("input_channels", 4)) + except (ValueError, TypeError): + logger.warning("Could not parse input_channels, defaulting to 4") + input_channels = 4 + + # Network + self.network = SegResNet( + blocks_down=(1, 2, 2, 4), + blocks_up=(1, 1, 1), + init_filters=16, + in_channels=input_channels, + out_channels=len(self.labels), # TC, WT, ET — sigmoid multilabel, no background channel + dropout_prob=0.2, + ) + + def infer(self) -> Union[InferTask, Dict[str, InferTask]]: + task: InferTask = lib.infers.Segmentation( + path=self.path, + network=self.network, + roi_size=self.roi_size, + target_spacing=self.target_spacing, + labels=self.labels, + preload=strtobool(self.conf.get("preload", "false")), + config={"largest_cc": True if has_cp and has_cucim else False}, + ) + return task + + def trainer(self) -> Optional[TrainTask]: + output_dir = os.path.join(self.model_dir, self.name) + load_path = self.path[0] if os.path.exists(self.path[0]) else self.path[1] + + task: TrainTask = lib.trainers.Segmentation( + model_dir=output_dir, + network=self.network, + roi_size=self.roi_size, + target_spacing=self.target_spacing, + load_path=load_path, + publish_path=self.path[1], + description="Train BraTS Segmentation Model (TC/WT/ET multilabel)", + labels=self.labels, + ) + return task diff --git a/sample-apps/radiology/lib/infers/segmentation_brats.py b/sample-apps/radiology/lib/infers/segmentation_brats.py new file mode 100644 index 000000000..6a832946d --- /dev/null +++ b/sample-apps/radiology/lib/infers/segmentation_brats.py @@ -0,0 +1,159 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Sequence + +from lib.transforms.transforms import ConvertFromMultiChannelBasedOnBratsClassesd, GetCentroidsd, LoadDirectoryImagesd +from monai.inferers import Inferer, SlidingWindowInferer +from monai.transforms import ( + Activationsd, + AsDiscreted, + EnsureChannelFirstd, + EnsureTyped, + KeepLargestConnectedComponentd, + LoadImaged, + NormalizeIntensityd, + Orientationd, + Spacingd, +) + +from monailabel.interfaces.tasks.infer_v2 import InferType +from monailabel.tasks.infer.basic_infer import BasicInferTask +from monailabel.transform.post import Restored + + +class Segmentation(BasicInferTask): + """ + Inference Engine for BraTS brain tumour segmentation using a SegResNet. + + The model outputs 3 channels (TC, WT, ET) with sigmoid activations — it is + a multilabel task, NOT a softmax classification. Each channel is thresholded + independently at 0.5 to produce binary maps. + + Two image loading modes are supported (set via ``data["multi_file"]``): + - False (default): the input image is a single 4-channel NIfTI volume. + - True: ``data["image"]`` is a directory containing 4 single- + modality NIfTI files; LoadDirectoryImagesd stacks them. + """ + + def __init__( + self, + path, + network=None, + target_spacing=(1.0, 1.0, 1.0), + type=InferType.SEGMENTATION, + labels=None, + dimension=3, + description="Pre-trained BraTS SegResNet — TC/WT/ET multilabel segmentation", + **kwargs, + ): + super().__init__( + path=path, + network=network, + type=type, + labels=labels, + dimension=dimension, + description=description, + load_strict=False, + **kwargs, + ) + self.target_spacing = target_spacing + + def pre_transforms(self, data=None) -> Sequence[Callable]: + """ + Pre-processing pipeline matching the official MONAI BraTS tutorial. + + NOTE: ScaleIntensityRangePercentilesd and CenterSpatialCropd from the + original file have been removed — they are not part of the BraTS pipeline + and would distort MRI intensity normalisation. NormalizeIntensityd with + nonzero=True, channel_wise=True is the correct approach for multi-modal MRI. + """ + channels = data.get("input_channels", 4) + t = [ + ( + LoadImaged(keys="image", reader="ITKReader", ensure_channel_first=True) + if data.get("multi_file", False) is False + else LoadDirectoryImagesd( + keys="image", + target_spacing=self.target_spacing, + channels=channels, + ) + ), + EnsureTyped(keys="image", device=data.get("device") if data else None), + # EnsureChannelFirstd is safe to keep as a guard; if the channel dim is + # already present (ITKReader + ensure_channel_first) it is a no-op. + EnsureChannelFirstd(keys="image", channel_dim=0), + Orientationd(keys="image", axcodes="RAS"), + Spacingd( + keys="image", + pixdim=self.target_spacing, + allow_missing_keys=True, + ), + # Channel-wise intensity normalisation on non-zero voxels only. + # This matches both the tutorial and the training pipeline exactly. + NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), + ] + return t + + def inferer(self, data=None) -> Inferer: + return SlidingWindowInferer( + roi_size=self.roi_size, + sw_batch_size=2, + overlap=0.4, + padding_mode="replicate", + mode="gaussian", + ) + + def inverse_transforms(self, data=None): + return [] + + def post_transforms(self, data=None) -> Sequence[Callable]: + """ + Post-processing for multilabel sigmoid output. + + IMPORTANT differences from a softmax segmentation: + - Activationsd uses sigmoid=True (not softmax=True). + - AsDiscreted thresholds each channel at 0.5 independently + (not argmax, because channels are not mutually exclusive). + - KeepLargestConnectedComponentd is applied per-channel if available. + """ + t = [ + EnsureTyped(keys="pred", device=data.get("device") if data else None), + # Sigmoid: each of the 3 channels (TC, WT, ET) is activated independently. + Activationsd(keys="pred", sigmoid=True), + # Threshold each channel at 0.5 to produce binary masks. + AsDiscreted(keys="pred", threshold=0.5), + ] + + if data and data.get("largest_cc", False): + # Apply per-channel so TC, WT and ET are each cleaned independently. + t.append( + KeepLargestConnectedComponentd( + keys="pred", + independent=True, # treat each channel separately + ) + ) + + t.extend( + [ + # Merge 3 binary channels → single-channel integer label map + # Must happen before Restored so spatial metadata is applied + # to the final (1, H, W, D) output, not the intermediate (3, H, W, D). + ConvertFromMultiChannelBasedOnBratsClassesd(keys="pred"), + Restored( + keys="pred", + ref_image="image", + config_labels=self.labels if data.get("restore_label_idx", False) else None, + ), + GetCentroidsd(keys="pred", centroids_key="centroids"), + ] + ) + return t diff --git a/sample-apps/radiology/lib/trainers/segmentation_brats.py b/sample-apps/radiology/lib/trainers/segmentation_brats.py new file mode 100644 index 000000000..34fcbd8be --- /dev/null +++ b/sample-apps/radiology/lib/trainers/segmentation_brats.py @@ -0,0 +1,199 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +from lib.transforms.transforms import LoadDirectoryImagesd +from monai.handlers import TensorBoardImageHandler, from_engine +from monai.inferers import SlidingWindowInferer +from monai.losses import DiceLoss +from monai.transforms import ( + Activationsd, + AsDiscreted, + ConvertToMultiChannelBasedOnBratsClassesd, + EnsureTyped, + LoadImaged, + NormalizeIntensityd, + Orientationd, + RandFlipd, + RandScaleIntensityd, + RandShiftIntensityd, + RandSpatialCropd, + Spacingd, +) + +from monailabel.tasks.train.basic_train import BasicTrainTask, Context +from monailabel.tasks.train.utils import region_wise_metrics + +logger = logging.getLogger(__name__) + + +class Segmentation(BasicTrainTask): + def __init__( + self, + model_dir, + network, + roi_size=(224, 224, 144), + target_spacing=(1.0, 1.0, 1.0), + num_samples=4, + description="Train BraTS Segmentation model (TC/WT/ET multilabel)", + **kwargs, + ): + self._network = network + self.roi_size = roi_size + self.target_spacing = target_spacing + self.num_samples = num_samples + super().__init__(model_dir, description, **kwargs) + + def network(self, context: Context): + return self._network + + def optimizer(self, context: Context): + return torch.optim.Adam(context.network.parameters(), lr=1e-4, weight_decay=1e-5) + + def loss_function(self, context: Context): + # BraTS is a sigmoid multilabel task (TC, WT, ET). + # to_onehot_y=False because the label is already 3-channel after + # ConvertToMultiChannelBasedOnBratsClassesd. + # sigmoid=True because each channel is independent (not mutually exclusive). + return DiceLoss( + smooth_nr=0, + smooth_dr=1e-5, + squared_pred=True, + to_onehot_y=False, + sigmoid=True, + ) + + def lr_scheduler_handler(self, context: Context): + return None + + def train_data_loader(self, context, num_workers=0, shuffle=False): + return super().train_data_loader(context, num_workers, True) + + def train_pre_transforms(self, context: Context): + """ + Transforms follow the official MONAI BraTS tutorial exactly. + + Two loading paths: + - multi_file=False : image is already a single 4-channel .nii.gz volume + (LoadImaged handles it, then EnsureChannelFirstd is a no-op + because ITKReader + ensure_channel_first already adds the channel dim) + - multi_file=True : a directory of 4 single-modality files is stacked by + LoadDirectoryImagesd into a (4, H, W, D) tensor + """ + channels = context.input_channels + return [ + LoadDirectoryImagesd(keys="image", target_spacing=self.target_spacing, channels=channels), + LoadImaged(keys="label", reader="ITKReader", ensure_channel_first=True), + # ( + # LoadImaged(keys="image", reader="ITKReader", ensure_channel_first=True) + # if context.multi_file is False + # else LoadDirectoryImagesd(keys="image", target_spacing=self.target_spacing, channels=channels) + # ), + # ConvertToMultiChannelBasedOnBratsClassesd converts the integer label map + # to a 3-channel binary tensor: [TC, WT, ET]. + ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), + EnsureTyped(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RAS"), + Spacingd( + keys=["image", "label"], + pixdim=(1.0, 1.0, 1.0), + mode=("bilinear", "nearest"), + ), + # Random crop matching the official tutorial roi + RandSpatialCropd( + keys=["image", "label"], + roi_size=self.roi_size, + random_size=False, + ), + RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), + RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), + RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), + # Channel-wise zero-mean / unit-std normalisation on non-zero voxels only + NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), + RandScaleIntensityd(keys="image", factors=0.1, prob=1.0), + RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0), + ] + + def train_post_transforms(self, context: Context): + """ + Post-transforms for TRAINING metrics. + + Because this is a sigmoid multilabel task: + - Apply sigmoid activation per channel. + - Threshold at 0.5 to get binary predictions. + - The label is already binary 3-channel — no argmax / to_onehot needed. + """ + return [ + EnsureTyped(keys="pred", device=context.device), + Activationsd(keys="pred", sigmoid=True), + AsDiscreted(keys="pred", threshold=0.5), + # label is already binary 3-channel, nothing to do + ] + + def val_pre_transforms(self, context: Context): + channels = context.input_channels + return [ + LoadDirectoryImagesd(keys="image", target_spacing=self.target_spacing, channels=channels), + LoadImaged(keys="label", reader="ITKReader", ensure_channel_first=True), + # ( + # LoadImaged(keys="image", reader="ITKReader", ensure_channel_first=True) + # if context.multi_file is False + # else LoadDirectoryImagesd(keys="image", target_spacing=self.target_spacing, channels=channels) + # ), + ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), + EnsureTyped(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RAS"), + Spacingd( + keys=["image", "label"], + pixdim=(1.0, 1.0, 1.0), + mode=("bilinear", "nearest"), + ), + # No crop during validation — sliding window covers the full volume + NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), + ] + + def val_inferer(self, context: Context): + return SlidingWindowInferer( + roi_size=self.roi_size, + sw_batch_size=2, + overlap=0.4, + padding_mode="replicate", + mode="gaussian", + ) + + def norm_labels(self): + new_label_nums = {} + for idx, key_label in enumerate(self._labels.keys(), start=0): + if key_label != "background": + new_label_nums[key_label] = idx + return new_label_nums + + def train_key_metric(self, context: Context): + return region_wise_metrics(self.norm_labels(), "train_mean_dice", "train") + + def val_key_metric(self, context: Context): + return region_wise_metrics(self.norm_labels(), "val_mean_dice", "val") + + def train_handlers(self, context: Context): + handlers = super().train_handlers(context) + if context.local_rank == 0: + handlers.append( + TensorBoardImageHandler( + log_dir=context.events_dir, + batch_transform=from_engine(["image", "label"]), + output_transform=from_engine(["pred"]), + interval=20, + epoch_level=True, + ) + ) + return handlers diff --git a/sample-apps/radiology/lib/transforms/transforms.py b/sample-apps/radiology/lib/transforms/transforms.py index c24202328..4bf71ed38 100644 --- a/sample-apps/radiology/lib/transforms/transforms.py +++ b/sample-apps/radiology/lib/transforms/transforms.py @@ -10,6 +10,7 @@ # limitations under the License. import copy import logging +import os from typing import Any, Dict, Hashable, Mapping import numpy as np @@ -18,7 +19,18 @@ from monai.config import KeysCollection, NdarrayOrTensor from monai.data import MetaTensor from monai.networks.layers import GaussianFilter -from monai.transforms import CropForeground, GaussianSmooth, Randomizable, Resize, ScaleIntensity, SpatialCrop +from monai.transforms import ( + ConcatItemsd, + CropForeground, + EnsureChannelFirst, + GaussianSmooth, + LoadImage, + Randomizable, + Resize, + ScaleIntensity, + Spacing, + SpatialCrop, +) from monai.transforms.transform import MapTransform, Transform from monai.utils.enums import CommonKeys @@ -27,6 +39,141 @@ logger = logging.getLogger(__name__) +class ConvertFromMultiChannelBasedOnBratsClassesd(MapTransform): + """ + Dictionary-based transform that reverses + ``ConvertToMultiChannelBasedOnBratsClassesd``. + + Converts a 3-channel binary prediction (TC, WT, ET) back to a + single-channel integer label map: + + Output shape: (1, H, W, D), dtype ``torch.long`` by default. + + Args: + keys: keys of the items to be transformed. + dtype: output dtype, default ``torch.long``. + allow_missing_keys: don't raise an error if a key is missing. + """ + + def __init__( + self, + keys: KeysCollection, + dtype: torch.dtype = torch.long, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.dtype = dtype + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + img = d[key] + + if img.shape[0] != 3: + raise ValueError( + f"Expected 3-channel input (TC, WT, ET) for key '{key}', " f"got {img.shape[0]} channels." + ) + + tc = img[0].bool() + wt = img[1].bool() + et = img[2].bool() + + label_map = torch.zeros_like(img[0], dtype=self.dtype) + label_map[wt & ~tc] = 1 # Oedema + label_map[et] = 2 # Enhancing tumour + label_map[tc & ~et] = 3 # Necrotic core + + result = label_map.unsqueeze(0) + + d[key] = MetaTensor(result, meta=img.meta) if isinstance(img, MetaTensor) else result + + return d + + +# Adapted from https://github.com/Project-MONAI/MONAILabel/issues/241#issuecomment-1497561538 +class LoadDirectoryImagesd(MapTransform): + """ + Load all 3D images from a directory, stack them along a new axis, + and preserve MONAI-style metadata similar to LoadImaged. + + - Each key should be a directory path. + - Assumes all images share the same spatial dimensions. + - Stores image data in `d[key]` and metadata in `d[f"{key}_meta_dict"]`. + """ + + def __init__(self, keys: KeysCollection, target_spacing=None, allow_missing_keys: bool = False, channels: int = 2): + super().__init__(keys, allow_missing_keys) + self.target_spacing = target_spacing + self.loader = LoadImage(reader="ITKReader", image_only=False) + self.ensure_channel_first = EnsureChannelFirst() + self.spacer = Spacing(pixdim=self.target_spacing, mode="bilinear") if target_spacing else None + self.resizer = None # initialized later + self.channels = int(channels) + + def __call__(self, data: Dict): + d = dict(data) + + for key in self.key_iterator(d): + dir_path = d[key] + if not os.path.isdir(dir_path): + raise ValueError(f"Expected a directory path for key '{key}', got: {dir_path}") + + # Gather all files in directory + image_files = sorted( + [ + os.path.join(dir_path, f) + for f in os.listdir(dir_path) + if f.lower().endswith((".nii", ".nii.gz", ".nrrd")) + ] + ) + if not image_files: + raise FileNotFoundError(f"No NIfTI images found in directory {dir_path}") + + channel_keys = [] + meta_dicts = [] + + logger.info(f"Loading {len(image_files)} images from {dir_path}") + + for idx, img_path in enumerate(image_files): + img, meta = self.loader(img_path) + img = self.ensure_channel_first(img) + + if self.resizer is None: + self.resizer = Resize(spatial_size=img.shape[1:], mode="bilinear") + + img = self.resizer(img) if self.resizer is not None else img + + ch_key = f"{key}_ch{idx + 1}" + d[ch_key] = img + d[f"{ch_key}_meta_dict"] = meta + + channel_keys.append(ch_key) + meta_dicts.append(meta) + + logger.debug(f"Loaded {ch_key}: {img.shape}") + + # MONAI-native concatenation + self.concat = ConcatItemsd(keys=channel_keys, name=key, dim=0) + d = self.concat(d) + + # Clean up temporary channel keys + for ch_key in channel_keys: + d.pop(ch_key, None) + d.pop(f"{ch_key}_meta_dict", None) + + # Construct merged metadata + merged_meta = copy.deepcopy(meta_dicts[0]) + merged_meta["filename_or_obj"] = image_files + merged_meta["num_channels"] = len(channel_keys) + merged_meta["original_channel_dim"] = 0 + + d[f"{key}_meta_dict"] = merged_meta + + logger.info(f"Concatenated {len(channel_keys)} images → {d[key].shape}") + + return d + + class BinaryMaskd(MapTransform): def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): """ diff --git a/sample-apps/radiology/main.py b/sample-apps/radiology/main.py index bb7f8ae18..5a1b2ab1f 100644 --- a/sample-apps/radiology/main.py +++ b/sample-apps/radiology/main.py @@ -307,12 +307,18 @@ def main(): parser.add_argument("-s", "--studies", default=studies) parser.add_argument("-m", "--model", default="segmentation") parser.add_argument("-t", "--test", default="batch_infer", choices=("train", "infer", "batch_infer")) + parser.add_argument("-multi", "--multichannel", default=False) + parser.add_argument("-c", "--input_channels") + parser.add_argument("-multif", "--multi_file", default=False) args = parser.parse_args() app_dir = os.path.dirname(__file__) studies = args.studies conf = { "models": args.model, + "multichannel": args.multichannel, + "input_channels": args.input_channels, + "multi_file": args.multi_file, "preload": "false", }