Skip to content

Conversation

@sefaaras
Copy link

Description

This PR adds MAGNUS (Multi-Attention Guided Network for Unified Segmentation), a hybrid CNN-Transformer architecture for medical image segmentation.

Key Features

  • Dual-path encoder: CNN path for local features + Vision Transformer path for global context
  • Cross-modal attention fusion: Bidirectional attention between CNN and ViT features
  • Scale-adaptive convolution: Multi-kernel (3, 5, 7) parallel convolutions
  • SE attention: Channel recalibration in decoder blocks
  • Deep supervision: Optional auxiliary outputs for improved training
  • 2D/3D support: Works with both 2D and 3D medical images

New Files

  • monai/networks/nets/magnus.py - Main implementation
  • tests/networks/nets/test_magnus.py - Unit tests (17 tests)

Modified Files

  • monai/networks/nets/__init__.py - Export MAGNUS and components

Usage Example

from monai.networks.nets import MAGNUS

model = MAGNUS(
spatial_dims=3,
in_channels=1,
out_channels=2,
features=(64, 128, 256, 512),
)

Test Results
All 17 unit tests pass ✅

Reference
Aras, E., Kayikcioglu, T., Aras, S., & Merd, N. (2026). MAGNUS: Multi-Attention Guided Network for Unified Segmentation via CNN-ViT Fusion. IEEE Access. DOI: 10.1109/ACCESS.2026.3656667

- Add MAGNUS hybrid CNN-Transformer architecture for medical image segmentation
- Implement CNNPath for hierarchical feature extraction
- Implement TransformerPath for global context modeling
- Add CrossModalAttentionFusion for bidirectional cross-attention
- Add ScaleAdaptiveConv for multi-scale feature extraction
- Add SEBlock for channel recalibration
- Support both 2D and 3D medical images
- Add deep supervision option
- Add comprehensive unit tests

Reference: Aras et al., IEEE Access 2026, DOI: 10.1109/ACCESS.2026.3656667
Signed-off-by: Sefa Aras <sefa666@hotmail.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 22, 2026

📝 Walkthrough

Walkthrough

