Skip to content

Commit 16df595

Browse files
authored
Merge pull request #19 from AdaptiveMotorControlLab/jaap/handle_invalid_2d_estimations
Inference API: add handling for (partially) unsuccessful 2d estimations
2 parents 0c4730c + eb61aba commit 16df595

3 files changed

Lines changed: 396 additions & 27 deletions

File tree

fmpose3d/inference_api/README.md

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ Convenience constructor for the **animal** pipeline. Sets `model_type="fmpose3d_
100100
#### `predict(source, *, camera_rotation, seed, progress)` → `Pose3DResult`
101101

102102
End-to-end prediction: 2D estimation followed by 3D lifting in a single call.
103+
Raises `ValueError` when 2D estimation is unusable for lifting
104+
(`Pose2DResult.status` is `ResultStatus.EMPTY` or `ResultStatus.INVALID`).
105+
For partial 2D detections, invalid frames are masked to `NaN` in
106+
`Pose3DResult.poses_3d` and `Pose3DResult.poses_3d_world`.
103107

104108
| Parameter | Type | Description |
105109
|---|---|---|
@@ -121,7 +125,9 @@ Runs only the 2D pose estimation step.
121125
| `source` | `Source` | Same flexible input as `predict()`. |
122126
| `progress` | `ProgressCallback \| None` | Optional progress callback. |
123127
124-
**Returns:** `Pose2DResult` containing `keypoints`, `scores`, and `image_size`.
128+
**Returns:** `Pose2DResult` containing `keypoints`, `scores`, `image_size`,
129+
and `valid_frames_mask`. The object also exposes derived properties
130+
`status` and `status_message`.
125131

126132
---
127133

@@ -168,13 +174,35 @@ Source = Union[str, Path, np.ndarray, Sequence[Union[str, Path, np.ndarray]]]
168174
| `keypoints` | `ndarray` | 2D keypoints, shape `(num_persons, num_frames, J, 2)`. |
169175
| `scores` | `ndarray` | Per-joint confidence, shape `(num_persons, num_frames, J)`. |
170176
| `image_size` | `tuple[int, int]` | `(height, width)` of source frames. |
177+
| `valid_frames_mask` | `ndarray \| None` | Boolean mask, shape `(num_frames,)`, indicating frames with valid detections. |
178+
179+
Computed properties:
180+
181+
- `status``ResultStatus`
182+
- `status_message``str`
183+
184+
#### `ResultStatus`
185+
186+
String enum values:
187+
188+
- `success` — valid detections in all frames
189+
- `partial` — valid detections in a subset of frames
190+
- `empty` — no valid detections in any frame
191+
- `invalid` — output predictions are unusable/malformed
192+
- `unknown` — validity metadata missing or malformed
171193

172194
#### `Pose3DResult`
173195

174196
| Field | Type | Description |
175197
|---|---|---|
176198
| `poses_3d` | `ndarray` | Root-relative 3D poses, shape `(num_frames, J, 3)`. |
177199
| `poses_3d_world` | `ndarray` | Post-processed 3D poses, shape `(num_frames, J, 3)`. For humans: world-coordinate poses. For animals: limb-regularized poses. |
200+
| `valid_frames_mask` | `ndarray \| None` | Boolean mask, shape `(num_frames,)`, indicating frames with valid 3D output. |
201+
202+
Computed properties:
203+
204+
- `status``ResultStatus`
205+
- `status_message``str`
178206

179207

180208

@@ -187,14 +215,14 @@ Source = Union[str, Path, np.ndarray, Sequence[Union[str, Path, np.ndarray]]]
187215
Default 2D estimator for the human pipeline. Wraps HRNet + YOLO with a COCOH36M keypoint conversion.
188216

189217
- `setup_runtime()` — Loads YOLO + HRNet models.
190-
- `predict(frames: ndarray)``(keypoints, scores)` — Returns H36M-format 2D keypoints from BGR frames `(N, H, W, C)`.
218+
- `predict(frames: ndarray)``(keypoints, scores, valid_frames_mask)` — Returns H36M-format 2D keypoints from BGR frames `(N, H, W, C)` plus a frame-level validity mask.
191219

192220
#### `SuperAnimalEstimator(cfg: SuperAnimalConfig | None)`
193221

194222
2D estimator for the animal pipeline. Uses DeepLabCut SuperAnimal and maps quadruped80K keypoints to the 26-joint Animal3D layout.
195223

196224
- `setup_runtime()` — No-op (DLC loads lazily).
197-
- `predict(frames: ndarray)``(keypoints, scores)` — Returns Animal3D-format 2D keypoints from BGR frames.
225+
- `predict(frames: ndarray)``(keypoints, scores, valid_frames_mask)` — Returns Animal3D-format 2D keypoints plus a frame-level validity mask.
198226

199227
---
200228

0 commit comments

Comments
 (0)