Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 186 additions & 60 deletions src/mritk/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class Segmentation:
mri (MRIData): The MRIData object containing the segmentation volume and affine.
lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels
to their descriptions. If None, a default numerical mapping is generated. Defaults to None.
Assumes that entries are indexed by the "label" column. If there is no "label" column
the current index is renamed to "label"
"""

mri: MRIData
Expand All @@ -111,16 +113,27 @@ def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None):
self.rois = np.unique(self.mri.data[self.mri.data > 0])

if lut is None:
self.lut = pd.DataFrame({"Label": self.rois}, index=self.rois)
else:
self.lut = lut
lut = pd.DataFrame(
{
"label": self.rois.astype(int),
"description": self.rois.astype(int).astype(str),
}
).set_index("label")

self.set_lut(lut, label_column="label" if "label" in lut.columns else None)
self._preprocess_lut()

# Identify the primary label column dynamically
self.label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0]
def _preprocess_lut(self) -> pd.DataFrame:
# dummy function for subclasses to override if they need to preprocess the LUT after loading
pass

@classmethod
def from_file(
cls, seg_path: Path, dtype: npt.DTypeLike | None = None, orient: bool = True, lut_path: Path | None = None
cls,
seg_path: Path,
dtype: npt.DTypeLike | None = None,
orient: bool = True,
lut_path: Path | None = None,
) -> "Segmentation":
"""Loads a Segmentation from a NIfTI file.

