From 94097b30490376febd3955544cf2c54ff5f06159 Mon Sep 17 00:00:00 2001 From: Sefa Aras Date: Fri, 23 Jan 2026 00:36:08 +0300 Subject: [PATCH 1/3] Add MAGNUS: Multi-Attention Guided Network for Unified Segmentation - Add MAGNUS hybrid CNN-Transformer architecture for medical image segmentation - Implement CNNPath for hierarchical feature extraction - Implement TransformerPath for global context modeling - Add CrossModalAttentionFusion for bidirectional cross-attention - Add ScaleAdaptiveConv for multi-scale feature extraction - Add SEBlock for channel recalibration - Support both 2D and 3D medical images - Add deep supervision option - Add comprehensive unit tests Reference: Aras et al., IEEE Access 2026, DOI: 10.1109/ACCESS.2026.3656667 Signed-off-by: Sefa Aras --- monai/networks/nets/__init__.py | 1 + monai/networks/nets/magnus.py | 734 +++++++++++++++++++++++++++++ tests/networks/nets/test_magnus.py | 332 +++++++++++++ 3 files changed, 1067 insertions(+) create mode 100644 monai/networks/nets/magnus.py create mode 100644 tests/networks/nets/test_magnus.py diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index c1917e5293..ecb1930f38 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -73,6 +73,7 @@ MedNeXtSmall, MedNextSmall, ) +from .magnus import MAGNUS, CNNPath, CrossModalAttentionFusion, ScaleAdaptiveConv, TransformerPath from .milmodel import MILModel from .netadapter import NetAdapter from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py new file mode 100644 index 0000000000..b4af9b8ee2 --- /dev/null +++ b/monai/networks/nets/magnus.py @@ -0,0 +1,734 @@ +# Copyright Project MONAI Contributors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MAGNUS: Multi-Attention Guided Network for Unified Segmentation via CNN-ViT Fusion + +A hybrid CNN-Transformer architecture that combines multi-scale CNN features +with Vision Transformer representations through cross-modal attention fusion +for advanced medical image segmentation. + +Reference: + Aras, E., Kayikcioglu, T., Aras, S., & Merd, N. (2026). + MAGNUS: Multi-Attention Guided Network for Unified Segmentation via CNN-ViT Fusion. + IEEE Access. DOI: 10.1109/ACCESS.2026.3656667 +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution, UpSample +from monai.networks.layers.factories import Act, Norm +from monai.networks.layers.utils import get_act_layer, get_norm_layer +from monai.utils import ensure_tuple_rep + +__all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv"] + + +class CNNPath(nn.Module): + """ + CNN encoder path with strided convolutions for hierarchical feature extraction. + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + in_channels: number of input channels. + features: sequence of output channels for each encoder stage. + norm: feature normalization type, one of ("batch", "instance", "group"). + act: activation type, one of ("relu", "leakyrelu", "prelu", "gelu"). + dropout: dropout ratio after each convolution block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + features: Sequence[int], + norm: Union[str, tuple] = "batch", + act: Union[str, tuple] = "relu", + dropout: float = 0.0, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.stages = nn.ModuleList() + current_channels = in_channels + + for feat in features: + stage = nn.Sequential( + Convolution( + spatial_dims=spatial_dims, + in_channels=current_channels, + out_channels=feat, + kernel_size=3, + strides=2, + padding=1, + norm=norm, + act=act, + dropout=dropout, + ), + Convolution( + spatial_dims=spatial_dims, + in_channels=feat, + out_channels=feat, + kernel_size=3, + strides=1, + padding=1, + norm=norm, + act=act, + dropout=dropout, + ), + ) + self.stages.append(stage) + current_channels = feat + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """ + Forward pass returning features from each stage. + + Args: + x: input tensor of shape (B, C, *spatial_dims). + + Returns: + List of feature tensors from each encoder stage, + ordered from shallow to deep. + """ + features = [] + for stage in self.stages: + x = stage(x) + features.append(x) + return features + + +class TransformerPath(nn.Module): + """ + Vision Transformer path for global context modeling. + + Applies patch embedding followed by transformer encoder layers + to capture long-range dependencies. + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + in_channels: number of input channels. + hidden_dim: transformer hidden dimension. + num_heads: number of attention heads. + depth: number of transformer encoder layers. + patch_size: size of patches for embedding. + dropout: dropout rate in transformer layers. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + hidden_dim: int, + num_heads: int, + depth: int, + patch_size: int = 16, + dropout: float = 0.1, + mlp_ratio: float = 4.0, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.patch_size = patch_size + self.hidden_dim = hidden_dim + + # Patch embedding via convolution + conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d + self.embedding = conv_type( + in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size + ) + + # Transformer encoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=hidden_dim, + nhead=num_heads, + dim_feedforward=int(hidden_dim * mlp_ratio), + dropout=dropout, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth) + + # Layer normalization + self.norm = nn.LayerNorm(hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through transformer path. + + Args: + x: input tensor of shape (B, C, *spatial_dims). + + Returns: + Transformed features of shape (B, hidden_dim, *reduced_spatial_dims). + """ + # Patch embedding: (B, C, D, H, W) -> (B, hidden_dim, Dp, Hp, Wp) + x_embedded = self.embedding(x) + B = x_embedded.shape[0] + spatial_shape = x_embedded.shape[2:] + + # Flatten spatial dims: (B, hidden_dim, *spatial) -> (B, N, hidden_dim) + x_flat = x_embedded.flatten(2).transpose(1, 2) + + # Apply transformer + x_transformed = self.transformer(x_flat) + x_transformed = self.norm(x_transformed) + + # Reshape back to spatial: (B, N, hidden_dim) -> (B, hidden_dim, *spatial) + x_reshaped = x_transformed.transpose(1, 2).view(B, self.hidden_dim, *spatial_shape) + + return x_reshaped + + +class CrossModalAttentionFusion(nn.Module): + """ + Cross-modal attention fusion between CNN and Transformer features. + + Performs bidirectional cross-attention where CNN features attend to + Transformer features and vice versa, then combines the results. + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + channels: number of input/output channels. + num_heads: number of attention heads. + dropout: dropout rate for attention weights. + """ + + def __init__( + self, + spatial_dims: int, + channels: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super().__init__() + if channels % num_heads != 0: + raise ValueError( + f"channels ({channels}) must be divisible by num_heads ({num_heads})." + ) + + self.spatial_dims = spatial_dims + self.num_heads = num_heads + self.head_dim = channels // num_heads + self.scale = self.head_dim ** -0.5 + self.dropout = nn.Dropout(dropout) + + conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d + + # QKV projections for both paths + self.to_qkv_cnn = conv_type(channels, channels * 3, 1, bias=False) + self.to_qkv_vit = conv_type(channels, channels * 3, 1, bias=False) + + # Output projection + self.to_out = nn.Sequential( + conv_type(channels, channels, 1), + nn.Dropout(dropout) if dropout > 0 else nn.Identity(), + ) + + def forward(self, cnn_feat: torch.Tensor, vit_feat: torch.Tensor) -> torch.Tensor: + """ + Forward pass for cross-modal attention fusion. + + Args: + cnn_feat: CNN features of shape (B, C, *spatial_dims). + vit_feat: ViT features of shape (B, C, *spatial_dims_vit). + + Returns: + Fused features of shape (B, C, *spatial_dims). + """ + B, C = cnn_feat.shape[:2] + spatial_shape = cnn_feat.shape[2:] + heads = self.num_heads + + # Interpolate ViT features to match CNN spatial dimensions + if cnn_feat.shape[2:] != vit_feat.shape[2:]: + mode = "trilinear" if self.spatial_dims == 3 else "bilinear" + vit_feat = F.interpolate( + vit_feat, size=spatial_shape, mode=mode, align_corners=False + ) + + # Compute Q, K, V for both paths + q_c, k_c, v_c = self.to_qkv_cnn(cnn_feat).chunk(3, dim=1) + q_v, k_v, v_v = self.to_qkv_vit(vit_feat).chunk(3, dim=1) + + # Reshape for multi-head attention: (B, heads, head_dim, N) + def reshape_for_attention(t: torch.Tensor) -> torch.Tensor: + return t.view(B, heads, self.head_dim, -1) + + q_c, k_c, v_c = map(reshape_for_attention, (q_c, k_c, v_c)) + q_v, k_v, v_v = map(reshape_for_attention, (q_v, k_v, v_v)) + + # Cross-attention: CNN queries attend to ViT keys/values + attn_cv = torch.einsum("b h d i, b h d j -> b h i j", q_c, k_v) * self.scale + attn_cv = self.dropout(attn_cv.softmax(dim=-1)) + out_c = torch.einsum("b h i j, b h d j -> b h d i", attn_cv, v_v) + + # Cross-attention: ViT queries attend to CNN keys/values + attn_vc = torch.einsum("b h d i, b h d j -> b h i j", q_v, k_c) * self.scale + attn_vc = self.dropout(attn_vc.softmax(dim=-1)) + out_v = torch.einsum("b h i j, b h d j -> b h d i", attn_vc, v_c) + + # Reshape back to spatial + out_c = out_c.contiguous().view(B, C, *spatial_shape) + out_v = out_v.contiguous().view(B, C, *spatial_shape) + + # Combine and project + fused = self.to_out(out_c + out_v) + + return fused + + +class ScaleAdaptiveConv(nn.Module): + """ + Scale-adaptive convolution module with multiple kernel sizes. + + Applies parallel convolutions with different kernel sizes and + combines the outputs for multi-scale feature extraction. + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + in_channels: number of input channels. + out_channels: number of output channels. + kernel_sizes: sequence of kernel sizes to use. + norm: normalization type. + act: activation type. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_sizes: Sequence[int] = (3, 5, 7), + norm: Union[str, tuple] = "batch", + act: Union[str, tuple] = "relu", + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + + conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d + + self.convs = nn.ModuleList([ + conv_type(in_channels, out_channels, k, padding=k // 2, bias=False) + for k in kernel_sizes + ]) + + # Shared normalization and activation + self.norm = get_norm_layer( + name=norm, spatial_dims=spatial_dims, channels=out_channels + ) + self.act = get_act_layer(name=act) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with multi-scale convolutions. + + Args: + x: input tensor of shape (B, C, *spatial_dims). + + Returns: + Multi-scale features of shape (B, out_channels, *spatial_dims). + """ + outs = [conv(x) for conv in self.convs] + out = torch.stack(outs, dim=0).sum(dim=0) + out = self.norm(out) + out = self.act(out) + return out + + +class SEBlock(nn.Module): + """ + Squeeze-and-Excitation block for channel recalibration. + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + channels: number of input/output channels. + reduction: channel reduction ratio for the squeeze operation. + """ + + def __init__( + self, + spatial_dims: int, + channels: int, + reduction: int = 16, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + + pool_type = nn.AdaptiveAvgPool3d if spatial_dims == 3 else nn.AdaptiveAvgPool2d + self.avg_pool = pool_type(1) + + reduced_channels = max(channels // reduction, 1) + self.fc = nn.Sequential( + nn.Linear(channels, reduced_channels, bias=False), + nn.ReLU(inplace=True), + nn.Linear(reduced_channels, channels, bias=False), + nn.Sigmoid(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for SE block. + + Args: + x: input tensor of shape (B, C, *spatial_dims). + + Returns: + Channel-recalibrated tensor of same shape. + """ + b, c = x.shape[:2] + y = self.avg_pool(x).view(b, c) + y = self.fc(y) + + # Reshape for broadcasting + if self.spatial_dims == 3: + y = y.view(b, c, 1, 1, 1) + else: + y = y.view(b, c, 1, 1) + + return x * y.expand_as(x) + + +class DecoderBlock(nn.Module): + """ + Single decoder block with upsampling, skip connection, and SE attention. + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + in_channels: number of input channels. + skip_channels: number of skip connection channels. + out_channels: number of output channels. + norm: normalization type. + act: activation type. + dropout: dropout ratio. + use_se: whether to use SE block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + skip_channels: int, + out_channels: int, + norm: Union[str, tuple] = "batch", + act: Union[str, tuple] = "relu", + dropout: float = 0.0, + use_se: bool = True, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + + # Upsampling with UpSample block + self.upsample = UpSample( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + scale_factor=2, + mode="nontrainable", + interp_mode="trilinear" if spatial_dims == 3 else "bilinear", + align_corners=False, + ) + + # Convolution after concatenation with skip + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels + skip_channels, + out_channels=out_channels, + kernel_size=3, + strides=1, + padding=1, + norm=norm, + act=act, + dropout=dropout, + ) + + # Optional SE block + self.se = SEBlock(spatial_dims, out_channels) if use_se else nn.Identity() + + def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: + """ + Forward pass for decoder block. + + Args: + x: input tensor from previous decoder stage. + skip: skip connection tensor from encoder. + + Returns: + Decoded features tensor. + """ + x = self.upsample(x) + + # Handle spatial dimension mismatch + if x.shape[2:] != skip.shape[2:]: + mode = "trilinear" if self.spatial_dims == 3 else "bilinear" + x = F.interpolate(x, size=skip.shape[2:], mode=mode, align_corners=False) + + x = torch.cat([x, skip], dim=1) + x = self.conv(x) + x = self.se(x) + + return x + + +class MAGNUS(nn.Module): + """ + MAGNUS: Multi-scale Attention Guided Network for Unified Segmentation. + + A hybrid CNN-Transformer architecture that combines: + - CNN path with strided convolutions for hierarchical feature extraction + - Vision Transformer path for global context modeling + - Cross-modal attention fusion for enhanced feature representation + - Scale-adaptive convolutions for multi-scale analysis + - Decoder with SE attention and deep supervision support + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + in_channels: number of input image channels. + out_channels: number of output segmentation classes. + features: sequence of feature channels for encoder stages. + Default: (64, 128, 256, 512). + vit_depth: number of transformer encoder layers. Default: 6. + vit_patch_size: patch size for ViT embedding. Default: 16. + vit_num_heads: number of attention heads in ViT. If None, computed as + features[-1] // 32. Default: None. + fusion_num_heads: number of attention heads in cross-modal fusion. + If None, uses vit_num_heads. Default: None. + scale_kernel_sizes: kernel sizes for scale-adaptive conv. Default: (3, 5, 7). + norm: normalization type ("batch", "instance", "group"). Default: "batch". + act: activation type. Default: "relu". + dropout: dropout ratio. Default: 0.0. + vit_dropout: dropout ratio for transformer. Default: 0.1. + deep_supervision: whether to return auxiliary outputs. Default: False. + aux_weights: weights for auxiliary losses. Default: (0.4, 0.3, 0.3). + + Example: + >>> import torch + >>> from monai.networks.nets import MAGNUS + >>> # 3D segmentation + >>> model = MAGNUS(spatial_dims=3, in_channels=1, out_channels=2) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> y = model(x) + >>> print(y.shape) # torch.Size([1, 2, 64, 64, 64]) + >>> # 2D segmentation + >>> model_2d = MAGNUS(spatial_dims=2, in_channels=3, out_channels=4) + >>> x_2d = torch.randn(1, 3, 256, 256) + >>> y_2d = model_2d(x_2d) + >>> print(y_2d.shape) # torch.Size([1, 4, 256, 256]) + + Reference: + Aras, E., Kayikcioglu, T., Aras, S., & Merd, N. (2026). + MAGNUS: Multi-Attention Guided Network for Unified Segmentation via CNN-ViT Fusion. + IEEE Access. DOI: 10.1109/ACCESS.2026.3656667 + """ + + def __init__( + self, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 1, + features: Sequence[int] = (64, 128, 256, 512), + vit_depth: int = 6, + vit_patch_size: int = 16, + vit_num_heads: Optional[int] = None, + fusion_num_heads: Optional[int] = None, + scale_kernel_sizes: Sequence[int] = (3, 5, 7), + norm: Union[str, tuple] = "batch", + act: Union[str, tuple] = "relu", + dropout: float = 0.0, + vit_dropout: float = 0.1, + deep_supervision: bool = False, + aux_weights: Sequence[float] = (0.4, 0.3, 0.3), + ) -> None: + super().__init__() + + if spatial_dims not in (2, 3): + raise ValueError(f"spatial_dims must be 2 or 3, got {spatial_dims}.") + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.features = list(features) + self.deep_supervision = deep_supervision + self.aux_weights = list(aux_weights) + + # Compute number of attention heads + vit_hidden_dim = self.features[-1] + if vit_num_heads is None: + vit_num_heads = max(vit_hidden_dim // 32, 1) + if fusion_num_heads is None: + fusion_num_heads = vit_num_heads + + # CNN encoder path + self.cnn_path = CNNPath( + spatial_dims=spatial_dims, + in_channels=in_channels, + features=self.features, + norm=norm, + act=act, + dropout=dropout, + ) + + # Transformer path + self.transformer_path = TransformerPath( + spatial_dims=spatial_dims, + in_channels=in_channels, + hidden_dim=vit_hidden_dim, + num_heads=vit_num_heads, + depth=vit_depth, + patch_size=vit_patch_size, + dropout=vit_dropout, + ) + + # Cross-modal attention fusion + self.fusion = CrossModalAttentionFusion( + spatial_dims=spatial_dims, + channels=vit_hidden_dim, + num_heads=fusion_num_heads, + dropout=dropout, + ) + + # Scale-adaptive convolution + self.scale_conv = ScaleAdaptiveConv( + spatial_dims=spatial_dims, + in_channels=vit_hidden_dim, + out_channels=vit_hidden_dim, + kernel_sizes=scale_kernel_sizes, + norm=norm, + act=act, + ) + + # Decoder path + reversed_features = list(reversed(self.features)) + self.decoder_blocks = nn.ModuleList() + self.aux_heads = nn.ModuleList() + + for i in range(len(reversed_features) - 1): + in_ch = reversed_features[i] + out_ch = reversed_features[i + 1] + + self.decoder_blocks.append( + DecoderBlock( + spatial_dims=spatial_dims, + in_channels=in_ch, + skip_channels=out_ch, + out_channels=out_ch, + norm=norm, + act=act, + dropout=dropout, + use_se=True, + ) + ) + + # Auxiliary segmentation heads for deep supervision + if deep_supervision: + conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d + self.aux_heads.append(conv_type(out_ch, out_channels, kernel_size=1)) + + # Final segmentation head + conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d + self.final_conv = conv_type(self.features[0], out_channels, kernel_size=1) + + # Initialize weights + self._init_weights() + + def _init_weights(self) -> None: + """Initialize model weights using Kaiming initialization.""" + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm2d, nn.InstanceNorm3d)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, list[torch.Tensor]]]: + """ + Forward pass of MAGNUS. + + Args: + x: input tensor of shape (B, in_channels, *spatial_dims). + + Returns: + If deep_supervision is False: + Segmentation logits of shape (B, out_channels, *spatial_dims). + If deep_supervision is True: + Tuple of (main_output, auxiliary_outputs) where auxiliary_outputs + is a list of intermediate segmentation maps. + """ + input_shape = x.shape[2:] + + # 1. CNN feature extraction + cnn_features = self.cnn_path(x) + cnn_deepest = cnn_features[-1] + + # 2. Transformer path + vit_features = self.transformer_path(x) + + # 3. Cross-modal attention fusion + fused_features = self.fusion(cnn_deepest, vit_features) + + # 4. Scale-adaptive convolution + scale_features = self.scale_conv(cnn_deepest) + + # 5. Combine fused and scale features + combined = fused_features + scale_features + + # 6. Decoder with skip connections + decoder_out = combined + cnn_skips = list(reversed(cnn_features[:-1])) + aux_outputs = [] + + for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips)): + decoder_out = decoder_block(decoder_out, skip) + + # Auxiliary outputs for deep supervision + if self.deep_supervision and i < len(self.aux_heads): + aux_out = self.aux_heads[i](decoder_out) + aux_out = F.interpolate( + aux_out, + size=input_shape, + mode="trilinear" if self.spatial_dims == 3 else "bilinear", + align_corners=False, + ) + aux_outputs.append(aux_out) + + # 7. Final segmentation + seg_logits = self.final_conv(decoder_out) + + # Upsample to original input size if needed + if seg_logits.shape[2:] != input_shape: + seg_logits = F.interpolate( + seg_logits, + size=input_shape, + mode="trilinear" if self.spatial_dims == 3 else "bilinear", + align_corners=False, + ) + + if self.deep_supervision: + return seg_logits, aux_outputs + + return seg_logits diff --git a/tests/networks/nets/test_magnus.py b/tests/networks/nets/test_magnus.py new file mode 100644 index 0000000000..e789a3835e --- /dev/null +++ b/tests/networks/nets/test_magnus.py @@ -0,0 +1,332 @@ +# Copyright Project MONAI Contributors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for MAGNUS network. + +To run tests: + pytest test_magnus.py -v + +Or with unittest: + python -m pytest test_magnus.py -v +""" + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets.magnus import ( + MAGNUS, + CNNPath, + CrossModalAttentionFusion, + ScaleAdaptiveConv, + TransformerPath, +) + + +# Test cases for MAGNUS model +MAGNUS_TEST_CASES = [ + # (spatial_dims, in_channels, out_channels, input_shape, expected_output_shape) + (3, 1, 2, (1, 1, 64, 64, 64), (1, 2, 64, 64, 64)), + (3, 4, 3, (2, 4, 32, 32, 32), (2, 3, 32, 32, 32)), + (2, 1, 2, (1, 1, 128, 128), (1, 2, 128, 128)), + (2, 3, 5, (2, 3, 64, 64), (2, 5, 64, 64)), +] + +# Test cases for individual components +CNN_PATH_TEST_CASES = [ + (3, 1, (32, 64, 128), (1, 1, 64, 64, 64)), + (2, 3, (64, 128, 256), (1, 3, 128, 128)), +] + +TRANSFORMER_PATH_TEST_CASES = [ + (3, 1, 256, 8, 4, 8, (1, 1, 64, 64, 64)), + (2, 3, 128, 4, 2, 16, (1, 3, 128, 128)), +] + +FUSION_TEST_CASES = [ + (3, 256, 8, (1, 256, 8, 8, 8), (1, 256, 4, 4, 4)), + (2, 128, 4, (1, 128, 16, 16), (1, 128, 8, 8)), +] + + +class TestMAGNUS(unittest.TestCase): + """Test cases for MAGNUS model.""" + + @parameterized.expand(MAGNUS_TEST_CASES) + def test_magnus_shape( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + input_shape: tuple, + expected_shape: tuple, + ): + """Test MAGNUS output shape.""" + model = MAGNUS( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + features=(32, 64, 128, 256), # Smaller for testing + vit_depth=2, + vit_patch_size=8, + ) + model.eval() + + x = torch.randn(*input_shape) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape, expected_shape) + + def test_magnus_deep_supervision(self): + """Test MAGNUS with deep supervision.""" + model = MAGNUS( + spatial_dims=3, + in_channels=1, + out_channels=2, + features=(32, 64, 128, 256), + vit_depth=2, + vit_patch_size=8, + deep_supervision=True, + ) + model.eval() + + x = torch.randn(1, 1, 32, 32, 32) + with torch.no_grad(): + main_out, aux_outs = model(x) + + self.assertEqual(main_out.shape, (1, 2, 32, 32, 32)) + self.assertEqual(len(aux_outs), 3) # 4 stages - 1 = 3 aux outputs + for aux_out in aux_outs: + self.assertEqual(aux_out.shape, (1, 2, 32, 32, 32)) + + def test_magnus_different_norms(self): + """Test MAGNUS with different normalization types.""" + norms = [ + "batch", + "instance", + ("group", {"num_groups": 8}), # GroupNorm requires num_groups + ] + for norm in norms: + model = MAGNUS( + spatial_dims=3, + in_channels=1, + out_channels=2, + features=(32, 64), + vit_depth=1, + vit_patch_size=8, + norm=norm, + ) + model.eval() + + x = torch.randn(1, 1, 32, 32, 32) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape, (1, 2, 32, 32, 32)) + + def test_magnus_gradient_flow(self): + """Test gradient flow through MAGNUS.""" + model = MAGNUS( + spatial_dims=3, + in_channels=1, + out_channels=2, + features=(32, 64), + vit_depth=1, + vit_patch_size=8, + ) + model.train() + + x = torch.randn(1, 1, 32, 32, 32, requires_grad=True) + y = model(x) + loss = y.sum() + loss.backward() + + self.assertIsNotNone(x.grad) + self.assertFalse(torch.isnan(x.grad).any()) + + def test_magnus_invalid_spatial_dims(self): + """Test MAGNUS raises error for invalid spatial_dims.""" + with self.assertRaises(ValueError): + MAGNUS(spatial_dims=4, in_channels=1, out_channels=2) + + +class TestCNNPath(unittest.TestCase): + """Test cases for CNNPath.""" + + @parameterized.expand(CNN_PATH_TEST_CASES) + def test_cnn_path_shape( + self, + spatial_dims: int, + in_channels: int, + features: tuple, + input_shape: tuple, + ): + """Test CNNPath output shapes.""" + model = CNNPath( + spatial_dims=spatial_dims, + in_channels=in_channels, + features=features, + ) + model.eval() + + x = torch.randn(*input_shape) + with torch.no_grad(): + outputs = model(x) + + self.assertEqual(len(outputs), len(features)) + for i, (feat, out) in enumerate(zip(features, outputs)): + self.assertEqual(out.shape[1], feat) + # Each stage downsamples by factor of 2 + expected_spatial = [s // (2 ** (i + 1)) for s in input_shape[2:]] + self.assertEqual(list(out.shape[2:]), expected_spatial) + + +class TestTransformerPath(unittest.TestCase): + """Test cases for TransformerPath.""" + + @parameterized.expand(TRANSFORMER_PATH_TEST_CASES) + def test_transformer_path_shape( + self, + spatial_dims: int, + in_channels: int, + hidden_dim: int, + num_heads: int, + depth: int, + patch_size: int, + input_shape: tuple, + ): + """Test TransformerPath output shape.""" + model = TransformerPath( + spatial_dims=spatial_dims, + in_channels=in_channels, + hidden_dim=hidden_dim, + num_heads=num_heads, + depth=depth, + patch_size=patch_size, + ) + model.eval() + + x = torch.randn(*input_shape) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape[0], input_shape[0]) # Batch + self.assertEqual(y.shape[1], hidden_dim) # Channels + expected_spatial = [s // patch_size for s in input_shape[2:]] + self.assertEqual(list(y.shape[2:]), expected_spatial) + + +class TestCrossModalAttentionFusion(unittest.TestCase): + """Test cases for CrossModalAttentionFusion.""" + + @parameterized.expand(FUSION_TEST_CASES) + def test_fusion_shape( + self, + spatial_dims: int, + channels: int, + num_heads: int, + cnn_shape: tuple, + vit_shape: tuple, + ): + """Test CrossModalAttentionFusion output shape.""" + model = CrossModalAttentionFusion( + spatial_dims=spatial_dims, + channels=channels, + num_heads=num_heads, + ) + model.eval() + + cnn_feat = torch.randn(*cnn_shape) + vit_feat = torch.randn(*vit_shape) + + with torch.no_grad(): + y = model(cnn_feat, vit_feat) + + # Output should match CNN feature shape + self.assertEqual(y.shape, cnn_shape) + + def test_fusion_invalid_channels(self): + """Test fusion raises error when channels not divisible by heads.""" + with self.assertRaises(ValueError): + CrossModalAttentionFusion( + spatial_dims=3, + channels=100, + num_heads=8, # 100 % 8 != 0 + ) + + +class TestScaleAdaptiveConv(unittest.TestCase): + """Test cases for ScaleAdaptiveConv.""" + + def test_scale_adaptive_conv_3d(self): + """Test ScaleAdaptiveConv 3D output shape.""" + model = ScaleAdaptiveConv( + spatial_dims=3, + in_channels=64, + out_channels=128, + kernel_sizes=(3, 5, 7), + ) + model.eval() + + x = torch.randn(1, 64, 16, 16, 16) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape, (1, 128, 16, 16, 16)) + + def test_scale_adaptive_conv_2d(self): + """Test ScaleAdaptiveConv 2D output shape.""" + model = ScaleAdaptiveConv( + spatial_dims=2, + in_channels=32, + out_channels=64, + kernel_sizes=(3, 5), + ) + model.eval() + + x = torch.randn(1, 32, 32, 32) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape, (1, 64, 32, 32)) + + +class TestMAGNUSMemory(unittest.TestCase): + """Memory and performance tests for MAGNUS.""" + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_magnus_cuda(self): + """Test MAGNUS on CUDA.""" + model = MAGNUS( + spatial_dims=3, + in_channels=1, + out_channels=2, + features=(32, 64, 128), + vit_depth=2, + vit_patch_size=8, + ).cuda() + model.eval() + + x = torch.randn(1, 1, 32, 32, 32, device="cuda") + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.device.type, "cuda") + self.assertEqual(y.shape, (1, 2, 32, 32, 32)) + + +if __name__ == "__main__": + unittest.main() From 89be2478a870323a6812c5add8bf32acc0a1919a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Jan 2026 21:41:19 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/magnus.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py index b4af9b8ee2..c5a24eb54d 100644 --- a/monai/networks/nets/magnus.py +++ b/monai/networks/nets/magnus.py @@ -32,9 +32,7 @@ import torch.nn.functional as F from monai.networks.blocks import Convolution, UpSample -from monai.networks.layers.factories import Act, Norm from monai.networks.layers.utils import get_act_layer, get_norm_layer -from monai.utils import ensure_tuple_rep __all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv"] From 64384132aeec6724a117d30a6df5d66abffaeef5 Mon Sep 17 00:00:00 2001 From: Sefa Aras Date: Sun, 25 Jan 2026 13:04:29 +0300 Subject: [PATCH 3/3] Fix TransformerPath positional encoding and aux_weights docs - Add learnable positional embeddings to TransformerPath for proper spatial reasoning - Implement dynamic positional embedding interpolation for varying input sizes - Add positional dropout for regularization - Update aux_weights docstring to clarify it's for external use only Addresses CodeRabbit review comments on PR #8717 Signed-off-by: Sefa Aras --- monai/networks/nets/magnus.py | 47 +++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py index c5a24eb54d..47cb2c4436 100644 --- a/monai/networks/nets/magnus.py +++ b/monai/networks/nets/magnus.py @@ -115,7 +115,8 @@ class TransformerPath(nn.Module): Vision Transformer path for global context modeling. Applies patch embedding followed by transformer encoder layers - to capture long-range dependencies. + to capture long-range dependencies. Includes learnable positional + embeddings that are interpolated to match varying input sizes. Args: spatial_dims: number of spatial dimensions (2 or 3). @@ -150,6 +151,14 @@ def __init__( in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size ) + # Learnable positional embedding (will be interpolated for different input sizes) + # Initialize with a reasonable default size, will adapt dynamically + self.pos_embed = nn.Parameter(torch.zeros(1, 256, hidden_dim)) + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + # Dropout for positional embedding + self.pos_drop = nn.Dropout(p=dropout) + # Transformer encoder encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, @@ -165,6 +174,31 @@ def __init__( # Layer normalization self.norm = nn.LayerNorm(hidden_dim) + def _interpolate_pos_encoding(self, x: torch.Tensor, num_patches: int) -> torch.Tensor: + """ + Interpolate positional embeddings to match the number of patches. + + Args: + x: input tensor for device reference. + num_patches: target number of patches. + + Returns: + Interpolated positional embeddings of shape (1, num_patches, hidden_dim). + """ + if num_patches == self.pos_embed.shape[1]: + return self.pos_embed + + # Interpolate positional embeddings + pos_embed = self.pos_embed.transpose(1, 2) # (1, hidden_dim, N) + pos_embed = F.interpolate( + pos_embed, + size=num_patches, + mode="linear", + align_corners=False, + ) + pos_embed = pos_embed.transpose(1, 2) # (1, num_patches, hidden_dim) + return pos_embed + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through transformer path. @@ -182,6 +216,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Flatten spatial dims: (B, hidden_dim, *spatial) -> (B, N, hidden_dim) x_flat = x_embedded.flatten(2).transpose(1, 2) + num_patches = x_flat.shape[1] + + # Add positional encoding + pos_embed = self._interpolate_pos_encoding(x_flat, num_patches) + x_flat = x_flat + pos_embed + x_flat = self.pos_drop(x_flat) # Apply transformer x_transformed = self.transformer(x_flat) @@ -512,7 +552,10 @@ class MAGNUS(nn.Module): dropout: dropout ratio. Default: 0.0. vit_dropout: dropout ratio for transformer. Default: 0.1. deep_supervision: whether to return auxiliary outputs. Default: False. - aux_weights: weights for auxiliary losses. Default: (0.4, 0.3, 0.3). + aux_weights: suggested weights for auxiliary losses when using deep supervision. + These weights are stored as an attribute for user convenience but are NOT + applied internally. Users should apply them externally when computing the + total loss. Default: (0.4, 0.3, 0.3). Example: >>> import torch