Skip to content

Commit 26c33e2

Browse files
authored
Merge pull request #23 from AdaptiveMotorControlLab/feat/deeplabcut_integration
Minor follow-up refactors for DeepLabCut integration
2 parents d15e500 + 6177ae6 commit 26c33e2

1 file changed

Lines changed: 76 additions & 30 deletions

File tree

fmpose3d/inference_api/fmpose3d.py

Lines changed: 76 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def predict(
261261
images=paths,
262262
max_individuals=cfg.max_individuals,
263263
out_folder=tmpdir,
264+
progress_bar=False
264265
)
265266
# predictions: {image_path: {"bodyparts": (N_ind, K, 3), ...}}
266267
# Iterate in input order to keep frame alignment stable.
@@ -937,23 +938,11 @@ def predict(
937938

938939
# 3D pose lifting
939940
result_3d = self.pose_3d(
940-
result_2d.keypoints,
941-
result_2d.image_size,
941+
result_2d,
942942
camera_rotation=camera_rotation,
943943
seed=seed,
944944
progress=progress,
945945
)
946-
947-
# Propagate 2D result status and validity mask to 3D pose result
948-
result_3d.status_hint = f"2D pose status is {status.value}: {status_msg}"
949-
result_3d.valid_frames_mask = result_2d.valid_frames_mask
950-
951-
# Apply result masking for partial results (set NaN for invalid frames)
952-
if status == ResultStatus.PARTIAL:
953-
invalid = ~result_3d.valid_frames_mask
954-
if np.any(invalid):
955-
result_3d.poses_3d[invalid] = np.nan
956-
result_3d.poses_3d_world[invalid] = np.nan
957946
return result_3d
958947

959948
@torch.no_grad()
@@ -1006,8 +995,8 @@ def prepare_2d(
1006995
@torch.no_grad()
1007996
def pose_3d(
1008997
self,
1009-
keypoints_2d: np.ndarray,
1010-
image_size: tuple[int, int],
998+
keypoints_2d: Pose2DResult | np.ndarray,
999+
image_size: tuple[int, int] | None = None,
10111000
*,
10121001
camera_rotation: np.ndarray | None = _DEFAULT_CAM_ROTATION,
10131002
seed: int | None = None,
@@ -1027,13 +1016,17 @@ def pose_3d(
10271016
10281017
Parameters
10291018
----------
1030-
keypoints_2d : ndarray
1031-
2D keypoints returned by :meth:`prepare_2d`. Accepted shapes:
1019+
keypoints_2d : Pose2DResult or ndarray
1020+
2D keypoints returned by :meth:`prepare_2d`, either as a full
1021+
:class:`Pose2DResult` or as a raw ndarray. Accepted ndarray shapes:
10321022
10331023
* ``(num_persons, num_frames, J, 2)`` -- first person is used.
10341024
* ``(num_frames, J, 2)`` -- treated as a single person.
1035-
image_size : tuple of (int, int)
1025+
image_size : tuple of (int, int) or None
10361026
``(height, width)`` of the source image / video frames.
1027+
Required when ``keypoints_2d`` is an ndarray. Optional when
1028+
``keypoints_2d`` is a :class:`Pose2DResult`; if provided, it must
1029+
match ``Pose2DResult.image_size``.
10371030
camera_rotation : ndarray or None
10381031
Length-4 quaternion for the camera-to-world rotation applied
10391032
to produce ``poses_3d_world``. Defaults to the rotation used
@@ -1053,25 +1046,25 @@ def pose_3d(
10531046
Pose3DResult
10541047
Root-relative and post-processed 3D poses.
10551048
"""
1049+
result_2d: Pose2DResult = self._normalize_3d_input(
1050+
keypoints_2d,
1051+
image_size=image_size
1052+
)
1053+
status, status_msg = result_2d.get_status_info()
1054+
if status in {ResultStatus.EMPTY, ResultStatus.INVALID}:
1055+
raise ValueError(f"2D pose estimation is not usable for 3D lifting: {status.value}. {status_msg}")
1056+
# Just use the first person's keypoints for now.
1057+
kpts = result_2d.keypoints[0]
1058+
h, w = result_2d.image_size
1059+
10561060
self.setup_runtime()
10571061
model = self._model_3d
1058-
h, w = image_size
10591062
steps = self.inference_cfg.sample_steps
10601063

10611064
# Optional deterministic seeding
10621065
if seed is not None:
10631066
torch.manual_seed(seed)
10641067

1065-
# Normalise input shape to (num_frames, J, 2)
1066-
if keypoints_2d.ndim == 4:
1067-
kpts = keypoints_2d[0] # first person
1068-
elif keypoints_2d.ndim == 3:
1069-
kpts = keypoints_2d
1070-
else:
1071-
raise ValueError(
1072-
f"Expected keypoints_2d with 3 or 4 dims, got {keypoints_2d.ndim}"
1073-
)
1074-
10751068
num_frames = kpts.shape[0]
10761069
all_poses_3d: list[np.ndarray] = []
10771070
all_poses_world: list[np.ndarray] = []
@@ -1091,11 +1084,64 @@ def pose_3d(
10911084
if progress:
10921085
progress(i + 1, num_frames)
10931086

1094-
return Pose3DResult(
1087+
result_3d = Pose3DResult(
10951088
poses_3d=np.stack(all_poses_3d, axis=0),
10961089
poses_3d_world=np.stack(all_poses_world, axis=0),
10971090
)
10981091

1092+
# Mask invalid frames in 3D output for partial 2D predictions.
1093+
result_3d.status_hint = f"2D pose status is {status.value}: {status_msg}"
1094+
result_3d.valid_frames_mask = result_2d.valid_frames_mask
1095+
if status == ResultStatus.PARTIAL and result_3d.valid_frames_mask is not None:
1096+
invalid = ~result_3d.valid_frames_mask
1097+
if np.any(invalid):
1098+
result_3d.poses_3d[invalid] = np.nan
1099+
result_3d.poses_3d_world[invalid] = np.nan
1100+
return result_3d
1101+
1102+
def _normalize_3d_input(
1103+
self,
1104+
keypoints_2d: Pose2DResult | np.ndarray,
1105+
*,
1106+
image_size: tuple[int, int] | None,
1107+
) -> Pose2DResult:
1108+
"""Normalise pose_3d inputs into a Pose2DResult instance."""
1109+
if isinstance(keypoints_2d, Pose2DResult):
1110+
if image_size is not None and image_size != keypoints_2d.image_size:
1111+
raise ValueError(
1112+
f"Image size mismatch: Pose2DResult.image_size={keypoints_2d.image_size}, "
1113+
f"image_size={image_size}. Please provide either a Pose2DResult (containing "
1114+
f"image_size), or keypoints_2d as a numpy ndarray together with "
1115+
f"image_size={image_size}."
1116+
)
1117+
return keypoints_2d
1118+
1119+
if not isinstance(keypoints_2d, np.ndarray):
1120+
raise ValueError("keypoints_2d must be a Pose2DResult or a numpy ndarray.")
1121+
if image_size is None:
1122+
raise ValueError(
1123+
"image_size is required when keypoints_2d is provided as an ndarray."
1124+
)
1125+
1126+
if keypoints_2d.ndim == 4:
1127+
keypoints = keypoints_2d
1128+
elif keypoints_2d.ndim == 3:
1129+
# Treat 3D input as a single-person sequence for consistency.
1130+
keypoints = keypoints_2d[np.newaxis]
1131+
else:
1132+
raise ValueError(
1133+
f"Expected keypoints_2d with 3 or 4 dims, got {keypoints_2d.ndim}"
1134+
)
1135+
1136+
scores = np.full(keypoints.shape[:-1], np.nan, dtype=np.float32)
1137+
return Pose2DResult(
1138+
keypoints=keypoints,
1139+
scores=scores,
1140+
image_size=image_size,
1141+
valid_frames_mask=None,
1142+
)
1143+
1144+
10991145
# ------------------------------------------------------------------
11001146
# Private helpers – sampling & post-processing
11011147
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)