diff --git a/monailabel/datastore/cvat.py b/monailabel/datastore/cvat.py
index 461883ea0..0fb293bc5 100644
--- a/monailabel/datastore/cvat.py
+++ b/monailabel/datastore/cvat.py
@@ -16,6 +16,7 @@
import tempfile
import time
import urllib.parse
+from typing import Any, Dict
import numpy as np
import requests
@@ -318,6 +319,22 @@ def download_from_cvat(self, max_retry_count=5, retry_wait_time=10):
retry_count += 1
return None
+ def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str:
+ """Not Implemented"""
+ raise NotImplementedError("This datastore does not support adding directories")
+
+ def get_is_multichannel(self) -> bool:
+ """
+ Returns whether the application's studies is directed at multichannel (4D) data
+ """
+ raise NotImplementedError("This datastore does not support multichannel imaging")
+
+ def get_is_multi_file(self) -> bool:
+ """
+ Returns whether the application's studies is directed at directories containing multiple images per sample
+ """
+ raise NotImplementedError("This datastore does not support support multi-volume imaging")
+
"""
def main():
diff --git a/monailabel/datastore/dicom.py b/monailabel/datastore/dicom.py
index ed4733ca6..0ebe7de2f 100644
--- a/monailabel/datastore/dicom.py
+++ b/monailabel/datastore/dicom.py
@@ -264,3 +264,19 @@ def _download_labeled_data(self):
def datalist(self, full_path=True) -> List[Dict[str, Any]]:
self._download_labeled_data()
return super().datalist(full_path)
+
+ def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str:
+ """Not Implemented"""
+ raise NotImplementedError("This datastore does not support adding directories")
+
+ def get_is_multichannel(self) -> bool:
+ """
+ Returns whether the application's studies is directed at multichannel (4D) data
+ """
+ raise NotImplementedError("This datastore does not support multichannel imaging")
+
+ def get_is_multi_file(self) -> bool:
+ """
+ Returns whether the application's studies is directed at directories containing multiple images per sample
+ """
+ raise NotImplementedError("This datastore does not support support multi-volume imaging")
diff --git a/monailabel/datastore/dsa.py b/monailabel/datastore/dsa.py
index 365cef24e..0184dc4cb 100644
--- a/monailabel/datastore/dsa.py
+++ b/monailabel/datastore/dsa.py
@@ -270,6 +270,22 @@ def status(self) -> Dict[str, Any]:
def json(self):
return self.datalist()
+ def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str:
+ """Not Implemented"""
+ raise NotImplementedError("This datastore does not support adding directories")
+
+ def get_is_multichannel(self) -> bool:
+ """
+ Returns whether the application's studies is directed at multichannel (4D) data
+ """
+ raise NotImplementedError("This datastore does not support multichannel imaging")
+
+ def get_is_multi_file(self) -> bool:
+ """
+ Returns whether the application's studies is directed at directories containing multiple images per sample
+ """
+ raise NotImplementedError("This datastore does not support support multi-volume imaging")
+
"""
def main():
diff --git a/monailabel/datastore/local.py b/monailabel/datastore/local.py
index d8b0538aa..cbe0a4134 100644
--- a/monailabel/datastore/local.py
+++ b/monailabel/datastore/local.py
@@ -102,9 +102,11 @@ def __init__(
images_dir: str = ".",
labels_dir: str = "labels",
datastore_config: str = "datastore_v2.json",
- extensions=("*.nii.gz", "*.nii"),
+ extensions=("*.nii.gz", "*.nii", "*.nrrd"),
auto_reload=False,
read_only=False,
+ multichannel: bool = False,
+ multi_file: bool = False,
):
"""
Creates a `LocalDataset` object
@@ -124,6 +126,8 @@ def __init__(
self._ignore_event_config = False
self._config_ts = 0
self._auto_reload = auto_reload
+ self._multichannel: bool = multichannel
+ self._multi_file: bool = multi_file
logging.getLogger("filelock").setLevel(logging.ERROR)
@@ -256,6 +260,12 @@ def datalist(self, full_path=True) -> List[Dict[str, Any]]:
ds = json.loads(json.dumps(ds).replace(f"{self._datastore_path.rstrip(os.pathsep)}{os.pathsep}", ""))
return ds
+ def get_is_multichannel(self) -> bool:
+ return self._multichannel
+
+ def get_is_multi_file(self) -> bool:
+ return self._multi_file
+
def get_image(self, image_id: str, params=None) -> Any:
"""
Retrieve image object based on image id
@@ -431,6 +441,29 @@ def refresh(self):
"""
self._reconcile_datastore()
+ def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str:
+ id = os.path.basename(filename)
+ if not directory_id:
+ directory_id = id
+
+ logger.info(f"Adding Image: {directory_id} => {filename}")
+ name = directory_id
+ dest = os.path.realpath(os.path.join(self._datastore.image_path(), name))
+
+ with FileLock(self._lock_file):
+ logger.debug("Acquired the lock!")
+ shutil.copy(filename, dest)
+
+ info = info if info else {}
+ info["ts"] = int(time.time())
+ info["name"] = name
+
+ # images = get_directory_contents(filename)
+ self._datastore.objects[directory_id] = ImageLabelModel(image=DataModel(info=info, ext=""))
+ self._update_datastore_file(lock=False)
+ logger.debug("Released the lock!")
+ return directory_id
+
def add_image(self, image_id: str, image_filename: str, image_info: Dict[str, Any]) -> str:
id, image_ext = self._to_id(os.path.basename(image_filename))
if not image_id:
@@ -552,10 +585,15 @@ def _list_files(self, path, patterns):
files = os.listdir(path)
filtered = dict()
- for pattern in patterns:
- matching = fnmatch.filter(files, pattern)
- for file in matching:
- filtered[os.path.basename(file)] = file
+ if not self._multi_file:
+ for pattern in patterns:
+ matching = fnmatch.filter(files, pattern)
+ for file in matching:
+ filtered[os.path.basename(file)] = file
+ else:
+ for file in files:
+ if file.lower() not in ["labels", ".lock", "datastore_v2.json"]:
+ filtered[os.path.basename(file)] = file
return filtered
def _reconcile_datastore(self):
@@ -585,23 +623,26 @@ def _add_non_existing_images(self) -> int:
invalidate = 0
self._init_from_datastore_file()
- local_images = self._list_files(self._datastore.image_path(), self._extensions)
+ local_files = self._list_files(self._datastore.image_path(), self._extensions)
- image_ids = list(self._datastore.objects.keys())
- for image_file in local_images:
- image_id, image_ext = self._to_id(image_file)
- if image_id not in image_ids:
- logger.info(f"Adding New Image: {image_id} => {image_file}")
+ ids = list(self._datastore.objects.keys())
+ for file in local_files:
+ if self._multi_file:
+ # Directories have no extension — use the name as-is
+ file_id = file
+ file_ext_str = ""
+ else:
+ file_id, file_ext_str = self._to_id(file)
- name = self._filename(image_id, image_ext)
- image_info = {
+ if file_id not in ids:
+ logger.info(f"Adding New Image: {file_id} => {file}")
+ name = self._filename(file_id, file_ext_str)
+ file_info = {
"ts": int(time.time()),
- # "checksum": file_checksum(os.path.join(self._datastore.image_path(), name)),
"name": name,
}
-
invalidate += 1
- self._datastore.objects[image_id] = ImageLabelModel(image=DataModel(info=image_info, ext=image_ext))
+ self._datastore.objects[file_id] = ImageLabelModel(name=DataModel(info=file_info, ext=file_ext_str))
return invalidate
diff --git a/monailabel/datastore/xnat.py b/monailabel/datastore/xnat.py
index c0904fd8d..895aaeb15 100644
--- a/monailabel/datastore/xnat.py
+++ b/monailabel/datastore/xnat.py
@@ -386,6 +386,22 @@ def __upload_assessment(self, aiaa_model_name, image_id, file_path, type):
self._request_put(url, data, type=type)
+ def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str:
+ """Not Implemented"""
+ raise NotImplementedError("This datastore does not support adding directories")
+
+ def get_is_multichannel(self) -> bool:
+ """
+ Returns whether the application's studies is directed at multichannel (4D) data
+ """
+ raise NotImplementedError("This datastore does not support multichannel imaging")
+
+ def get_is_multi_file(self) -> bool:
+ """
+ Returns whether the application's studies is directed at directories containing multiple images per sample
+ """
+ raise NotImplementedError("This datastore does not support support multi-volume imaging")
+
"""
def main():
diff --git a/monailabel/endpoints/datastore.py b/monailabel/endpoints/datastore.py
index fdd63bb6e..2cf26bb6f 100644
--- a/monailabel/endpoints/datastore.py
+++ b/monailabel/endpoints/datastore.py
@@ -68,7 +68,7 @@ def add_image(
logger.info(f"Image: {image}; File: {file}; params: {params}")
file_ext = "".join(pathlib.Path(file.filename).suffixes) if file.filename else ".nii.gz"
- image_id = image if image else os.path.basename(file.filename).replace(file_ext, "")
+ id = image if image else os.path.basename(file.filename).replace(file_ext, "")
image_file = tempfile.NamedTemporaryFile(suffix=file_ext).name
with open(image_file, "wb") as buffer:
@@ -79,8 +79,12 @@ def add_image(
save_params: Dict[str, Any] = json.loads(params) if params else {}
if user:
save_params["user"] = user
- image_id = instance.datastore().add_image(image_id, image_file, save_params)
- return {"image": image_id}
+ if not instance.datastore().get_is_multi_file():
+ image_id = instance.datastore().add_image(id, image_file, save_params)
+ return {"image": image_id}
+ else:
+ directory_id = instance.datastore().add_directory(id, image_file, save_params)
+ return {"image": directory_id}
def remove_image(id: str, user: Optional[str] = None):
diff --git a/monailabel/interfaces/app.py b/monailabel/interfaces/app.py
index f9b405bde..a8c92aa23 100644
--- a/monailabel/interfaces/app.py
+++ b/monailabel/interfaces/app.py
@@ -90,7 +90,9 @@ def __init__(
self.app_dir = app_dir
self.studies = studies
self.conf = conf if conf else {}
-
+ self.multichannel: bool = strtobool(conf.get("multichannel", False))
+ self.multi_file: bool = strtobool(conf.get("multi_file", False))
+ self.input_channels = conf.get("input_channels", False)
self.name = name
self.description = description
self.version = version
@@ -146,6 +148,8 @@ def init_datastore(self) -> Datastore:
extensions=settings.MONAI_LABEL_DATASTORE_FILE_EXT,
auto_reload=settings.MONAI_LABEL_DATASTORE_AUTO_RELOAD,
read_only=settings.MONAI_LABEL_DATASTORE_READ_ONLY,
+ multichannel=self.multichannel,
+ multi_file=self.multi_file,
)
def init_remote_datastore(self) -> Datastore:
@@ -281,6 +285,10 @@ def infer(self, request, datastore=None):
f"Inference Task is not Initialized. There is no model '{model}' available",
)
+ request["multi_file"] = self.multi_file
+ request["multichannel"] = self.multichannel
+ request["input_channels"] = self.input_channels
+
request = copy.deepcopy(request)
request["description"] = task.description
@@ -292,7 +300,7 @@ def infer(self, request, datastore=None):
else:
request["image"] = datastore.get_image_uri(request["image"])
- if os.path.isdir(request["image"]):
+ if os.path.isdir(request["image"]) and not self.multi_file:
logger.info("Input is a Directory; Consider it as DICOM")
logger.debug(f"Image => {request['image']}")
@@ -430,6 +438,11 @@ def train(self, request):
f"Train Task is not Initialized. There is no model '{model}' available; {request}",
)
+ # 4D image support, send train task information regarding data
+ request["multi_file"] = self.multi_file
+ request["multichannel"] = self.multichannel
+ request["input_channels"] = self.input_channels
+
request = copy.deepcopy(request)
result = task(request, self.datastore())
diff --git a/monailabel/interfaces/datastore.py b/monailabel/interfaces/datastore.py
index 78fa0aecc..dea84c592 100644
--- a/monailabel/interfaces/datastore.py
+++ b/monailabel/interfaces/datastore.py
@@ -201,6 +201,18 @@ def refresh(self) -> None:
"""
pass
+ @abstractmethod
+ def add_directory(self, id: str, filename: str, info: Dict[str, Any]) -> str:
+ """
+ Save a directory for the given directory id and return the newly saved directory's id
+
+ :param id: the directory id for the image; If None then base filename will be used
+ :param filename: the path to the directory
+ :param info: additional info for the directory
+ :return: the directory id for the saved image filename
+ """
+ pass
+
@abstractmethod
def add_image(self, image_id: str, image_filename: str, image_info: Dict[str, Any]) -> str:
"""
@@ -279,3 +291,17 @@ def json(self):
Return json representation of datastore
"""
pass
+
+ @abstractmethod
+ def get_is_multichannel(self) -> bool:
+ """
+ Returns whether the application's studies is directed at multichannel (4D) data
+ """
+ pass
+
+ @abstractmethod
+ def get_is_multi_file(self) -> bool:
+ """
+ Returns whether the application's studies is directed at directories containing multiple images per sample
+ """
+ pass
diff --git a/monailabel/tasks/activelearning/first.py b/monailabel/tasks/activelearning/first.py
index 2a0ffa675..e80ec7704 100644
--- a/monailabel/tasks/activelearning/first.py
+++ b/monailabel/tasks/activelearning/first.py
@@ -35,5 +35,13 @@ def __call__(self, request, datastore: Datastore):
images.sort()
image = images[0]
+ # If the datastore contains 4d images send the multichannel flag to ensure images are loaded as sequences
+ if datastore.get_is_multichannel():
+ return {"id": image, "multichannel": True}
+
+ # If the datastore is multi_file, each sample has a directory with multiple images
+ if datastore.get_is_multi_file():
+ return {"id": image, "multi_file": True}
+
logger.info(f"First: Selected Image: {image}")
return {"id": image}
diff --git a/monailabel/tasks/activelearning/random.py b/monailabel/tasks/activelearning/random.py
index b196f7a6b..22b58bc98 100644
--- a/monailabel/tasks/activelearning/random.py
+++ b/monailabel/tasks/activelearning/random.py
@@ -45,4 +45,17 @@ def __call__(self, request, datastore: Datastore):
image = random.choices(images, weights=weights)[0]
logger.debug(f"Random: Images: {images}; Weight: {weights}")
logger.info(f"Random: Selected Image: {image}; Weight: {weights[0]}")
+
+ # If the datastore contains 4d images send the multichannel flag to ensure images are loaded as sequences
+ if datastore.get_is_multichannel():
+ return {"id": image, "weight": weights[0], "multichannel": True}
+
+ # If the datastore is multi_file, each sample has a directory with multiple images
+ if datastore.get_is_multi_file():
+ return {
+ "id": image,
+ "weight": weights[0],
+ "multi_file": True,
+ } # this will send the directory and we will walk it later on
+
return {"id": image, "weight": weights[0]}
diff --git a/monailabel/tasks/train/basic_train.py b/monailabel/tasks/train/basic_train.py
index 9e5d0b1a9..65211f47a 100644
--- a/monailabel/tasks/train/basic_train.py
+++ b/monailabel/tasks/train/basic_train.py
@@ -83,6 +83,8 @@ def __init__(self):
self.multi_gpu = False # multi gpu enabled
self.local_rank = 0 # local rank in case of multi gpu
self.world_size = 0 # world size in case of multi gpu
+ self.input_channels = 1
+ self.multi_file = False
self.request = None
self.trainer = None
@@ -490,6 +492,9 @@ def train(self, rank, world_size, request, datalist):
context.run_id = request["run_id"]
context.multi_gpu = request["multi_gpu"]
+ context.multi_file = request.get("multi_file", False)
+ context.input_channels = request.get("input_channels", 1)
+
if context.multi_gpu:
os.environ["LOCAL_RANK"] = str(context.local_rank)
diff --git a/plugins/slicer/MONAILabel/MONAILabel.py b/plugins/slicer/MONAILabel/MONAILabel.py
index 5051dacf3..208e797e5 100644
--- a/plugins/slicer/MONAILabel/MONAILabel.py
+++ b/plugins/slicer/MONAILabel/MONAILabel.py
@@ -1303,19 +1303,57 @@ def onNextSampleButton(self):
return
logging.info(sample)
- image_id = sample["id"]
+ id = sample["id"]
image_file = sample.get("path")
- image_name = sample.get("name", image_id)
- node_name = sample.get("PatientID", sample.get("name", image_id))
+ image_name = sample.get("name", id)
+ node_name = sample.get("PatientID", sample.get("name", id))
checksum = sample.get("checksum")
local_exists = image_file and os.path.exists(image_file)
+ multichannel: bool = bool(sample.get("multichannel", False))
+ multi_file: bool = bool(sample.get("multi_file", False))
logging.info(f"Check if file exists/shared locally: {image_file} => {local_exists}")
if local_exists:
- self._volumeNode = slicer.util.loadVolume(image_file)
- self._volumeNode.SetName(node_name)
+ if multichannel:
+ # For 4D multichannel images, NOTE: slicer does not like 4D nifti images
+ # from https://github.com/Project-MONAI/MONAILabel/issues/241#issuecomment-1497788857
+ volumeSequenceNode = slicer.util.loadSequence(image_file)
+ volumeSequenceNode.SetName(node_name)
+ # Get a volume node
+ browserNode = slicer.modules.sequences.logic().GetFirstBrowserNodeForSequenceNode(
+ volumeSequenceNode
+ )
+ browserNode.SetOverwriteProxyName(
+ None, True
+ ) # set the proxy node name based on the sequence node name
+ self._volumeNode = browserNode.GetProxyNode(volumeSequenceNode)
+ else:
+ if not multi_file:
+ self._volumeNode = slicer.util.loadVolume(image_file)
+ self._volumeNode.SetName(node_name)
+ else: # in the case the underlying dataset is multi_file, we load all the images in the directory
+ dir_path = image_file
+ if not os.path.isdir(dir_path):
+ raise ValueError(f"multi_file=True but path is not a directory: {dir_path}")
+
+ # get valid image paths
+ entries = sorted(os.listdir(dir_path))
+ image_paths = []
+ for name in entries:
+ full_path = os.path.join(dir_path, name)
+ if os.path.isfile(full_path):
+ image_paths.append(full_path)
+
+ nodes = []
+ for idx, image in enumerate(image_paths):
+ image_base_name = os.path.basename(image)
+ node = slicer.util.loadVolume(image)
+ node.SetName(image_base_name)
+ nodes.append(node)
+
+ self._volumeNode = nodes[0]
else:
- download_uri = f"{self.serverUrl()}/datastore/image?image={quote_plus(image_id)}"
+ download_uri = f"{self.serverUrl()}/datastore/image?image={quote_plus(id)}"
logging.info(download_uri)
sampleDataLogic = SampleData.SampleDataLogic()
@@ -1326,7 +1364,7 @@ def onNextSampleButton(self):
if slicer.util.settingsValue("MONAILabel/originalLabel", True, converter=slicer.util.toBool):
try:
datastore = self.logic.datastore()
- label_info = datastore["objects"][image_id]["labels"]["original"]["info"]
+ label_info = datastore["objects"][id]["labels"]["original"]["info"]
labels = label_info.get("params", {}).get("label_names", {})
if labels:
@@ -1338,7 +1376,7 @@ def onNextSampleButton(self):
labels = self.logic.info().get("labels")
# ext = datastore['objects'][image_id]['labels']['original']['ext']
- maskFile = self.logic.download_label(image_id, "original")
+ maskFile = self.logic.download_label(id, "original")
self.updateSegmentationMask(maskFile, list(labels))
print("Original label uploaded! ")
diff --git a/sample-apps/radiology/README.md b/sample-apps/radiology/README.md
index 5c7356940..74ec5638d 100644
--- a/sample-apps/radiology/README.md
+++ b/sample-apps/radiology/README.md
@@ -215,6 +215,34 @@ the model to learn on new organ.
- Output: N channels representing the segmented organs/tumors/tissues
+
+
+ Segmentation BraTS is a model based on UNet for automated multilabel brain tumor segmentation. This model is designed for multi-label segmentation tasks using pre-aligned, multi-modal MRI volumes.
+
+
+> monailabel start_server --app workspace/radiology --studies workspace/images --conf models segmentation_brats --conf input_channels 4 --conf multi_file true
+
+- Additional Configs *(pass them as **--conf name value** while starting MONAILabel Server)*
+
+| Name | Values | Description |
+|----------------------|------------------|--------------------------------------------------------------------------|
+| use_pretrained_model | **true**, false | Set to `false` to skip loading pretrained weights |
+| preload | true, **false** | Preload model into GPU at startup |
+| scribbles | **true**, false | Set to `false` to disable scribble-based interactive segmentation models |
+
+- Network: This model uses the [UNet](https://docs.monai.io/en/latest/networks.html#unet) as the default network. Researchers can define their own network or use one of the listed [here](https://docs.monai.io/en/latest/networks.html)
+- Labels
+ ```json
+ {
+ "tumor core": 1,
+ "whole tumor": 2,
+ "enhancing tumor": 3
+ }
+ ```
+- Dataset: The model is trained over the adataset: https://www.med.upenn.edu/cbica/brats2020/
+- Inputs: 4 channels for the 4 BRATS image modalities
+- Output: N channels representing the segmented tumors/tissues
+
diff --git a/sample-apps/radiology/lib/configs/segmentation_brats.py b/sample-apps/radiology/lib/configs/segmentation_brats.py
new file mode 100644
index 000000000..cd2f628a8
--- /dev/null
+++ b/sample-apps/radiology/lib/configs/segmentation_brats.py
@@ -0,0 +1,107 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+from typing import Any, Dict, Optional, Union
+
+import lib.infers
+import lib.trainers
+from monai.networks.nets import SegResNet
+from monai.utils import optional_import
+
+from monailabel.interfaces.config import TaskConfig
+from monailabel.interfaces.tasks.infer_v2 import InferTask
+from monailabel.interfaces.tasks.train import TrainTask
+from monailabel.utils.others.generic import download_file, strtobool
+
+_, has_cp = optional_import("cupy")
+_, has_cucim = optional_import("cucim")
+
+logger = logging.getLogger(__name__)
+
+
+class Segmentation(TaskConfig):
+ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **kwargs):
+ super().init(name, model_dir, conf, planner, **kwargs)
+
+ # BraTS labels: 3 multi-label channels produced by ConvertToMultiChannelBasedOnBratsClassesd
+ # Channel 0: TC - Tumor Core (label 2 OR label 3)
+ # Channel 1: WT - Whole Tumor (label 1 OR label 2 OR label 3)
+ # Channel 2: ET - Enhancing Tumor (label 2)
+ self.labels = {
+ "tumor core": 1, # Tumor Core
+ "whole tumor": 2, # Whole Tumor
+ "enhancing tumor": 3, # Enhancing Tumor
+ }
+
+ # Model Files
+ self.path = [
+ os.path.join(self.model_dir, f"pretrained_{name}.pt"), # pretrained
+ os.path.join(self.model_dir, f"{name}.pt"), # published
+ ]
+
+ # Download PreTrained Model (optional)
+ if strtobool(self.conf.get("use_pretrained_model", "false")):
+ url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
+ url = f"{url}/radiology_segmentation_segresnet_brats.pt"
+ download_file(url, self.path[0])
+
+ # Spacing and ROI for BraTS (isotropic 1mm, large crop matching tutorial)
+ self.target_spacing = (1.0, 1.0, 1.0)
+ self.roi_size = (224, 224, 144)
+
+ # Number of input channels: 4 MRI modalities (FLAIR, T1, T1Gd, T2)
+ # when multi_file=True the LoadDirectoryImagesd loader stacks them;
+ # when multi_file=False the image file must already be a 4-channel volume.
+ try:
+ input_channels = int(self.conf.get("input_channels", 4))
+ except (ValueError, TypeError):
+ logger.warning("Could not parse input_channels, defaulting to 4")
+ input_channels = 4
+
+ # Network
+ self.network = SegResNet(
+ blocks_down=(1, 2, 2, 4),
+ blocks_up=(1, 1, 1),
+ init_filters=16,
+ in_channels=input_channels,
+ out_channels=len(self.labels), # TC, WT, ET — sigmoid multilabel, no background channel
+ dropout_prob=0.2,
+ )
+
+ def infer(self) -> Union[InferTask, Dict[str, InferTask]]:
+ task: InferTask = lib.infers.Segmentation(
+ path=self.path,
+ network=self.network,
+ roi_size=self.roi_size,
+ target_spacing=self.target_spacing,
+ labels=self.labels,
+ preload=strtobool(self.conf.get("preload", "false")),
+ config={"largest_cc": True if has_cp and has_cucim else False},
+ )
+ return task
+
+ def trainer(self) -> Optional[TrainTask]:
+ output_dir = os.path.join(self.model_dir, self.name)
+ load_path = self.path[0] if os.path.exists(self.path[0]) else self.path[1]
+
+ task: TrainTask = lib.trainers.Segmentation(
+ model_dir=output_dir,
+ network=self.network,
+ roi_size=self.roi_size,
+ target_spacing=self.target_spacing,
+ load_path=load_path,
+ publish_path=self.path[1],
+ description="Train BraTS Segmentation Model (TC/WT/ET multilabel)",
+ labels=self.labels,
+ )
+ return task
diff --git a/sample-apps/radiology/lib/infers/segmentation_brats.py b/sample-apps/radiology/lib/infers/segmentation_brats.py
new file mode 100644
index 000000000..6a832946d
--- /dev/null
+++ b/sample-apps/radiology/lib/infers/segmentation_brats.py
@@ -0,0 +1,159 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Sequence
+
+from lib.transforms.transforms import ConvertFromMultiChannelBasedOnBratsClassesd, GetCentroidsd, LoadDirectoryImagesd
+from monai.inferers import Inferer, SlidingWindowInferer
+from monai.transforms import (
+ Activationsd,
+ AsDiscreted,
+ EnsureChannelFirstd,
+ EnsureTyped,
+ KeepLargestConnectedComponentd,
+ LoadImaged,
+ NormalizeIntensityd,
+ Orientationd,
+ Spacingd,
+)
+
+from monailabel.interfaces.tasks.infer_v2 import InferType
+from monailabel.tasks.infer.basic_infer import BasicInferTask
+from monailabel.transform.post import Restored
+
+
+class Segmentation(BasicInferTask):
+ """
+ Inference Engine for BraTS brain tumour segmentation using a SegResNet.
+
+ The model outputs 3 channels (TC, WT, ET) with sigmoid activations — it is
+ a multilabel task, NOT a softmax classification. Each channel is thresholded
+ independently at 0.5 to produce binary maps.
+
+ Two image loading modes are supported (set via ``data["multi_file"]``):
+ - False (default): the input image is a single 4-channel NIfTI volume.
+ - True: ``data["image"]`` is a directory containing 4 single-
+ modality NIfTI files; LoadDirectoryImagesd stacks them.
+ """
+
+ def __init__(
+ self,
+ path,
+ network=None,
+ target_spacing=(1.0, 1.0, 1.0),
+ type=InferType.SEGMENTATION,
+ labels=None,
+ dimension=3,
+ description="Pre-trained BraTS SegResNet — TC/WT/ET multilabel segmentation",
+ **kwargs,
+ ):
+ super().__init__(
+ path=path,
+ network=network,
+ type=type,
+ labels=labels,
+ dimension=dimension,
+ description=description,
+ load_strict=False,
+ **kwargs,
+ )
+ self.target_spacing = target_spacing
+
+ def pre_transforms(self, data=None) -> Sequence[Callable]:
+ """
+ Pre-processing pipeline matching the official MONAI BraTS tutorial.
+
+ NOTE: ScaleIntensityRangePercentilesd and CenterSpatialCropd from the
+ original file have been removed — they are not part of the BraTS pipeline
+ and would distort MRI intensity normalisation. NormalizeIntensityd with
+ nonzero=True, channel_wise=True is the correct approach for multi-modal MRI.
+ """
+ channels = data.get("input_channels", 4)
+ t = [
+ (
+ LoadImaged(keys="image", reader="ITKReader", ensure_channel_first=True)
+ if data.get("multi_file", False) is False
+ else LoadDirectoryImagesd(
+ keys="image",
+ target_spacing=self.target_spacing,
+ channels=channels,
+ )
+ ),
+ EnsureTyped(keys="image", device=data.get("device") if data else None),
+ # EnsureChannelFirstd is safe to keep as a guard; if the channel dim is
+ # already present (ITKReader + ensure_channel_first) it is a no-op.
+ EnsureChannelFirstd(keys="image", channel_dim=0),
+ Orientationd(keys="image", axcodes="RAS"),
+ Spacingd(
+ keys="image",
+ pixdim=self.target_spacing,
+ allow_missing_keys=True,
+ ),
+ # Channel-wise intensity normalisation on non-zero voxels only.
+ # This matches both the tutorial and the training pipeline exactly.
+ NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
+ ]
+ return t
+
+ def inferer(self, data=None) -> Inferer:
+ return SlidingWindowInferer(
+ roi_size=self.roi_size,
+ sw_batch_size=2,
+ overlap=0.4,
+ padding_mode="replicate",
+ mode="gaussian",
+ )
+
+ def inverse_transforms(self, data=None):
+ return []
+
+ def post_transforms(self, data=None) -> Sequence[Callable]:
+ """
+ Post-processing for multilabel sigmoid output.
+
+ IMPORTANT differences from a softmax segmentation:
+ - Activationsd uses sigmoid=True (not softmax=True).
+ - AsDiscreted thresholds each channel at 0.5 independently
+ (not argmax, because channels are not mutually exclusive).
+ - KeepLargestConnectedComponentd is applied per-channel if available.
+ """
+ t = [
+ EnsureTyped(keys="pred", device=data.get("device") if data else None),
+ # Sigmoid: each of the 3 channels (TC, WT, ET) is activated independently.
+ Activationsd(keys="pred", sigmoid=True),
+ # Threshold each channel at 0.5 to produce binary masks.
+ AsDiscreted(keys="pred", threshold=0.5),
+ ]
+
+ if data and data.get("largest_cc", False):
+ # Apply per-channel so TC, WT and ET are each cleaned independently.
+ t.append(
+ KeepLargestConnectedComponentd(
+ keys="pred",
+ independent=True, # treat each channel separately
+ )
+ )
+
+ t.extend(
+ [
+ # Merge 3 binary channels → single-channel integer label map
+ # Must happen before Restored so spatial metadata is applied
+ # to the final (1, H, W, D) output, not the intermediate (3, H, W, D).
+ ConvertFromMultiChannelBasedOnBratsClassesd(keys="pred"),
+ Restored(
+ keys="pred",
+ ref_image="image",
+ config_labels=self.labels if data.get("restore_label_idx", False) else None,
+ ),
+ GetCentroidsd(keys="pred", centroids_key="centroids"),
+ ]
+ )
+ return t
diff --git a/sample-apps/radiology/lib/trainers/segmentation_brats.py b/sample-apps/radiology/lib/trainers/segmentation_brats.py
new file mode 100644
index 000000000..34fcbd8be
--- /dev/null
+++ b/sample-apps/radiology/lib/trainers/segmentation_brats.py
@@ -0,0 +1,199 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import torch
+from lib.transforms.transforms import LoadDirectoryImagesd
+from monai.handlers import TensorBoardImageHandler, from_engine
+from monai.inferers import SlidingWindowInferer
+from monai.losses import DiceLoss
+from monai.transforms import (
+ Activationsd,
+ AsDiscreted,
+ ConvertToMultiChannelBasedOnBratsClassesd,
+ EnsureTyped,
+ LoadImaged,
+ NormalizeIntensityd,
+ Orientationd,
+ RandFlipd,
+ RandScaleIntensityd,
+ RandShiftIntensityd,
+ RandSpatialCropd,
+ Spacingd,
+)
+
+from monailabel.tasks.train.basic_train import BasicTrainTask, Context
+from monailabel.tasks.train.utils import region_wise_metrics
+
+logger = logging.getLogger(__name__)
+
+
+class Segmentation(BasicTrainTask):
+ def __init__(
+ self,
+ model_dir,
+ network,
+ roi_size=(224, 224, 144),
+ target_spacing=(1.0, 1.0, 1.0),
+ num_samples=4,
+ description="Train BraTS Segmentation model (TC/WT/ET multilabel)",
+ **kwargs,
+ ):
+ self._network = network
+ self.roi_size = roi_size
+ self.target_spacing = target_spacing
+ self.num_samples = num_samples
+ super().__init__(model_dir, description, **kwargs)
+
+ def network(self, context: Context):
+ return self._network
+
+ def optimizer(self, context: Context):
+ return torch.optim.Adam(context.network.parameters(), lr=1e-4, weight_decay=1e-5)
+
+ def loss_function(self, context: Context):
+ # BraTS is a sigmoid multilabel task (TC, WT, ET).
+ # to_onehot_y=False because the label is already 3-channel after
+ # ConvertToMultiChannelBasedOnBratsClassesd.
+ # sigmoid=True because each channel is independent (not mutually exclusive).
+ return DiceLoss(
+ smooth_nr=0,
+ smooth_dr=1e-5,
+ squared_pred=True,
+ to_onehot_y=False,
+ sigmoid=True,
+ )
+
+ def lr_scheduler_handler(self, context: Context):
+ return None
+
+ def train_data_loader(self, context, num_workers=0, shuffle=False):
+ return super().train_data_loader(context, num_workers, True)
+
+ def train_pre_transforms(self, context: Context):
+ """
+ Transforms follow the official MONAI BraTS tutorial exactly.
+
+ Two loading paths:
+ - multi_file=False : image is already a single 4-channel .nii.gz volume
+ (LoadImaged handles it, then EnsureChannelFirstd is a no-op
+ because ITKReader + ensure_channel_first already adds the channel dim)
+ - multi_file=True : a directory of 4 single-modality files is stacked by
+ LoadDirectoryImagesd into a (4, H, W, D) tensor
+ """
+ channels = context.input_channels
+ return [
+ LoadDirectoryImagesd(keys="image", target_spacing=self.target_spacing, channels=channels),
+ LoadImaged(keys="label", reader="ITKReader", ensure_channel_first=True),
+ # (
+ # LoadImaged(keys="image", reader="ITKReader", ensure_channel_first=True)
+ # if context.multi_file is False
+ # else LoadDirectoryImagesd(keys="image", target_spacing=self.target_spacing, channels=channels)
+ # ),
+ # ConvertToMultiChannelBasedOnBratsClassesd converts the integer label map
+ # to a 3-channel binary tensor: [TC, WT, ET].
+ ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
+ EnsureTyped(keys=["image", "label"]),
+ Orientationd(keys=["image", "label"], axcodes="RAS"),
+ Spacingd(
+ keys=["image", "label"],
+ pixdim=(1.0, 1.0, 1.0),
+ mode=("bilinear", "nearest"),
+ ),
+ # Random crop matching the official tutorial roi
+ RandSpatialCropd(
+ keys=["image", "label"],
+ roi_size=self.roi_size,
+ random_size=False,
+ ),
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
+ # Channel-wise zero-mean / unit-std normalisation on non-zero voxels only
+ NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
+ RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
+ RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
+ ]
+
+ def train_post_transforms(self, context: Context):
+ """
+ Post-transforms for TRAINING metrics.
+
+ Because this is a sigmoid multilabel task:
+ - Apply sigmoid activation per channel.
+ - Threshold at 0.5 to get binary predictions.
+ - The label is already binary 3-channel — no argmax / to_onehot needed.
+ """
+ return [
+ EnsureTyped(keys="pred", device=context.device),
+ Activationsd(keys="pred", sigmoid=True),
+ AsDiscreted(keys="pred", threshold=0.5),
+ # label is already binary 3-channel, nothing to do
+ ]
+
+ def val_pre_transforms(self, context: Context):
+ channels = context.input_channels
+ return [
+ LoadDirectoryImagesd(keys="image", target_spacing=self.target_spacing, channels=channels),
+ LoadImaged(keys="label", reader="ITKReader", ensure_channel_first=True),
+ # (
+ # LoadImaged(keys="image", reader="ITKReader", ensure_channel_first=True)
+ # if context.multi_file is False
+ # else LoadDirectoryImagesd(keys="image", target_spacing=self.target_spacing, channels=channels)
+ # ),
+ ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
+ EnsureTyped(keys=["image", "label"]),
+ Orientationd(keys=["image", "label"], axcodes="RAS"),
+ Spacingd(
+ keys=["image", "label"],
+ pixdim=(1.0, 1.0, 1.0),
+ mode=("bilinear", "nearest"),
+ ),
+ # No crop during validation — sliding window covers the full volume
+ NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
+ ]
+
+ def val_inferer(self, context: Context):
+ return SlidingWindowInferer(
+ roi_size=self.roi_size,
+ sw_batch_size=2,
+ overlap=0.4,
+ padding_mode="replicate",
+ mode="gaussian",
+ )
+
+ def norm_labels(self):
+ new_label_nums = {}
+ for idx, key_label in enumerate(self._labels.keys(), start=0):
+ if key_label != "background":
+ new_label_nums[key_label] = idx
+ return new_label_nums
+
+ def train_key_metric(self, context: Context):
+ return region_wise_metrics(self.norm_labels(), "train_mean_dice", "train")
+
+ def val_key_metric(self, context: Context):
+ return region_wise_metrics(self.norm_labels(), "val_mean_dice", "val")
+
+ def train_handlers(self, context: Context):
+ handlers = super().train_handlers(context)
+ if context.local_rank == 0:
+ handlers.append(
+ TensorBoardImageHandler(
+ log_dir=context.events_dir,
+ batch_transform=from_engine(["image", "label"]),
+ output_transform=from_engine(["pred"]),
+ interval=20,
+ epoch_level=True,
+ )
+ )
+ return handlers
diff --git a/sample-apps/radiology/lib/transforms/transforms.py b/sample-apps/radiology/lib/transforms/transforms.py
index c24202328..4bf71ed38 100644
--- a/sample-apps/radiology/lib/transforms/transforms.py
+++ b/sample-apps/radiology/lib/transforms/transforms.py
@@ -10,6 +10,7 @@
# limitations under the License.
import copy
import logging
+import os
from typing import Any, Dict, Hashable, Mapping
import numpy as np
@@ -18,7 +19,18 @@
from monai.config import KeysCollection, NdarrayOrTensor
from monai.data import MetaTensor
from monai.networks.layers import GaussianFilter
-from monai.transforms import CropForeground, GaussianSmooth, Randomizable, Resize, ScaleIntensity, SpatialCrop
+from monai.transforms import (
+ ConcatItemsd,
+ CropForeground,
+ EnsureChannelFirst,
+ GaussianSmooth,
+ LoadImage,
+ Randomizable,
+ Resize,
+ ScaleIntensity,
+ Spacing,
+ SpatialCrop,
+)
from monai.transforms.transform import MapTransform, Transform
from monai.utils.enums import CommonKeys
@@ -27,6 +39,141 @@
logger = logging.getLogger(__name__)
+class ConvertFromMultiChannelBasedOnBratsClassesd(MapTransform):
+ """
+ Dictionary-based transform that reverses
+ ``ConvertToMultiChannelBasedOnBratsClassesd``.
+
+ Converts a 3-channel binary prediction (TC, WT, ET) back to a
+ single-channel integer label map:
+
+ Output shape: (1, H, W, D), dtype ``torch.long`` by default.
+
+ Args:
+ keys: keys of the items to be transformed.
+ dtype: output dtype, default ``torch.long``.
+ allow_missing_keys: don't raise an error if a key is missing.
+ """
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ dtype: torch.dtype = torch.long,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.dtype = dtype
+
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
+ d = dict(data)
+ for key in self.key_iterator(d):
+ img = d[key]
+
+ if img.shape[0] != 3:
+ raise ValueError(
+ f"Expected 3-channel input (TC, WT, ET) for key '{key}', " f"got {img.shape[0]} channels."
+ )
+
+ tc = img[0].bool()
+ wt = img[1].bool()
+ et = img[2].bool()
+
+ label_map = torch.zeros_like(img[0], dtype=self.dtype)
+ label_map[wt & ~tc] = 1 # Oedema
+ label_map[et] = 2 # Enhancing tumour
+ label_map[tc & ~et] = 3 # Necrotic core
+
+ result = label_map.unsqueeze(0)
+
+ d[key] = MetaTensor(result, meta=img.meta) if isinstance(img, MetaTensor) else result
+
+ return d
+
+
+# Adapted from https://github.com/Project-MONAI/MONAILabel/issues/241#issuecomment-1497561538
+class LoadDirectoryImagesd(MapTransform):
+ """
+ Load all 3D images from a directory, stack them along a new axis,
+ and preserve MONAI-style metadata similar to LoadImaged.
+
+ - Each key should be a directory path.
+ - Assumes all images share the same spatial dimensions.
+ - Stores image data in `d[key]` and metadata in `d[f"{key}_meta_dict"]`.
+ """
+
+ def __init__(self, keys: KeysCollection, target_spacing=None, allow_missing_keys: bool = False, channels: int = 2):
+ super().__init__(keys, allow_missing_keys)
+ self.target_spacing = target_spacing
+ self.loader = LoadImage(reader="ITKReader", image_only=False)
+ self.ensure_channel_first = EnsureChannelFirst()
+ self.spacer = Spacing(pixdim=self.target_spacing, mode="bilinear") if target_spacing else None
+ self.resizer = None # initialized later
+ self.channels = int(channels)
+
+ def __call__(self, data: Dict):
+ d = dict(data)
+
+ for key in self.key_iterator(d):
+ dir_path = d[key]
+ if not os.path.isdir(dir_path):
+ raise ValueError(f"Expected a directory path for key '{key}', got: {dir_path}")
+
+ # Gather all files in directory
+ image_files = sorted(
+ [
+ os.path.join(dir_path, f)
+ for f in os.listdir(dir_path)
+ if f.lower().endswith((".nii", ".nii.gz", ".nrrd"))
+ ]
+ )
+ if not image_files:
+ raise FileNotFoundError(f"No NIfTI images found in directory {dir_path}")
+
+ channel_keys = []
+ meta_dicts = []
+
+ logger.info(f"Loading {len(image_files)} images from {dir_path}")
+
+ for idx, img_path in enumerate(image_files):
+ img, meta = self.loader(img_path)
+ img = self.ensure_channel_first(img)
+
+ if self.resizer is None:
+ self.resizer = Resize(spatial_size=img.shape[1:], mode="bilinear")
+
+ img = self.resizer(img) if self.resizer is not None else img
+
+ ch_key = f"{key}_ch{idx + 1}"
+ d[ch_key] = img
+ d[f"{ch_key}_meta_dict"] = meta
+
+ channel_keys.append(ch_key)
+ meta_dicts.append(meta)
+
+ logger.debug(f"Loaded {ch_key}: {img.shape}")
+
+ # MONAI-native concatenation
+ self.concat = ConcatItemsd(keys=channel_keys, name=key, dim=0)
+ d = self.concat(d)
+
+ # Clean up temporary channel keys
+ for ch_key in channel_keys:
+ d.pop(ch_key, None)
+ d.pop(f"{ch_key}_meta_dict", None)
+
+ # Construct merged metadata
+ merged_meta = copy.deepcopy(meta_dicts[0])
+ merged_meta["filename_or_obj"] = image_files
+ merged_meta["num_channels"] = len(channel_keys)
+ merged_meta["original_channel_dim"] = 0
+
+ d[f"{key}_meta_dict"] = merged_meta
+
+ logger.info(f"Concatenated {len(channel_keys)} images → {d[key].shape}")
+
+ return d
+
+
class BinaryMaskd(MapTransform):
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
"""
diff --git a/sample-apps/radiology/main.py b/sample-apps/radiology/main.py
index bb7f8ae18..5a1b2ab1f 100644
--- a/sample-apps/radiology/main.py
+++ b/sample-apps/radiology/main.py
@@ -307,12 +307,18 @@ def main():
parser.add_argument("-s", "--studies", default=studies)
parser.add_argument("-m", "--model", default="segmentation")
parser.add_argument("-t", "--test", default="batch_infer", choices=("train", "infer", "batch_infer"))
+ parser.add_argument("-multi", "--multichannel", default=False)
+ parser.add_argument("-c", "--input_channels")
+ parser.add_argument("-multif", "--multi_file", default=False)
args = parser.parse_args()
app_dir = os.path.dirname(__file__)
studies = args.studies
conf = {
"models": args.model,
+ "multichannel": args.multichannel,
+ "input_channels": args.input_channels,
+ "multi_file": args.multi_file,
"preload": "false",
}