This PR introduces MAGNUS, a new multi-architecture neural network for CNN-ViT fusion-based segmentation, to the MONAI package. The implementation adds five new public classes (MAGNUS, CNNPath, TransformerPath, CrossModalAttentionFusion, ScaleAdaptiveConv) plus supporting components (SEBlock, DecoderBlock). The model combines hierarchical CNN features with Vision Transformer global context via bidirectional cross-attention fusion, includes multi-scale feature processing, and supports deep supervision. Supporting infrastructure adds comprehensive unit tests across all components and configurations.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 71.43% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly summarizes the main change: introducing MAGNUS, a hybrid CNN-Transformer architecture for medical image segmentation.
Description check ✅ Passed Description includes key features, files changed, usage example, and test status. Matches template structure with description, types of changes, and test coverage.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@monai/networks/nets/magnus.py`:
- Around line 147-166: The transformer path lacks positional embeddings: add a
learnable positional embedding parameter (e.g., self.pos_embed =
nn.Parameter(torch.zeros(1, num_patches, hidden_dim))) initialized (truncated
normal or normal) and sized to match the sequence length produced by
self.embedding (compute num_patches from input spatial dimensions divided by
patch_size or infer from the flattened embedding shape at runtime), then in the
forward pass add this positional embedding to the flattened patch tokens before
passing them into self.transformer; ensure the parameter is registered on the
correct device and that self.norm still applies after the transformer.
🧹 Nitpick comments (4)
monai/networks/nets/magnus.py (2)

37-37: Sort __all__ alphabetically.

Per Ruff RUF022: apply isort-style sorting to __all__.

Proposed fix
-__all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv"]
+__all__ = ["CNNPath", "CrossModalAttentionFusion", "MAGNUS", "ScaleAdaptiveConv", "TransformerPath"]

703-704: Add strict=True to zip().

Ensures decoder_blocks and cnn_skips have matching lengths, catching bugs if construction changes.

Proposed fix
-        for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips)):
+        for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips, strict=True)):
tests/networks/nets/test_magnus.py (2)

29-35: Consider importing from monai.networks.nets to verify public exports.

Current imports bypass the public API. Testing via from monai.networks.nets import MAGNUS would also validate the __init__.py exports.


190-194: Add strict=True to zip().

Although length is asserted above, adding strict=True makes intent explicit.

Proposed fix
-        for i, (feat, out) in enumerate(zip(features, outputs)):
+        for i, (feat, out) in enumerate(zip(features, outputs, strict=True)):

Comment on lines +147 to +166
# Patch embedding via convolution
conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d
self.embedding = conv_type(
in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size
)

# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=int(hidden_dim * mlp_ratio),
dropout=dropout,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

# Layer normalization
self.norm = nn.LayerNorm(hidden_dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the full TransformerPath class to see if positional encoding is present elsewhere
cd /tmp && find . -name "magnus.py" -type f 2>/dev/null | head -1

Repository: Project-MONAI/MONAI

Length of output: 45


🏁 Script executed:

# Let's search for the magnus.py file
find . -path "*/monai/networks/nets/magnus.py" -type f

Repository: Project-MONAI/MONAI

Length of output: 94


🏁 Script executed:

# Alternative search
fd magnus.py

Repository: Project-MONAI/MONAI

Length of output: 127


🏁 Script executed:

# List files in monai/networks/nets to find the file
ls -la monai/networks/nets/ | grep magnus

Repository: Project-MONAI/MONAI

Length of output: 122


🏁 Script executed:

# Read the TransformerPath class to see full implementation
sed -n '113,193p' monai/networks/nets/magnus.py

Repository: Project-MONAI/MONAI

Length of output: 2795


🌐 Web query:

Vision Transformer positional encoding requirement ViT images

💡 Result:

Short answer: Yes—some form of positional encoding (PE) is required for ViT-style models because transformers are permutation-invariant and patch order/2D location must be supplied externally. The original ViT adds learnable 1‑D (flattened) absolute position embeddings to patch embeddings, but many alternatives exist (sinusoidal absolute, 2‑D decomposed, relative, RoPE, semantic/content-aware, or learned 2‑D) and some methods avoid explicit PE by building equivariance into architecture or learning position-like structure implicitly. Choosing which PE (or none) affects generalization, resolution changes, and data efficiency; relative/2‑D and interpolation-friendly schemes are common when varying image size. [1][2][3][4]

Key points and practical notes

  • Why needed: transformers ignore token order by design, so positional info must be provided for spatial reasoning. [1][2]
  • ViT (original): uses learnable 1‑D absolute positional embeddings added to patch embeddings + class token. [2]
  • Alternatives: sinusoidal absolute, learned 2‑D (separate x/y), relative 2‑D encodings, rotary (RoPE), and content-aware/semantic PEs — many improve robustness/generalization. [3][4]
  • Changing image/patch size: absolute learned embeddings must be interpolated or replaced (relative or decomposed 2‑D helps); some works show relative or 2‑D schemes are more resolution‑robust. [3][4]
  • Empirical tradeoffs: learned absolute PE works well with large data; relative/2‑D often improves performance and sample efficiency for vision tasks. [3][4]

If you want, I can:

  • show code examples for adding/interpolating 2D positional embeddings for different input sizes, or
  • compare accuracy and robustness of specific PE choices (with citations/benchmarks).

🏁 Script executed:

# Check the test file to understand expected behavior
grep -n "TransformerPath" tests/networks/nets/test_magnus.py -A 20 | head -50

Repository: Project-MONAI/MONAI

Length of output: 1637


Add positional encoding to the transformer path.

Vision Transformers require positional encoding because transformers are inherently permutation-invariant. Without it, patch positions are indistinguishable, severely limiting spatial reasoning. The original ViT uses learnable 1D absolute position embeddings; alternatives include 2D decomposed, sinusoidal, or relative encodings. Add positional embeddings to the patch representations before passing them to the transformer.

🤖 Prompt for AI Agents
In `@monai/networks/nets/magnus.py` around lines 147 - 166, The transformer path
lacks positional embeddings: add a learnable positional embedding parameter
(e.g., self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_dim)))
initialized (truncated normal or normal) and sized to match the sequence length
produced by self.embedding (compute num_patches from input spatial dimensions
divided by patch_size or infer from the flattened embedding shape at runtime),
then in the forward pass add this positional embedding to the flattened patch
tokens before passing them into self.transformer; ensure the parameter is
registered on the correct device and that self.norm still applies after the
transformer.

Comment on lines +553 to +565
aux_weights: Sequence[float] = (0.4, 0.3, 0.3),
) -> None:
super().__init__()

if spatial_dims not in (2, 3):
raise ValueError(f"spatial_dims must be 2 or 3, got {spatial_dims}.")

self.spatial_dims = spatial_dims
self.in_channels = in_channels
self.out_channels = out_channels
self.features = list(features)
self.deep_supervision = deep_supervision
self.aux_weights = list(aux_weights)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

aux_weights is stored but never used.

The aux_weights parameter is documented and stored as an attribute but not applied anywhere in the model. Either apply them in the forward pass or remove from constructor and document that users should handle weighting externally.

🧰 Tools
🪛 Ruff (0.14.13)

558-558: Avoid specifying long messages outside the exception class

(TRY003)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant