Skip to content

Commit 86ed4f7

Browse files
committed
[feat] Support panoptic segmentation in fvcore-based MAC computation
1 parent e678187 commit 86ed4f7

1 file changed

Lines changed: 23 additions & 4 deletions

File tree

compressai_vision/utils/measure_complexity.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,15 @@ def calc_complexity_nn_part2_plyr(vision_model, data, dec_features):
115115
),
116116
)
117117

118-
# ---------- 4) ROIHeads ----------
118+
# ---------- 4) Measure sem_seg_head if available ----------
119+
# Panoptic/Semantic models use sem_seg_head(x, None)
120+
is_semseg = hasattr(vision_model, "sem_seg_head") and vision_model.sem_seg_head is not None
121+
if is_semseg:
122+
semseg_model = SemSegHeadFvcoreWrapper(vision_model.sem_seg_head).eval()
123+
# IMPORTANT: pass dict as a single positional arg
124+
kmacs_sum += measure_kmacs(semseg_model, (feature_pyramid,))
125+
126+
# ---------- 5) ROIHeads ----------
119127
# Run the proposal generator once to obtain actual proposals.
120128
# Only the image size is required for Detectron2, so a minimal dummy object is used.
121129
class _ImagesDummy:
@@ -128,7 +136,7 @@ def __init__(self, image_sizes):
128136
with torch.no_grad():
129137
proposals, _ = vision_model.proposal_generator(images, feature_pyramid, None)
130138

131-
# ---------- 4-1) Measure box_head + box_predictor ----------
139+
# ---------- 5-1) Measure box_head + box_predictor ----------
132140
# ROIAlign/Pooler is excluded from FLOPs due to ambiguity and potential CUDA/JIT issues.
133141
# Instead, pooled features are obtained once and only NN blocks are measured.
134142
if hasattr(vision_model, "roi_heads") and vision_model.roi_heads is not None:
@@ -152,8 +160,8 @@ def __init__(self, image_sizes):
152160
box_head_model = BoxHeadPredictorFvcoreWrapper(roi_heads).eval()
153161
kmacs_sum += measure_kmacs(box_head_model, pooled)
154162

155-
# ---------- 4-2) Measure mask head if available ----------
156-
if (
163+
# ---------- 5-2) Measure mask head if available ----------
164+
if (not is_semseg) and (
157165
hasattr(roi_heads, "mask_head")
158166
and roi_heads.mask_head is not None
159167
and hasattr(roi_heads, "mask_pooler")
@@ -322,6 +330,17 @@ def _cast(x):
322330
name = tag or module.__class__.__name__
323331
print(f"[INFO] {name}: KMACs = {kmacs}")
324332
return kmacs
333+
334+
class SemSegHeadFvcoreWrapper(nn.Module):
335+
def __init__(self, sem_seg_head: nn.Module):
336+
super().__init__()
337+
self.sem_seg_head = sem_seg_head
338+
339+
def forward(self, x):
340+
# detectron2 style: returns (sem_seg_results, losses) or similar
341+
out = self.sem_seg_head(x, None)
342+
return out[0] if isinstance(out, (tuple, list)) else out
343+
325344
class RPNHeadOnlyFvcoreWrapper(nn.Module):
326345
"""
327346
Wrapper for Detectron2 RPN to measure FLOPs only for the neural network part.

0 commit comments

Comments
 (0)