@@ -115,7 +115,15 @@ def calc_complexity_nn_part2_plyr(vision_model, data, dec_features):
115115 ),
116116 )
117117
118- # ---------- 4) ROIHeads ----------
118+ # ---------- 4) Measure sem_seg_head if available ----------
119+ # Panoptic/Semantic models use sem_seg_head(x, None)
120+ is_semseg = hasattr (vision_model , "sem_seg_head" ) and vision_model .sem_seg_head is not None
121+ if is_semseg :
122+ semseg_model = SemSegHeadFvcoreWrapper (vision_model .sem_seg_head ).eval ()
123+ # IMPORTANT: pass dict as a single positional arg
124+ kmacs_sum += measure_kmacs (semseg_model , (feature_pyramid ,))
125+
126+ # ---------- 5) ROIHeads ----------
119127 # Run the proposal generator once to obtain actual proposals.
120128 # Only the image size is required for Detectron2, so a minimal dummy object is used.
121129 class _ImagesDummy :
@@ -128,7 +136,7 @@ def __init__(self, image_sizes):
128136 with torch .no_grad ():
129137 proposals , _ = vision_model .proposal_generator (images , feature_pyramid , None )
130138
131- # ---------- 4 -1) Measure box_head + box_predictor ----------
139+ # ---------- 5 -1) Measure box_head + box_predictor ----------
132140 # ROIAlign/Pooler is excluded from FLOPs due to ambiguity and potential CUDA/JIT issues.
133141 # Instead, pooled features are obtained once and only NN blocks are measured.
134142 if hasattr (vision_model , "roi_heads" ) and vision_model .roi_heads is not None :
@@ -152,8 +160,8 @@ def __init__(self, image_sizes):
152160 box_head_model = BoxHeadPredictorFvcoreWrapper (roi_heads ).eval ()
153161 kmacs_sum += measure_kmacs (box_head_model , pooled )
154162
155- # ---------- 4 -2) Measure mask head if available ----------
156- if (
163+ # ---------- 5 -2) Measure mask head if available ----------
164+ if (not is_semseg ) and (
157165 hasattr (roi_heads , "mask_head" )
158166 and roi_heads .mask_head is not None
159167 and hasattr (roi_heads , "mask_pooler" )
@@ -322,6 +330,17 @@ def _cast(x):
322330 name = tag or module .__class__ .__name__
323331 print (f"[INFO] { name } : KMACs = { kmacs } " )
324332 return kmacs
333+
334+ class SemSegHeadFvcoreWrapper (nn .Module ):
335+ def __init__ (self , sem_seg_head : nn .Module ):
336+ super ().__init__ ()
337+ self .sem_seg_head = sem_seg_head
338+
339+ def forward (self , x ):
340+ # detectron2 style: returns (sem_seg_results, losses) or similar
341+ out = self .sem_seg_head (x , None )
342+ return out [0 ] if isinstance (out , (tuple , list )) else out
343+
325344class RPNHeadOnlyFvcoreWrapper (nn .Module ):
326345 """
327346 Wrapper for Detectron2 RPN to measure FLOPs only for the neural network part.
0 commit comments