diff --git a/fmpose3d/inference_api/README.md b/fmpose3d/inference_api/README.md index 71a77b1..87f6e58 100644 --- a/fmpose3d/inference_api/README.md +++ b/fmpose3d/inference_api/README.md @@ -100,6 +100,10 @@ Convenience constructor for the **animal** pipeline. Sets `model_type="fmpose3d_ #### `predict(source, *, camera_rotation, seed, progress)` → `Pose3DResult` End-to-end prediction: 2D estimation followed by 3D lifting in a single call. +Raises `ValueError` when 2D estimation is unusable for lifting +(`Pose2DResult.status` is `ResultStatus.EMPTY` or `ResultStatus.INVALID`). +For partial 2D detections, invalid frames are masked to `NaN` in +`Pose3DResult.poses_3d` and `Pose3DResult.poses_3d_world`. | Parameter | Type | Description | |---|---|---| @@ -121,7 +125,9 @@ Runs only the 2D pose estimation step. | `source` | `Source` | Same flexible input as `predict()`. | | `progress` | `ProgressCallback \| None` | Optional progress callback. | -**Returns:** `Pose2DResult` containing `keypoints`, `scores`, and `image_size`. +**Returns:** `Pose2DResult` containing `keypoints`, `scores`, `image_size`, +and `valid_frames_mask`. The object also exposes derived properties +`status` and `status_message`. --- @@ -168,6 +174,22 @@ Source = Union[str, Path, np.ndarray, Sequence[Union[str, Path, np.ndarray]]] | `keypoints` | `ndarray` | 2D keypoints, shape `(num_persons, num_frames, J, 2)`. | | `scores` | `ndarray` | Per-joint confidence, shape `(num_persons, num_frames, J)`. | | `image_size` | `tuple[int, int]` | `(height, width)` of source frames. | +| `valid_frames_mask` | `ndarray \| None` | Boolean mask, shape `(num_frames,)`, indicating frames with valid detections. | + +Computed properties: + +- `status` → `ResultStatus` +- `status_message` → `str` + +#### `ResultStatus` + +String enum values: + +- `success` — valid detections in all frames +- `partial` — valid detections in a subset of frames +- `empty` — no valid detections in any frame +- `invalid` — output predictions are unusable/malformed +- `unknown` — validity metadata missing or malformed #### `Pose3DResult` @@ -175,6 +197,12 @@ Source = Union[str, Path, np.ndarray, Sequence[Union[str, Path, np.ndarray]]] |---|---|---| | `poses_3d` | `ndarray` | Root-relative 3D poses, shape `(num_frames, J, 3)`. | | `poses_3d_world` | `ndarray` | Post-processed 3D poses, shape `(num_frames, J, 3)`. For humans: world-coordinate poses. For animals: limb-regularized poses. | +| `valid_frames_mask` | `ndarray \| None` | Boolean mask, shape `(num_frames,)`, indicating frames with valid 3D output. | + +Computed properties: + +- `status` → `ResultStatus` +- `status_message` → `str` @@ -187,14 +215,14 @@ Source = Union[str, Path, np.ndarray, Sequence[Union[str, Path, np.ndarray]]] Default 2D estimator for the human pipeline. Wraps HRNet + YOLO with a COCO → H36M keypoint conversion. - `setup_runtime()` — Loads YOLO + HRNet models. -- `predict(frames: ndarray)` → `(keypoints, scores)` — Returns H36M-format 2D keypoints from BGR frames `(N, H, W, C)`. +- `predict(frames: ndarray)` → `(keypoints, scores, valid_frames_mask)` — Returns H36M-format 2D keypoints from BGR frames `(N, H, W, C)` plus a frame-level validity mask. #### `SuperAnimalEstimator(cfg: SuperAnimalConfig | None)` 2D estimator for the animal pipeline. Uses DeepLabCut SuperAnimal and maps quadruped80K keypoints to the 26-joint Animal3D layout. - `setup_runtime()` — No-op (DLC loads lazily). -- `predict(frames: ndarray)` → `(keypoints, scores)` — Returns Animal3D-format 2D keypoints from BGR frames. +- `predict(frames: ndarray)` → `(keypoints, scores, valid_frames_mask)` — Returns Animal3D-format 2D keypoints plus a frame-level validity mask. --- diff --git a/fmpose3d/inference_api/fmpose3d.py b/fmpose3d/inference_api/fmpose3d.py index 603277d..ff4e23c 100644 --- a/fmpose3d/inference_api/fmpose3d.py +++ b/fmpose3d/inference_api/fmpose3d.py @@ -12,6 +12,7 @@ import copy from dataclasses import dataclass +from enum import Enum from pathlib import Path from typing import Callable, Sequence, Tuple, Union @@ -82,7 +83,7 @@ def setup_runtime(self) -> None: def predict( self, frames: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Estimate 2D keypoints from image frames and return in H36M format. Parameters @@ -96,6 +97,9 @@ def predict( H36M-format 2D keypoints, shape ``(num_persons, N, 17, 2)``. scores : ndarray Per-joint confidence scores, shape ``(num_persons, N, 17)``. + valid_frames_mask : ndarray + Boolean mask indicating which frames contain at least one + valid detection, shape ``(N,)``. """ from fmpose3d.lib.preprocess import h36m_coco_format, revise_kpts @@ -104,12 +108,70 @@ def predict( keypoints, scores = self._model.predict(frames) keypoints, scores, valid_frames = h36m_coco_format(keypoints, scores) + keypoints, scores = self._validate_predictions( + keypoints, scores, num_frames=frames.shape[0], + ) + valid_frames_mask = self._compute_valid_frames_mask(keypoints, scores) + # NOTE: revise_kpts is computed for consistency but is NOT applied # to the returned keypoints, matching the demo script behaviour. _revised = revise_kpts(keypoints, scores, valid_frames) # noqa: F841 + return keypoints, scores, valid_frames_mask + + def _validate_predictions( + self, + keypoints: np.ndarray, + scores: np.ndarray, + *, + num_frames: int, + ) -> Tuple[np.ndarray, np.ndarray]: + """Validate and normalise HRNet/H36M predictions.""" + num_joints = 17 + + keypoints = np.asarray(keypoints, dtype=np.float32) + scores = np.asarray(scores, dtype=np.float32) + if keypoints.shape[0] == 0: + # h36m_coco_format can drop all persons when all frames are empty. + return ( + np.zeros((1, num_frames, num_joints, 2), dtype=np.float32), + np.zeros((1, num_frames, num_joints), dtype=np.float32), + ) + + if keypoints.ndim != 4 or keypoints.shape[-2:] != (num_joints, 2): + raise ValueError( + f"Invalid HRNet keypoints shape {keypoints.shape}; " + f"expected (num_persons, num_frames, {num_joints}, 2)." + ) + if scores.ndim != 3 or scores.shape[-1] != num_joints: + raise ValueError( + f"Invalid HRNet scores shape {scores.shape}; " + f"expected (num_persons, num_frames, {num_joints})." + ) + if keypoints.shape[:2] != scores.shape[:2]: + raise ValueError( + "HRNet keypoints/scores leading dimensions do not match: " + f"{keypoints.shape[:2]} vs {scores.shape[:2]}." + ) + if keypoints.shape[1] != num_frames: + raise ValueError( + f"HRNet frame count mismatch: got {keypoints.shape[1]}, " + f"expected {num_frames}." + ) return keypoints, scores + @staticmethod + def _compute_valid_frames_mask( + keypoints: np.ndarray, scores: np.ndarray + ) -> np.ndarray: + """Return frame-level validity mask from estimator outputs.""" + safe_scores = np.nan_to_num(scores, nan=0.0) + has_score = np.any(safe_scores > 0, axis=-1) # (num_persons, num_frames) + + safe_kpts = np.nan_to_num(np.abs(keypoints), nan=0.0) + has_kpt = np.any(safe_kpts > 0, axis=(-1, -2)) # (num_persons, num_frames) + return np.any(has_score | has_kpt, axis=0) + # Quadruped80K → Animal3D (26 keypoints) mapping table. # -1 entries are filled by linear interpolation (see _INTERPOLATION_RULES). @@ -148,7 +210,7 @@ def setup_runtime(self) -> None: def predict( self, frames: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Estimate 2D keypoints from image frames in Animal3D format. The method writes *frames* to a temporary directory, runs @@ -166,8 +228,11 @@ def predict( Animal3D-format 2D keypoints, shape ``(1, N, 26, 2)``. The first axis is always 1 (single individual). scores : ndarray - Placeholder confidence scores (all ones), + Mapped per-joint confidence scores, shape ``(1, N, 26)``. + valid_frames_mask : ndarray + Boolean mask indicating which frames contain at least one + valid detection, shape ``(N,)``. """ import cv2 import tempfile @@ -178,6 +243,7 @@ def predict( cfg = self.cfg num_frames = frames.shape[0] all_mapped: list[np.ndarray] = [] + all_scores: list[np.ndarray] = [] with tempfile.TemporaryDirectory() as tmpdir: # Write each frame as an image so DLC can read it. @@ -187,8 +253,7 @@ def predict( cv2.imwrite(p, frames[idx]) paths.append(p) - # Run DeepLabCut on each frame individually (the API - # expects a single image path). + # Run DeepLabCut on each frame individually. for img_path in paths: predictions = superanimal_analyze_images( superanimal_name=cfg.superanimal_name, @@ -199,21 +264,33 @@ def predict( out_folder=tmpdir, ) # predictions: {image_path: {"bodyparts": (N_ind, K, 3), ...}} - for _path, payload in predictions.items(): - bodyparts = payload.get("bodyparts") - if bodyparts is None: - # No detection -- fill with zeros. - all_mapped.append(np.zeros((1, 26, 2), dtype="float32")) - continue - xy = bodyparts[..., :2] # (N_ind, K, 2) - mapped = self._map_keypoints(xy) - # Take only the first individual. - all_mapped.append(mapped[:1]) + payload = predictions.get(img_path) if isinstance(predictions, dict) else None + if payload is None and isinstance(predictions, dict) and len(predictions) == 1: + payload = next(iter(predictions.values())) + + bodyparts = None if payload is None else payload.get("bodyparts") + bodyparts = None if bodyparts is None else np.asarray(bodyparts) + if bodyparts is None or bodyparts.shape[0] == 0: + # No detection -- fill with zeros and zero confidence. + all_mapped.append(np.zeros((1, 26, 2), dtype=np.float32)) + all_scores.append(np.zeros((1, 26), dtype=np.float32)) + continue + + xy = bodyparts[..., :2] # (N_ind, K, 2) + conf = bodyparts[..., 2] # (N_ind, K) + mapped = self._map_keypoints(xy) + mapped_scores = self._map_scores(conf) + + # Take only the first individual. + all_mapped.append(mapped[:1]) + all_scores.append(mapped_scores[:1]) # Stack along frame axis → (1, N, 26, 2) kpts = np.stack(all_mapped, axis=1) # (1, N, 26, 2) - scores = np.ones(kpts.shape[:3], dtype="float32") # (1, N, 26) - return kpts, scores + scores = np.stack(all_scores, axis=1) # (1, N, 26) + kpts, scores = self._validate_predictions(kpts, scores, num_frames=num_frames) + valid_frames_mask = self._compute_valid_frames_mask(kpts, scores) + return kpts, scores, valid_frames_mask # ------------------------------------------------------------------ # @@ -247,6 +324,80 @@ def _map_keypoints(xy: np.ndarray) -> np.ndarray: return mapped + @staticmethod + def _map_scores(conf: np.ndarray) -> np.ndarray: + """Map confidence scores from quadruped80K to Animal3D layout.""" + num_ind, num_src = conf.shape + num_tgt = len(_QUADRUPED80K_TO_ANIMAL3D) + mapped = np.full((num_ind, num_tgt), np.nan, dtype=np.float32) + + for tgt_idx, src_idx in enumerate(_QUADRUPED80K_TO_ANIMAL3D): + if src_idx != -1 and src_idx < num_src: + mapped[:, tgt_idx] = conf[:, src_idx] + elif src_idx == -1 and tgt_idx in _INTERPOLATION_RULES: + s1, s2 = _INTERPOLATION_RULES[tgt_idx] + if s1 < num_src and s2 < num_src: + mapped[:, tgt_idx] = (conf[:, s1] + conf[:, s2]) / 2.0 + + return mapped + + def _validate_predictions( + self, + keypoints: np.ndarray, + scores: np.ndarray, + *, + num_frames: int, + ) -> Tuple[np.ndarray, np.ndarray]: + """Validate and normalise SuperAnimal predictions.""" + num_joints = 26 + keypoints = np.asarray(keypoints, dtype=np.float32) + scores = np.asarray(scores, dtype=np.float32) + + if keypoints.shape[0] == 0: + return ( + np.zeros((1, num_frames, num_joints, 2), dtype=np.float32), + np.zeros((1, num_frames, num_joints), dtype=np.float32), + ) + + if keypoints.ndim != 4 or keypoints.shape[-2:] != (num_joints, 2): + raise ValueError( + f"Invalid SuperAnimal keypoints shape {keypoints.shape}; " + f"expected (num_individuals, num_frames, {num_joints}, 2)." + ) + if scores.ndim != 3 or scores.shape[-1] != num_joints: + raise ValueError( + f"Invalid SuperAnimal scores shape {scores.shape}; " + f"expected (num_individuals, num_frames, {num_joints})." + ) + if keypoints.shape[:2] != scores.shape[:2]: + raise ValueError( + "SuperAnimal keypoints/scores leading dimensions do not match: " + f"{keypoints.shape[:2]} vs {scores.shape[:2]}." + ) + if keypoints.shape[1] != num_frames: + raise ValueError( + f"SuperAnimal frame count mismatch: got {keypoints.shape[1]}, " + f"expected {num_frames}." + ) + + # Normalise unknown values to zeros so downstream code can treat these + # joints as invalid via score==0 while retaining shape stability. + keypoints = np.nan_to_num(keypoints, nan=0.0) + scores = np.nan_to_num(scores, nan=0.0) + return keypoints, scores + + @staticmethod + def _compute_valid_frames_mask( + keypoints: np.ndarray, scores: np.ndarray + ) -> np.ndarray: + """Return frame-level validity mask from estimator outputs.""" + safe_scores = np.nan_to_num(scores, nan=0.0) + has_score = np.any(safe_scores > 0, axis=-1) # (num_persons, num_frames) + + safe_kpts = np.nan_to_num(np.abs(keypoints), nan=0.0) + has_kpt = np.any(safe_kpts > 0, axis=(-1, -2)) # (num_persons, num_frames) + return np.any(has_score | has_kpt, axis=0) + # --------------------------------------------------------------------------- # Limb regularisation (animal post-processing) @@ -444,6 +595,16 @@ def _default_components( # --------------------------------------------------------------------------- +class ResultStatus(str, Enum): + """High-level status for pose estimation outputs.""" + + SUCCESS = "success" + PARTIAL = "partial" + EMPTY = "empty" + INVALID = "invalid" + UNKNOWN = "unknown" + + @dataclass class Pose2DResult: """Container returned by :meth:`FMPose3DInference.prepare_2d`. @@ -458,6 +619,43 @@ class Pose2DResult: """Per-joint confidence scores, shape ``(num_persons, num_frames, J)``.""" image_size: tuple[int, int] = (0, 0) """``(height, width)`` of the source frames.""" + valid_frames_mask: np.ndarray | None = None + """Boolean mask of frames with at least one valid detection, shape ``(N,)``.""" + + @property + def status(self) -> ResultStatus: + """Prediction status derived from ``valid_frames_mask``.""" + return self.get_status_info()[0] + + @property + def status_message(self) -> str: + """Human-readable explanation for :attr:`status`.""" + return self.get_status_info()[1] + + def get_status_info(self) -> tuple[ResultStatus, str]: + """Prediction status derived from ``valid_frames_mask``.""" + # Validate canonical shapes and frame-count consistency. + if self.keypoints.ndim != 4 or self.scores.ndim != 3: + return ResultStatus.INVALID, "Incorrect 2D pose keypoints/scores dimensions." + if self.keypoints.shape[1] != self.scores.shape[1]: + return ResultStatus.INVALID, "2D pose keypoints/scores frame counts do not match." + num_frames = int(self.keypoints.shape[1]) + + if self.valid_frames_mask is None: + return ResultStatus.UNKNOWN, "No frame-validity mask provided by the 2D pose." + if not isinstance(self.valid_frames_mask, np.ndarray) or self.valid_frames_mask.ndim != 1: + return ResultStatus.UNKNOWN, "invalid 2D pose valid_frames_mask: must be a 1D numpy array." + if not np.issubdtype(self.valid_frames_mask.dtype, np.bool_): + return ResultStatus.UNKNOWN, "invalid 2D pose valid_frames_mask: must be a boolean numpy array." + if self.valid_frames_mask.shape[0] != num_frames: + return ResultStatus.INVALID, "2D pose valid_frames_mask mismatches the number of frames." + + valid_count = int(np.sum(self.valid_frames_mask)) + if valid_count == 0: + return ResultStatus.EMPTY, "No valid 2D pose predictions in any frame." + if valid_count < num_frames: + return ResultStatus.PARTIAL, "Missing 2D pose predictions in a subset of frames." + return ResultStatus.SUCCESS, "Valid 2D pose predictions for all frames." @dataclass @@ -477,6 +675,47 @@ class Pose3DResult: ``camera_to_world``). For animal poses this contains the limb-regularised output. """ + valid_frames_mask: np.ndarray | None = None + """Boolean mask of frames with valid 3D poses, shape ``(num_frames,)``.""" + status_hint: str | None = None + """Optional extra context for status reporting.""" + + @property + def status(self) -> ResultStatus: + """Prediction status derived from ``valid_frames_mask``.""" + return self.get_status_info()[0] + + @property + def status_message(self) -> str: + """Human-readable explanation for :attr:`status`.""" + return self.get_status_info()[1] + + def get_status_info(self) -> tuple[ResultStatus, str]: + """Prediction status derived from ``valid_frames_mask``.""" + if self.poses_3d.ndim != 3 or self.poses_3d_world.ndim != 3: + return ResultStatus.INVALID, "Incorrect 3D result dimensions." + num_frames = int(self.poses_3d.shape[0]) + if self.poses_3d_world.shape[0] != num_frames: + return ResultStatus.INVALID, "poses_3d and poses_3d_world frame counts differ." + + def _with_hint(message: str) -> str: + return f"{message} {self.status_hint}" if self.status_hint else message + + if self.valid_frames_mask is None: + return ResultStatus.UNKNOWN, _with_hint("No frame-validity mask provided by the 3D pose.") + if not isinstance(self.valid_frames_mask, np.ndarray) or self.valid_frames_mask.ndim != 1: + return ResultStatus.UNKNOWN, _with_hint("invalid 3D pose valid_frames_mask: must be a 1D numpy array.") + if not np.issubdtype(self.valid_frames_mask.dtype, np.bool_): + return ResultStatus.UNKNOWN, _with_hint("invalid 3D pose valid_frames_mask: must be a boolean numpy array.") + if self.valid_frames_mask.shape[0] != num_frames: + return ResultStatus.INVALID, _with_hint("3D pose valid_frames_mask mismatches the number of frames.") + + valid_count = int(np.sum(self.valid_frames_mask)) + if valid_count == 0: + return ResultStatus.EMPTY, _with_hint("No valid 3D pose predictions in any frame.") + if valid_count < num_frames: + return ResultStatus.PARTIAL, _with_hint("Missing 3D pose predictions in a subset of frames.") + return ResultStatus.SUCCESS, _with_hint("Valid 3D pose predictions for all frames.") #: Accepted source types for :meth:`FMPose3DInference.predict`. @@ -689,8 +928,14 @@ def predict( Pose3DResult Root-relative and world-coordinate 3D poses. """ + # 2D pose estimation result_2d = self.prepare_2d(source) - return self.pose_3d( + status, status_msg = result_2d.get_status_info() + if status in {ResultStatus.EMPTY, ResultStatus.INVALID}: + raise ValueError(f"2D pose estimation is not usable for 3D lifting: {status.value}. {status_msg}") + + # 3D pose lifting + result_3d = self.pose_3d( result_2d.keypoints, result_2d.image_size, camera_rotation=camera_rotation, @@ -698,6 +943,18 @@ def predict( progress=progress, ) + # Propagate 2D result status and validity mask to 3D pose result + result_3d.status_hint = f"2D pose status is {status.value}: {status_msg}" + result_3d.valid_frames_mask = result_2d.valid_frames_mask + + # Apply result masking for partial results (set NaN for invalid frames) + if status == ResultStatus.PARTIAL: + invalid = ~result_3d.valid_frames_mask + if np.any(invalid): + result_3d.poses_3d[invalid] = np.nan + result_3d.poses_3d_world[invalid] = np.nan + return result_3d + @torch.no_grad() def prepare_2d( self, @@ -733,13 +990,16 @@ def prepare_2d( self.setup_runtime() if progress: progress(0, 1) - keypoints, scores = self._estimator_2d.predict(ingested.frames) + keypoints, scores, valid_frames_mask = self._estimator_2d.predict( + ingested.frames + ) if progress: progress(1, 1) return Pose2DResult( keypoints=keypoints, scores=scores, image_size=ingested.image_size, + valid_frames_mask=valid_frames_mask, ) @torch.no_grad() diff --git a/tests/fmpose3d_api/test_fmpose3d.py b/tests/fmpose3d_api/test_fmpose3d.py index f7eaf0e..1877baa 100644 --- a/tests/fmpose3d_api/test_fmpose3d.py +++ b/tests/fmpose3d_api/test_fmpose3d.py @@ -25,6 +25,7 @@ HRNetEstimator, HumanPostProcessor, Pose2DResult, + ResultStatus, Pose3DResult, SuperAnimalEstimator, _default_components, @@ -615,8 +616,9 @@ def test_predict_end_to_end_with_mock_estimator(self): mock_kpts = np.random.randn(1, 1, 17, 2).astype("float32") mock_scores = np.ones((1, 1, 17), dtype="float32") + mock_mask = np.array([True], dtype=bool) api._estimator_2d = MagicMock() - api._estimator_2d.predict.return_value = (mock_kpts, mock_scores) + api._estimator_2d.predict.return_value = (mock_kpts, mock_scores, mock_mask) api._estimator_2d.setup_runtime = MagicMock() frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) @@ -626,6 +628,54 @@ def test_predict_end_to_end_with_mock_estimator(self): assert result.poses_3d.shape == (1, 17, 3) api._estimator_2d.predict.assert_called_once() + def test_predict_applies_partial_2d_mask_to_3d(self): + """predict() masks invalid 2D frames to NaN in 3D outputs.""" + api = _make_ready_api("fmpose3d_humans", test_augmentation=False) + mock_kpts = np.random.randn(1, 3, 17, 2).astype("float32") + mock_scores = np.ones((1, 3, 17), dtype="float32") + mask = np.array([True, False, True], dtype=bool) + api._estimator_2d = MagicMock() + api._estimator_2d.predict.return_value = (mock_kpts, mock_scores, mask) + api._estimator_2d.setup_runtime = MagicMock() + + frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + result = api.predict([frame, frame, frame], seed=42) + + np.testing.assert_array_equal(result.valid_frames_mask, mask) + assert result.status == ResultStatus.PARTIAL + assert np.all(np.isnan(result.poses_3d[1])) + assert np.all(np.isnan(result.poses_3d_world[1])) + + @pytest.mark.parametrize( + "mask,expected_status", + [ + ( + np.array([False, False], dtype=bool), + ResultStatus.EMPTY, + ), + ( + np.array([True], dtype=bool), + ResultStatus.INVALID, + ), + ], + ) + def test_predict_raises_on_unusable_2d_status(self, mask, expected_status): + """predict() raises for EMPTY/INVALID 2D status and skips 3D lifting.""" + api = _make_ready_api("fmpose3d_humans", test_augmentation=False) + mock_kpts = np.random.randn(1, 2, 17, 2).astype("float32") + mock_scores = np.ones((1, 2, 17), dtype="float32") + api._estimator_2d = MagicMock() + api._estimator_2d.predict.return_value = (mock_kpts, mock_scores, mask) + api._estimator_2d.setup_runtime = MagicMock() + api.pose_3d = MagicMock() + + frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + with pytest.raises(ValueError) as exc_info: + api.predict([frame, frame], seed=42) + + assert f": {expected_status.value}." in str(exc_info.value) + api.pose_3d.assert_not_called() + # ========================================================================= # Unit tests — dataclasses @@ -648,12 +698,41 @@ def test_pose2d_result_default_image_size(self): ) assert result.image_size == (0, 0) + def test_pose2d_status_success(self): + result = Pose2DResult( + keypoints=np.zeros((1, 2, 17, 2)), + scores=np.zeros((1, 2, 17)), + valid_frames_mask=np.array([True, True], dtype=bool), + ) + assert result.status == ResultStatus.SUCCESS + assert "all frames" in result.status_message + + def test_pose2d_status_partial(self): + result = Pose2DResult( + keypoints=np.zeros((1, 2, 17, 2)), + scores=np.zeros((1, 2, 17)), + valid_frames_mask=np.array([True, False], dtype=bool), + ) + assert result.status == ResultStatus.PARTIAL + assert "subset" in result.status_message + + def test_pose2d_status_invalid_mask_length(self): + result = Pose2DResult( + keypoints=np.zeros((1, 2, 17, 2)), + scores=np.zeros((1, 2, 17)), + valid_frames_mask=np.array([True], dtype=bool), + ) + assert result.status == ResultStatus.INVALID + assert "mismatches" in result.status_message + def test_pose3d_result(self): p3d = np.random.randn(10, 17, 3) pw = np.random.randn(10, 17, 3) - result = Pose3DResult(poses_3d=p3d, poses_3d_world=pw) + mask = np.ones((10,), dtype=bool) + result = Pose3DResult(poses_3d=p3d, poses_3d_world=pw, valid_frames_mask=mask) assert result.poses_3d is p3d assert result.poses_3d_world is pw + assert result.status == ResultStatus.SUCCESS # ========================================================================= @@ -672,12 +751,13 @@ def test_predict_returns_zeros_when_no_bodyparts(self): "deeplabcut.pose_estimation_pytorch.apis.superanimal_analyze_images", ) as mock_fn: mock_fn.return_value = {"frame.png": {"bodyparts": None}} - kpts, scores = estimator.predict(frames) + kpts, scores, mask = estimator.predict(frames) assert kpts.shape == (1, 2, 26, 2) np.testing.assert_allclose(kpts, 0.0) assert scores.shape == (1, 2, 26) - np.testing.assert_allclose(scores, 1.0) + np.testing.assert_allclose(scores, 0.0) + np.testing.assert_array_equal(mask, np.array([False, False])) def test_predict_maps_valid_bodyparts(self): """Valid DLC bodyparts are mapped to Animal3D layout.""" @@ -692,9 +772,10 @@ def test_predict_maps_valid_bodyparts(self): "deeplabcut.pose_estimation_pytorch.apis.superanimal_analyze_images", ) as mock_fn: mock_fn.return_value = {"frame.png": {"bodyparts": fake_bp}} - kpts, scores = estimator.predict(frames) + kpts, scores, mask = estimator.predict(frames) assert kpts.shape == (1, 1, 26, 2) assert scores.shape == (1, 1, 26) + np.testing.assert_array_equal(mask, np.array([True])) # target[24] ← source[0] → (0*3, 0*3+1) = (0.0, 1.0) np.testing.assert_allclose(kpts[0, 0, 24], fake_bp[0, 0, :2])