Expand All @@ -136,19 +149,29 @@ def from_file(
logger.info(f"Loading segmentation from {seg_path}.")
mri = MRIData.from_file(seg_path, dtype=dtype, orient=orient)

if lut_path is None and seg_path.with_suffix(".json").exists():
lut_path = seg_path.with_suffix(".json")
if lut_path is None:
if seg_path.with_suffix(".csv").exists():
lut_path = seg_path.with_suffix(".csv")
lut = pd.read_csv(lut_path)
elif seg_path.with_suffix(".json").exists():
lut_path = seg_path.with_suffix(".json")
lut = pd.read_json(lut_path)

if lut_path is not None:
logger.info(f"Loading LUT from {lut_path}.")
lut = pd.read_json(lut_path)
else:
rois = np.unique(mri.data[mri.data > 0])
lut = pd.DataFrame({"Label": rois}, index=rois)
lut = None

return cls(mri=mri, lut=lut)

def save(self, output_path: Path, dtype: npt.DTypeLike | None = None, intent_code: int = 1006, lut_path: Path | None = None):
def save(
self,
output_path: Path,
dtype: npt.DTypeLike | None = None,
intent_code: int = 1006,
lut_path: Path | None = None,
lut_suffix=".csv",
):
"""Saves the Segmentation to a NIfTI file.

Args:
Expand All @@ -157,25 +180,35 @@ def save(self, output_path: Path, dtype: npt.DTypeLike | None = None, intent_cod
intent_code (int, optional): The NIfTI intent code to set in the header. Defaults to 1006 (NIFTI_INTENT_LABEL).
"""
self.mri.save(output_path, dtype=dtype, intent_code=intent_code)

if lut_path is not None:
self.lut.to_json(lut_path, orient="index")
write_lut(lut_path, self.lut)
Comment thread
finsberg marked this conversation as resolved.
else:
self.lut.to_json(output_path.with_suffix(".json"), orient="index")
filename = output_path.name.removesuffix("".join(output_path.suffixes))
write_lut(output_path.parent.joinpath(filename).with_suffix(lut_suffix), self.lut)

def set_lut(self, lut: pd.DataFrame, label_column: str = "Label"):
def set_lut(self, lut: pd.DataFrame, label_column: str | None = None):
"""Sets the Lookup Table (LUT) for the segmentation, ensuring it matches the present ROIs.

Args:
lut (pd.DataFrame): A pandas DataFrame mapping numerical labels
to their descriptions. If None, a default numerical mapping is generated. Defaults to None.
label_column (str, optional): The name of the column in the LUT that contains the label
descriptions. Defaults to "Label".
descriptions which will be used as the index. If None, use the current index. Defaults to None.
If the index is not already named, it is renamed to "label".
"""

self.lut = lut
self.label_name = label_column
if self.label_name not in self.lut.columns:
raise ValueError(f"Specified label column '{self.label_name}' not found in LUT.")

if label_column is not None:
self.lut = lut.set_index(label_column)
self.label_name = label_column
else:
if lut.index.name is not None: # If lut index already is named, use it
self.label_name = lut.index.name
else: # Use label as default name for axis
self.label_name = "label"
self.lut = lut.rename_axis(self.label_name)

@property
def num_rois(self) -> int:
Expand Down Expand Up @@ -209,8 +242,7 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr

if not np.isin(rois, self.rois).all():
raise ValueError("Some of the provided ROIs are not present in the segmentation.")

return self.lut.loc[self.lut.index.isin(rois), [self.label_name]].rename_axis("ROI").reset_index()
return self.lut.loc[rois.astype(self.lut.index.dtype)]

def resample_to_reference(self, reference_mri: MRIData) -> "Segmentation":
"""
Expand Down Expand Up @@ -292,7 +324,11 @@ class FreeSurferSegmentation(Segmentation):

@classmethod
def from_file(
cls, filepath: Path | str, dtype: npt.DTypeLike | None = None, orient: bool = True, lut_path: Path | None = None
cls,
filepath: Path | str,
dtype: npt.DTypeLike | None = None,
orient: bool = True,
lut_path: Path | None = None,
) -> "FreeSurferSegmentation":
"""
Load a FreeSurfer segmentation from a NIfTI file, automatically resolving the LUT.
Expand All @@ -309,13 +345,13 @@ def from_file(
"""
resolved_lut_path = resolve_freesurfer_lut_path(lut_path)
lut = read_freesurfer_lut(resolved_lut_path)

# FreeSurfer LUTs index by the "label" column
lut = lut.set_index("label") if "label" in lut.columns else lut

mri = MRIData.from_file(filepath, dtype=dtype, orient=orient)
return cls(mri=mri, lut=lut)

def _preprocess_lut(self) -> pd.DataFrame:
# FreeSurfer LUTs index by the "label" column
self.lut = self.lut.query("label < 10000") # Most used FreeSurfer labels


class ExtendedFreeSurferSegmentation(FreeSurferSegmentation):
"""
Expand All @@ -326,6 +362,22 @@ class ExtendedFreeSurferSegmentation(FreeSurferSegmentation):
the base FreeSurfer anatomical label (modulus 10000).
"""

def _preprocess_lut(self) -> pd.DataFrame:
super()._preprocess_lut()

# Add CSF and dura tags
base_lut = self.lut.copy()
for i, tissue_type in enumerate(["CSF", "Dura"]):
tissue_lut = base_lut.copy()
tissue_lut.index += 10000 if tissue_type == "CSF" else 20000
tissue_lut["description"] = tissue_lut["description"] + f"-{tissue_type}"
if np.all(np.isin(["R", "G", "B"], base_lut.columns)):
for col in ["R", "G", "B"]:
tissue_lut[col] = np.clip(
tissue_lut[col] * (0.5 + 0.5 * i), 0, 1
) # Shift colors towards blue for CSF and red for Dura
self.lut = pd.concat([self.lut, tissue_lut])

def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFrame:
"""
Retrieves descriptive mappings including the augmented tissue type classifications.
Expand All @@ -338,21 +390,12 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr
pd.DataFrame: A DataFrame mapping the requested ROIs to their base descriptions
and their computed 'tissue_type'.
"""
rois = self.rois if rois is None else rois

# Use modulus 10000 to extract the base anatomical label from the superclass LUT
freesurfer_labels = super().get_roi_labels(rois % 10000).rename(columns={"ROI": "FreeSurfer_ROI"})
roi_labels = super().get_roi_labels(rois)

# Get the broad tissue categories based on the numerical offsets
# Add column specifying tissue_type:
tissue_type = self.get_tissue_type(rois)

# Merge the base anatomical names with the tissue types
return freesurfer_labels.merge(
tissue_type,
left_on="FreeSurfer_ROI",
right_on="FreeSurfer_ROI",
how="outer",
).drop(columns=["FreeSurfer_ROI"])[["ROI", self.label_name, "tissue_type"]]
return pd.merge(roi_labels, tissue_type, on="label")

def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFrame:
"""
Expand All @@ -372,15 +415,14 @@ def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataF
"""
rois = self.rois if rois is None else rois

tissue_types = pd.Series(
data=np.where(rois < 10000, "Parenchyma", np.where(rois < 20000, "CSF", "Dura")),
index=rois,
name="tissue_type",
)
tissue_types = pd.DataFrame(
{
self.label_name: rois,
"tissue_type": np.where(rois < 10000, "Parenchyma", np.where(rois < 20000, "CSF", "Dura")),
}
).set_index(self.label_name)

ret = pd.DataFrame(tissue_types, columns=["tissue_type"]).rename_axis("ROI").reset_index()
ret["FreeSurfer_ROI"] = ret["ROI"] % 10000
return ret
return tissue_types


@dataclass
Expand Down Expand Up @@ -574,15 +616,63 @@ def write_lut(filename: Path, table: pd.DataFrame):
"""
newtable = table.copy()

# Re-scale RGB values to [0, 255]
for col in ["R", "G", "B"]:
newtable[col] = (newtable[col] * 255).astype(int)
if np.all(np.isin(["R", "G", "B"], table.columns)):
# Re-scale RGB values to [0, 255]
for col in ["R", "G", "B"]:
newtable[col] = (newtable[col] * 255).astype(int)

# Reverse Alpha inversion and scale to [0, 255]
newtable["A"] = 255 - (newtable["A"] * 255).astype(int)
# Reverse Alpha inversion and scale to [0, 255]
newtable["A"] = 255 - (newtable["A"] * 255).astype(int)

# Save as tab-separated values without headers or indices
newtable.to_csv(filename, sep="\t", index=False, header=False)
if filename.suffix == ".csv":
newtable.to_csv(filename, sep="\t", index=True, header=False)
elif filename.suffix == ".json":
newtable.to_json(filename, index=True, header=False)
else:
newtable.to_csv(filename, sep="\t", index=True, header=False)


def procedural_freesurfer_lut(labels: list, descriptions: list, cmap: str | None = None) -> pd.DataFrame:
"""
Generates a FreeSurfer compatible lut with colors for each label in a procedural manner

Args:
labels (list): list of labels to include in the lut
descriptions (list): list of descriptions associated to each label
cmap (str, optional): Colormap for label regions. Defaults to "hsv".

Returns:
pd.DataFrame: DataFrame indexed by the label, with RGBA columns
"""
N = len(labels)
if not N == len(descriptions):
raise ValueError("Label and descriptions lists must have same length")

if cmap is not None: # If a colormap is specified, use cmap from matplotlib
import matplotlib.pyplot as plt

# Get evenly spaced values between 0 and 1 based on the number of labels
color_indices = np.linspace(0, 0.95, N)
# Sample a colormap
rgb_float = plt.get_cmap(cmap)(color_indices)
else:
rgb_float = []
import colorsys

for i in range(N):
h = i / N
rgb = list(colorsys.hsv_to_rgb(h, 1.0, 1.0))
rgb.append(1.0) # Add transparency
rgb_float.append(rgb)
rgb_float = np.array(rgb_float)

# Create the DataFrame
df_colors = pd.DataFrame(rgb_float, columns=["R", "G", "B", "A"], index=labels)
df_colors.index.name = "label"
df_colors["description"] = descriptions
lut = df_colors[["description", "R", "G", "B", "A"]]
return lut


def add_arguments(
Expand All @@ -592,7 +682,9 @@ def add_arguments(
subparser = parser.add_subparsers(dest="seg-command", help="Commands for segmentation processing")

resample_parser = subparser.add_parser(
"resample", help="Resample a segmentation to match the space of a reference MRI", formatter_class=parser.formatter_class
"resample",
help="Resample a segmentation to match the space of a reference MRI",
formatter_class=parser.formatter_class,
)
resample_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file")
resample_parser.add_argument(
Expand All @@ -602,19 +694,43 @@ def add_arguments(
help="Path to the reference MRI \
- usually a registered T1 weighted anatomical scan",
)
resample_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resampled segmentation")
resample_parser.add_argument(
"-o",
"--output",
type=Path,
help="Desired output path for the resampled segmentation",
)

smooth_parser = subparser.add_parser(
"smooth",
help="Apply Gaussian smoothing to a segmentation to create a soft probabilistic map",
formatter_class=parser.formatter_class,
)
smooth_parser.add_argument("-i", "--input", type=Path, help="Path to the input (refined) segmentation NIfTI file")
smooth_parser.add_argument("-s", "--sigma", type=float, help="Standard deviation for the Gaussian kernel used in smoothing")
smooth_parser.add_argument(
"-c", "--cutoff", type=float, default=0.5, help="Cutoff score to remove low-confidence voxels (default: 0.5)"
"-i",
"--input",
type=Path,
help="Path to the input (refined) segmentation NIfTI file",
)
smooth_parser.add_argument(
"-s",
"--sigma",
type=float,
help="Standard deviation for the Gaussian kernel used in smoothing",
)
smooth_parser.add_argument(
"-c",
"--cutoff",
type=float,
default=0.5,
help="Cutoff score to remove low-confidence voxels (default: 0.5)",
)
smooth_parser.add_argument(
"-o",
"--output",
type=Path,
help="Desired output path for the smoothed segmentation",
)
smooth_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the smoothed segmentation")

refine_parser = subparser.add_parser(
"refine",
Expand All @@ -629,8 +745,18 @@ def add_arguments(
help="Path to the reference MRI \
- usually a registered T1 weighted anatomical scan",
)
refine_parser.add_argument("-s", "--smooth", type=float, help="Standard deviation for the Gaussian kernel used in smoothing")
refine_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the refined segmentation")
refine_parser.add_argument(
"-s",
"--smooth",
type=float,
help="Standard deviation for the Gaussian kernel used in smoothing",
)
refine_parser.add_argument(
"-o",
"--output",
type=Path,
help="Desired output path for the refined segmentation",
)

if extra_args_cb is not None:
extra_args_cb(resample_parser)
Expand Down
Loading
Loading