Skip to content

Commit 6c4bd86

Browse files
committed
feat(detections): ✨ paligemma segmentation support added
Signed-off-by: Onuralp SEZER <thunderbirdtr@gmail.com>
1 parent a6e1f03 commit 6c4bd86

2 files changed

Lines changed: 65 additions & 17 deletions

File tree

supervision/detection/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -840,9 +840,10 @@ def from_lmm(
840840

841841
if lmm == LMM.PALIGEMMA:
842842
assert isinstance(result, str)
843-
xyxy, class_id, class_name = from_paligemma(result, **kwargs)
843+
xyxy, class_id, class_name, mask = from_paligemma(result, **kwargs)
844844
data = {CLASS_NAME_DATA_FIELD: class_name}
845-
return cls(xyxy=xyxy, class_id=class_id, data=data)
845+
mask = mask if mask is not None else None
846+
return cls(xyxy=xyxy, class_id=class_id, mask=mask, data=data)
846847

847848
if lmm == LMM.FLORENCE_2:
848849
assert isinstance(result, dict)

supervision/detection/lmm.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,72 @@ def validate_lmm_parameters(
6969

7070
def from_paligemma(
7171
result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]] = None
72-
) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
72+
) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray, Optional[np.ndarray]]:
73+
"""
74+
Parse results from Paligemma model which can contain object detection and segmentation.
75+
76+
Args:
77+
result (str): Model output string containing loc and optional seg tokens
78+
resolution_wh (Tuple[int, int]): Target resolution (width, height)
79+
classes (Optional[List[str]]): List of class names to filter results
80+
81+
Returns:
82+
Tuple containing:
83+
- xyxy (np.ndarray): Bounding box coordinates
84+
- class_id (Optional[np.ndarray]): Class IDs if classes provided
85+
- class_name (np.ndarray): Class names
86+
- mask (Optional[np.ndarray]): Segmentation masks if available
87+
"""
7388
w, h = resolution_wh
74-
pattern = re.compile(
75-
r"(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s\-]+)"
76-
)
77-
matches = pattern.findall(result)
78-
matches = np.array(matches) if matches else np.empty((0, 5))
7989

80-
xyxy, class_name = matches[:, [1, 0, 3, 2]], matches[:, 4]
81-
xyxy = xyxy.astype(int) / 1024 * np.array([w, h, w, h])
82-
class_name = np.char.strip(class_name.astype(str))
83-
class_id = None
90+
segmentation_pattern = re.compile(
91+
r"<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})>\s*"
92+
+ "".join(r"<seg(\d{3})>" for _ in range(16))
93+
+ r"\s+([\w\s\-]+)"
94+
)
8495

85-
if classes is not None:
86-
mask = np.array([name in classes for name in class_name]).astype(bool)
87-
xyxy, class_name = xyxy[mask], class_name[mask]
88-
class_id = np.array([classes.index(name) for name in class_name])
96+
detection_pattern = re.compile(
97+
r"(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s\-]+)"
98+
)
8999

90-
return xyxy, class_id, class_name
100+
segmentation_matches = segmentation_pattern.findall(result)
101+
if segmentation_matches:
102+
matches = np.array(segmentation_matches)
103+
xyxy = matches[:, [1, 0, 3, 2]].astype(int) / 1024 * np.array([w, h, w, h])
104+
class_name = np.char.strip(matches[:, -1].astype(str))
105+
class_id = None
106+
107+
seg_tokens = matches[:, 4:-1].astype(int)
108+
masks = []
109+
for tokens in seg_tokens:
110+
mask = np.zeros((h, w), dtype=bool)
111+
masks.append(mask)
112+
masks = np.array(masks)
113+
114+
if classes is not None:
115+
mask = np.array([name in classes for name in class_name]).astype(bool)
116+
xyxy = xyxy[mask]
117+
class_name = class_name[mask]
118+
masks = masks[mask]
119+
class_id = np.array([classes.index(name) for name in class_name])
120+
121+
return xyxy, class_id, class_name, masks
122+
123+
detection_matches = detection_pattern.findall(result)
124+
if detection_matches:
125+
matches = np.array(detection_matches)
126+
xyxy = matches[:, [1, 0, 3, 2]].astype(int) / 1024 * np.array([w, h, w, h])
127+
class_name = np.char.strip(matches[:, 4].astype(str))
128+
class_id = None
129+
130+
if classes is not None:
131+
mask = np.array([name in classes for name in class_name]).astype(bool)
132+
xyxy, class_name = xyxy[mask], class_name[mask]
133+
class_id = np.array([classes.index(name) for name in class_name])
134+
135+
return xyxy, class_id, class_name, None
136+
137+
return np.empty((0, 4)), None, np.array([]), None
91138

92139

93140
def from_florence_2(

0 commit comments

Comments
 (0)