diff --git a/perceptionmetrics/utils/segmentation_metrics.py b/perceptionmetrics/utils/segmentation_metrics.py index 3975e5dd..b55ee5b2 100644 --- a/perceptionmetrics/utils/segmentation_metrics.py +++ b/perceptionmetrics/utils/segmentation_metrics.py @@ -255,6 +255,18 @@ def get_averaged_metric( """ metric = getattr(self, f"get_{metric_name}") if method == "macro": + per_class_values = metric(per_class=True) + nan_mask = np.isnan(per_class_values) + missing_count = int(np.sum(nan_mask)) + + if missing_count > 0: + import warnings + # Use implicit string concatenation to keep lines short + msg = ( + f"Warning: {missing_count} class(es) were missing from the confusion matrix. " + f"Their {metric_name.upper()} evaluated to NaN and will be ignored in the macro-average." + ) + warnings.warn(msg, UserWarning) return float(np.nanmean(metric(per_class=True))) if method == "micro": return float(metric(per_class=False))