@@ -261,6 +261,7 @@ def predict(
261261 images = paths ,
262262 max_individuals = cfg .max_individuals ,
263263 out_folder = tmpdir ,
264+ progress_bar = False
264265 )
265266 # predictions: {image_path: {"bodyparts": (N_ind, K, 3), ...}}
266267 # Iterate in input order to keep frame alignment stable.
@@ -937,23 +938,11 @@ def predict(
937938
938939 # 3D pose lifting
939940 result_3d = self .pose_3d (
940- result_2d .keypoints ,
941- result_2d .image_size ,
941+ result_2d ,
942942 camera_rotation = camera_rotation ,
943943 seed = seed ,
944944 progress = progress ,
945945 )
946-
947- # Propagate 2D result status and validity mask to 3D pose result
948- result_3d .status_hint = f"2D pose status is { status .value } : { status_msg } "
949- result_3d .valid_frames_mask = result_2d .valid_frames_mask
950-
951- # Apply result masking for partial results (set NaN for invalid frames)
952- if status == ResultStatus .PARTIAL :
953- invalid = ~ result_3d .valid_frames_mask
954- if np .any (invalid ):
955- result_3d .poses_3d [invalid ] = np .nan
956- result_3d .poses_3d_world [invalid ] = np .nan
957946 return result_3d
958947
959948 @torch .no_grad ()
@@ -1006,8 +995,8 @@ def prepare_2d(
1006995 @torch .no_grad ()
1007996 def pose_3d (
1008997 self ,
1009- keypoints_2d : np .ndarray ,
1010- image_size : tuple [int , int ],
998+ keypoints_2d : Pose2DResult | np .ndarray ,
999+ image_size : tuple [int , int ] | None = None ,
10111000 * ,
10121001 camera_rotation : np .ndarray | None = _DEFAULT_CAM_ROTATION ,
10131002 seed : int | None = None ,
@@ -1027,13 +1016,17 @@ def pose_3d(
10271016
10281017 Parameters
10291018 ----------
1030- keypoints_2d : ndarray
1031- 2D keypoints returned by :meth:`prepare_2d`. Accepted shapes:
1019+ keypoints_2d : Pose2DResult or ndarray
1020+ 2D keypoints returned by :meth:`prepare_2d`, either as a full
1021+ :class:`Pose2DResult` or as a raw ndarray. Accepted ndarray shapes:
10321022
10331023 * ``(num_persons, num_frames, J, 2)`` -- first person is used.
10341024 * ``(num_frames, J, 2)`` -- treated as a single person.
1035- image_size : tuple of (int, int)
1025+ image_size : tuple of (int, int) or None
10361026 ``(height, width)`` of the source image / video frames.
1027+ Required when ``keypoints_2d`` is an ndarray. Optional when
1028+ ``keypoints_2d`` is a :class:`Pose2DResult`; if provided, it must
1029+ match ``Pose2DResult.image_size``.
10371030 camera_rotation : ndarray or None
10381031 Length-4 quaternion for the camera-to-world rotation applied
10391032 to produce ``poses_3d_world``. Defaults to the rotation used
@@ -1053,25 +1046,25 @@ def pose_3d(
10531046 Pose3DResult
10541047 Root-relative and post-processed 3D poses.
10551048 """
1049+ result_2d : Pose2DResult = self ._normalize_3d_input (
1050+ keypoints_2d ,
1051+ image_size = image_size
1052+ )
1053+ status , status_msg = result_2d .get_status_info ()
1054+ if status in {ResultStatus .EMPTY , ResultStatus .INVALID }:
1055+ raise ValueError (f"2D pose estimation is not usable for 3D lifting: { status .value } . { status_msg } " )
1056+ # Just use the first person's keypoints for now.
1057+ kpts = result_2d .keypoints [0 ]
1058+ h , w = result_2d .image_size
1059+
10561060 self .setup_runtime ()
10571061 model = self ._model_3d
1058- h , w = image_size
10591062 steps = self .inference_cfg .sample_steps
10601063
10611064 # Optional deterministic seeding
10621065 if seed is not None :
10631066 torch .manual_seed (seed )
10641067
1065- # Normalise input shape to (num_frames, J, 2)
1066- if keypoints_2d .ndim == 4 :
1067- kpts = keypoints_2d [0 ] # first person
1068- elif keypoints_2d .ndim == 3 :
1069- kpts = keypoints_2d
1070- else :
1071- raise ValueError (
1072- f"Expected keypoints_2d with 3 or 4 dims, got { keypoints_2d .ndim } "
1073- )
1074-
10751068 num_frames = kpts .shape [0 ]
10761069 all_poses_3d : list [np .ndarray ] = []
10771070 all_poses_world : list [np .ndarray ] = []
@@ -1091,11 +1084,64 @@ def pose_3d(
10911084 if progress :
10921085 progress (i + 1 , num_frames )
10931086
1094- return Pose3DResult (
1087+ result_3d = Pose3DResult (
10951088 poses_3d = np .stack (all_poses_3d , axis = 0 ),
10961089 poses_3d_world = np .stack (all_poses_world , axis = 0 ),
10971090 )
10981091
1092+ # Mask invalid frames in 3D output for partial 2D predictions.
1093+ result_3d .status_hint = f"2D pose status is { status .value } : { status_msg } "
1094+ result_3d .valid_frames_mask = result_2d .valid_frames_mask
1095+ if status == ResultStatus .PARTIAL and result_3d .valid_frames_mask is not None :
1096+ invalid = ~ result_3d .valid_frames_mask
1097+ if np .any (invalid ):
1098+ result_3d .poses_3d [invalid ] = np .nan
1099+ result_3d .poses_3d_world [invalid ] = np .nan
1100+ return result_3d
1101+
1102+ def _normalize_3d_input (
1103+ self ,
1104+ keypoints_2d : Pose2DResult | np .ndarray ,
1105+ * ,
1106+ image_size : tuple [int , int ] | None ,
1107+ ) -> Pose2DResult :
1108+ """Normalise pose_3d inputs into a Pose2DResult instance."""
1109+ if isinstance (keypoints_2d , Pose2DResult ):
1110+ if image_size is not None and image_size != keypoints_2d .image_size :
1111+ raise ValueError (
1112+ f"Image size mismatch: Pose2DResult.image_size={ keypoints_2d .image_size } , "
1113+ f"image_size={ image_size } . Please provide either a Pose2DResult (containing "
1114+ f"image_size), or keypoints_2d as a numpy ndarray together with "
1115+ f"image_size={ image_size } ."
1116+ )
1117+ return keypoints_2d
1118+
1119+ if not isinstance (keypoints_2d , np .ndarray ):
1120+ raise ValueError ("keypoints_2d must be a Pose2DResult or a numpy ndarray." )
1121+ if image_size is None :
1122+ raise ValueError (
1123+ "image_size is required when keypoints_2d is provided as an ndarray."
1124+ )
1125+
1126+ if keypoints_2d .ndim == 4 :
1127+ keypoints = keypoints_2d
1128+ elif keypoints_2d .ndim == 3 :
1129+ # Treat 3D input as a single-person sequence for consistency.
1130+ keypoints = keypoints_2d [np .newaxis ]
1131+ else :
1132+ raise ValueError (
1133+ f"Expected keypoints_2d with 3 or 4 dims, got { keypoints_2d .ndim } "
1134+ )
1135+
1136+ scores = np .full (keypoints .shape [:- 1 ], np .nan , dtype = np .float32 )
1137+ return Pose2DResult (
1138+ keypoints = keypoints ,
1139+ scores = scores ,
1140+ image_size = image_size ,
1141+ valid_frames_mask = None ,
1142+ )
1143+
1144+
10991145 # ------------------------------------------------------------------
11001146 # Private helpers – sampling & post-processing
11011147 # ------------------------------------------------------------------
0 commit comments