From 5fd8ddaf43e82ef9b14ecf06a0649d7a458dd77b Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 23 Mar 2026 18:59:45 +0800 Subject: [PATCH 01/25] draft --- embodichain/lab/sim/objects/rigid_object.py | 67 +++ .../graspkit/pg_grasp/antipodal_annotator.py | 489 ++++++++++++++++++ .../graspkit/pg_grasp/antipodal_sampler.py | 231 +++++++++ examples/sim/demo/grasp_mug.py | 257 +++++++++ 4 files changed, 1044 insertions(+) create mode 100644 embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py create mode 100644 embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py create mode 100644 examples/sim/demo/grasp_mug.py diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index 565c5bf4..62207baa 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -34,6 +34,11 @@ from embodichain.utils.math import convert_quat from embodichain.utils.math import matrix_from_quat, quat_from_matrix, matrix_from_euler from embodichain.utils import logger +from embodichain.toolkits.graspkit.pg_grasp.antipodal_annotator import ( + GraspAnnotator, + GraspAnnotatorCfg, +) +import torch.nn.functional as F @dataclass @@ -1108,3 +1113,65 @@ def destroy(self) -> None: arenas = [env] for i, entity in enumerate(self._entities): arenas[i].remove_actor(entity) + + def get_grasp_pose( + self, + cfg: GraspAnnotatorCfg, + approach_direction: torch.Tensor = None, + is_visual: bool = False, + ) -> torch.Tensor: + if approach_direction is None: + approach_direction = torch.tensor( + [0, 0, -1], dtype=torch.float32, device=self.device + ) + approach_direction = F.normalize(approach_direction, dim=-1) + if hasattr(self, "_grasp_annotator") is False: + self._grasp_annotator = GraspAnnotator(cfg=cfg) + if hasattr(self, "_hit_point_pairs") is False or cfg.force_regenerate: + vertices = torch.tensor( + self._entities[0].get_vertices(), + dtype=torch.float32, + device=self.device, + ) + triangles = torch.tensor( + self._entities[0].get_triangles(), dtype=torch.int32, device=self.device + ) + scale = torch.tensor( + self._entities[0].get_body_scale(), + dtype=torch.float32, + device=self.device, + ) + vertices = vertices * scale + self._hit_point_pairs = self._grasp_annotator.annotate(vertices, triangles) + + poses = self.get_local_pose(to_matrix=True) + poses = torch.as_tensor(poses, dtype=torch.float32, device=self.device) + grasp_poses = [] + open_lengths = [] + for pose in poses: + grasp_pose, open_length = self._grasp_annotator.get_approach_grasp_poses( + self._hit_point_pairs, pose, approach_direction + ) + grasp_poses.append(grasp_pose) + open_lengths.append(open_length) + grasp_poses = torch.cat( + [grasp_pose.unsqueeze(0) for grasp_pose in grasp_poses], dim=0 + ) + + if is_visual: + vertices = self._entities[0].get_vertices() + triangles = self._entities[0].get_triangles() + scale = self._entities[0].get_body_scale() + vertices = vertices * scale + GraspAnnotator.visualize_grasp_pose( + vertices=torch.tensor( + vertices, dtype=torch.float32, device=self.device + ), + triangles=torch.tensor( + triangles, dtype=torch.int32, device=self.device + ), + obj_pose=poses[0], + grasp_pose=grasp_poses[0], + open_length=open_lengths[0], + ) + return grasp_poses diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py new file mode 100644 index 00000000..4852879e --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -0,0 +1,489 @@ +import os +import argparse +import open3d as o3d +import time +from pathlib import Path +from typing import Any, cast +import torch +import numpy as np +import trimesh + +import viser +import viser.transforms as tf +from embodichain.utils import logger +from dataclasses import dataclass +from embodichain.toolkits.graspkit.pg_grasp.antipodal_sampler import ( + AntipodalSampler, + AntipodalSamplerCfg, +) +import hashlib +import torch.nn.functional as F +import tempfile + + +@dataclass +class GraspAnnotatorCfg: + viser_port: int = 15531 + use_largest_connected_component: bool = False + antipodal_sampler_cfg: AntipodalSamplerCfg = AntipodalSamplerCfg() + force_regenerate: bool = False + max_deviation_angle: float = np.pi / 12 + + +@dataclass +class SelectResult: + vertex_indices: np.ndarray | None = None + face_indices: np.ndarray | None = None + vertices: np.ndarray | None = None + faces: np.ndarray | None = None + + +class GraspAnnotator: + def __init__(self, cfg: GraspAnnotatorCfg = GraspAnnotatorCfg()) -> None: + self.cfg = cfg + self.antipodal_sampler = AntipodalSampler(cfg=cfg.antipodal_sampler_cfg) + + def annotate(self, vertices: torch.Tensor, triangles: torch.Tensor): + cache_path = self._get_cache_dir(vertices, triangles) + if os.path.exists(cache_path) and not self.cfg.force_regenerate: + logger.log_info( + f"Found existing antipodal retult. Loading cached antipodal pairs from {cache_path}" + ) + hit_point_pairs = torch.tensor( + np.load(cache_path), dtype=torch.float32, device=vertices.device + ) + return hit_point_pairs + else: + logger.log_info( + f"[Viser] *****Annotate grasp region in http://localhost:{self.cfg.viser_port}" + ) + + self.mesh = trimesh.Trimesh( + vertices=vertices.to("cpu").numpy(), + faces=triangles.to("cpu").numpy(), + process=False, + force="mesh", + ) + self.device = vertices.device + + server = viser.ViserServer(port=self.cfg.viser_port) + server.gui.configure_theme(brand_color=(130, 0, 150)) + server.scene.set_up_direction("+z") + + mesh_handle = server.scene.add_mesh_trimesh(name="/mesh", mesh=self.mesh) + selected_overlay: viser.GlbHandle | None = None + selection: SelectResult = SelectResult() + + hit_point_pairs = None + return_flag = False + + @server.on_client_connect + def _(client: viser.ClientHandle) -> None: + nonlocal mesh_handle + nonlocal selected_overlay + nonlocal selection + + # client.camera.position = np.array([0.0, 0.0, -0.5]) + # client.camera.wxyz = np.array([1.0, 0.0, 0.0, 0.0]) + + select_button = client.gui.add_button( + "Rect Select Region", icon=viser.Icon.PAINT + ) + confirm_button = client.gui.add_button("Confirm Selection") + + @select_button.on_click + def _(_evt: viser.GuiEvent) -> None: + select_button.disabled = True + + @client.scene.on_pointer_event(event_type="rect-select") + def _(event: viser.ScenePointerEvent) -> None: + nonlocal mesh_handle + nonlocal selected_overlay + nonlocal selection + nonlocal hit_point_pairs + client.scene.remove_pointer_callback() + + proj, depth = GraspAnnotator._project_vertices_to_screen( + cast(np.ndarray, self.mesh.vertices), + mesh_handle, + event.client.camera, + ) + + lower = np.minimum( + np.array(event.screen_pos[0]), np.array(event.screen_pos[1]) + ) + upper = np.maximum( + np.array(event.screen_pos[0]), np.array(event.screen_pos[1]) + ) + vertex_mask = ((proj >= lower) & (proj <= upper)).all(axis=1) & ( + depth > 1e-6 + ) + + selection = GraspAnnotator._extract_selection( + self.mesh, vertex_mask, self.cfg.use_largest_connected_component + ) + if selection.vertices is None: + logger.log_warning("[Selection] No vertices selected.") + return + + color_mesh = self.mesh.copy() + used_vertex_indices = selection.vertex_indices + vertex_colors = np.tile( + np.array([[0.85, 0.85, 0.85, 1.0]]), + (self.mesh.vertices.shape[0], 1), + ) + vertex_colors[used_vertex_indices] = np.array( + [0.56, 0.17, 0.92, 1.0] + ) + color_mesh.visual.vertex_colors = vertex_colors # type: ignore + mesh_handle = server.scene.add_mesh_trimesh( + name="/mesh", mesh=color_mesh + ) + + if selected_overlay is not None: + selected_overlay.remove() + selected_mesh = trimesh.Trimesh( + vertices=selection.vertices, + faces=selection.faces, + process=False, + ) + selected_mesh.visual.face_colors = (0.9, 0.2, 0.2, 0.65) # type: ignore + selected_overlay = server.scene.add_mesh_trimesh( + name="/selected", mesh=selected_mesh + ) + logger.log_info( + f"[Selection] Selected {selection.vertex_indices.size} vertices and {selection.face_indices.size} faces." + ) + + hit_point_pairs = self.antipodal_sampler.sample( + torch.tensor(selection.vertices, device=self.device), + torch.tensor(selection.faces, device=self.device), + ) + extended_hit_point_pairs = GraspAnnotator._extend_hit_point_pairs( + hit_point_pairs + ) + server.scene.add_line_segments( + name="/antipodal_pairs", + points=extended_hit_point_pairs.to("cpu").numpy(), + colors=(20, 200, 200), + line_width=1.5, + ) + + @client.scene.on_pointer_callback_removed + def _() -> None: + select_button.disabled = False + + @confirm_button.on_click + def _(_evt: viser.GuiEvent) -> None: + nonlocal return_flag + if selection.vertices is None: + logger.log_warning("[Selection] No vertex selected.") + return + else: + logger.log_info( + f"[Selection] {selection.vertices.shape[0]}vertices selected. Generating antipodal point pairs." + ) + return_flag = True + + while True: + if return_flag: + # save result to cache + if hit_point_pairs is not None: + self._save_cache(cache_path, hit_point_pairs) + break + time.sleep(0.5) + return hit_point_pairs + + def _get_cache_dir(self, vertices: torch.Tensor, triangles: torch.Tensor): + vert_bytes = vertices.to("cpu").numpy().tobytes() + face_bytes = triangles.to("cpu").numpy().tobytes() + md5_hash = hashlib.md5(vert_bytes + face_bytes).hexdigest() + cache_path = os.path.join( + tempfile.gettempdir(), f"antipodal_cache_{md5_hash}.npy" + ) + return cache_path + + def _save_cache(self, cache_path: str, hit_point_pairs: torch.Tensor): + np.save(cache_path, hit_point_pairs.cpu().numpy().astype(np.float32)) + + @staticmethod + def _extend_hit_point_pairs(hit_point_pairs: torch.Tensor): + origin_points = hit_point_pairs[:, 0, :] + hit_points = hit_point_pairs[:, 1, :] + mid_points = (origin_points + hit_points) / 2 + point_diff = hit_points - origin_points + extended_origin = mid_points - 0.8 * point_diff + extended_hit = mid_points + 0.8 * point_diff + extended_point_pairs = torch.cat( + [extended_origin[:, None, :], extended_hit[:, None, :]], dim=1 + ) + return extended_point_pairs + + @staticmethod + def _project_vertices_to_screen( + vertices_mesh: np.ndarray, + mesh_handle: viser.GlbHandle, + camera: Any, + ) -> tuple[np.ndarray, np.ndarray]: + T_world_mesh = tf.SE3.from_rotation_and_translation( + tf.SO3(np.asarray(mesh_handle.wxyz)), + np.asarray(mesh_handle.position), + ) + vertices_world_h = ( + T_world_mesh.as_matrix() + @ np.hstack([vertices_mesh, np.ones((vertices_mesh.shape[0], 1))]).T + ).T + vertices_world = vertices_world_h[:, :3] + + T_camera_world = tf.SE3.from_rotation_and_translation( + tf.SO3(np.asarray(camera.wxyz)), + np.asarray(camera.position), + ).inverse() + vertices_camera_h = ( + T_camera_world.as_matrix() + @ np.hstack([vertices_world, np.ones((vertices_world.shape[0], 1))]).T + ).T + vertices_camera = vertices_camera_h[:, :3] + + fov = float(camera.fov) + aspect = float(camera.aspect) + projected = vertices_camera[:, :2] / np.maximum(vertices_camera[:, 2:3], 1e-8) + projected /= np.tan(fov / 2.0) + projected[:, 0] /= aspect + projected = (1.0 + projected) / 2.0 + return projected, vertices_camera[:, 2] + + def _extract_selection( + mesh: trimesh.Trimesh, + vertex_mask: np.ndarray, + largest_component: bool, + ) -> SelectResult: + def _largest_connected_face_component(face_ids: np.ndarray) -> np.ndarray: + if face_ids.size <= 1: + return face_ids + + face_id_set = set(face_ids.tolist()) + parent: dict[int, int] = { + int(face_id): int(face_id) for face_id in face_ids + } + + def find(x: int) -> int: + root = x + while parent[root] != root: + root = parent[root] + while parent[x] != x: + x_parent = parent[x] + parent[x] = root + x = x_parent + return root + + def union(a: int, b: int) -> None: + ra, rb = find(a), find(b) + if ra != rb: + parent[rb] = ra + + face_adjacency = cast(np.ndarray, mesh.face_adjacency) + for face_a, face_b in face_adjacency: + if int(face_a) in face_id_set and int(face_b) in face_id_set: + union(int(face_a), int(face_b)) + + groups: dict[int, list[int]] = {} + for face_id in face_ids: + root = find(int(face_id)) + groups.setdefault(root, []).append(int(face_id)) + + largest_group = max(groups.values(), key=len) + return np.array(largest_group, dtype=np.int32) + + faces = cast(np.ndarray, mesh.faces) + face_mask = np.all(vertex_mask[faces], axis=1) + + face_indices = np.flatnonzero(face_mask) + if face_indices.size == 0: + return SelectResult() + if largest_component: + face_indices = _largest_connected_face_component(face_indices) + if face_indices.size == 0: + return SelectResult() + + selected_face_vertices = faces[face_indices] + vertex_indices = np.unique(selected_face_vertices.reshape(-1)) + + old_to_new = np.full(mesh.vertices.shape[0], -1, dtype=np.int32) + old_to_new[vertex_indices] = np.arange(vertex_indices.size, dtype=np.int32) + + sub_vertices = np.asarray(mesh.vertices)[vertex_indices] + sub_faces = np.asarray(old_to_new)[selected_face_vertices] + + return SelectResult( + vertex_indices=vertex_indices, + face_indices=face_indices, + vertices=sub_vertices, + faces=sub_faces, + ) + + @staticmethod + def _apply_transform(points: torch.Tensor, transform: torch.Tensor) -> torch.Tensor: + r = transform[:3, :3] + t = transform[:3, 3] + return points @ r.T + t + + def get_approach_grasp_poses( + self, + hit_point_pairs: torch.Tensor, + object_pose: torch.Tensor, + approach_direction: torch.Tensor, + ) -> torch.Tensor: + """Get grasp pose given approach direction + + Args: + hit_point_pairs (torch.Tensor): (N, 2, 3) tensor of N antipodal point pairs. Each pair consists of a hit point and its corresponding surface point. + object_pose (torch.Tensor): (4, 4) homogeneous transformation matrix representing the pose of the object in the world frame. + approach_direction (torch.Tensor): (3,) unit vector representing the desired approach direction of the gripper in the world frame. + + Returns: + torch.Tensor: (4, 4) homogeneous transformation matrix representing the grasp pose in the world frame that aligns the gripper's approach direction with the given approach_direction. Returns None if no valid grasp pose can be found. + """ + origin_points = hit_point_pairs[:, 0, :] + hit_points = hit_point_pairs[:, 1, :] + print("origin_points dtype:", origin_points.dtype) + print("object_pose dtype:", object_pose.dtype) + origin_points_ = self._apply_transform(origin_points, object_pose) + hit_points_ = self._apply_transform(hit_points, object_pose) + centers = (origin_points_ + hit_points_) / 2 + center = centers.mean(dim=0) + + # get best grasp pose + grasp_x = F.normalize(hit_points_ - origin_points_, dim=-1) + cos_angle = torch.clamp((grasp_x * approach_direction).sum(dim=-1), -1.0, 1.0) + positive_angle = torch.abs(torch.acos(cos_angle)) + antipodal_length = torch.norm(hit_points_ - origin_points_, dim=-1) + length_cost = 1 - antipodal_length / antipodal_length.max() + angle_cost = torch.abs(positive_angle - 0.5 * torch.pi) / (0.5 * torch.pi) + center_distance = torch.norm(centers - center, dim=-1) + center_cost = center_distance / center_distance.max() + total_cost = 0.4 * angle_cost + 0.3 * length_cost + 0.3 * center_cost + best_idx = torch.argmin(total_cost) + + best_open_length = torch.norm(hit_points_[best_idx] - origin_points_[best_idx]) + best_grasp_x = grasp_x[best_idx] + best_grasp_center = centers[best_idx] + best_grasp_y = torch.cross(approach_direction, best_grasp_x, dim=0) + best_grasp_y = F.normalize(best_grasp_y, dim=-1) + best_grasp_z = torch.cross(best_grasp_x, best_grasp_y, dim=0) + best_grasp_z = F.normalize(best_grasp_z, dim=-1) + grasp_pose = torch.eye(4, device=hit_point_pairs.device, dtype=torch.float32) + grasp_pose[:3, 0] = best_grasp_x + grasp_pose[:3, 1] = best_grasp_y + grasp_pose[:3, 2] = best_grasp_z + grasp_pose[:3, 3] = best_grasp_center + return grasp_pose, best_open_length + + @staticmethod + def visualize_grasp_pose( + vertices: torch.Tensor, + triangles: torch.Tensor, + obj_pose: torch.Tensor, + grasp_pose: torch.Tensor, + open_length: float, + ): + mesh = o3d.geometry.TriangleMesh( + vertices=o3d.utility.Vector3dVector(vertices.to("cpu").numpy()), + triangles=o3d.utility.Vector3iVector(triangles.to("cpu").numpy()), + ) + mesh.compute_vertex_normals() + mesh.paint_uniform_color([0.3, 0.6, 0.3]) + mesh.transform(obj_pose.to("cpu").numpy()) + vertices_ = torch.tensor( + np.asarray(mesh.vertices), device=vertices.device, dtype=vertices.dtype + ) + mesh_scale = (vertices_.max(dim=0)[0] - vertices_.min(dim=0)[0]).max().item() + groud_plane = o3d.geometry.TriangleMesh.create_cylinder( + radius=mesh_scale, height=0.01 * mesh_scale + ) + groud_plane.compute_vertex_normals() + center = vertices_.mean(dim=0) + z_sim = vertices_.min(dim=0)[0][2].item() + groud_plane.translate( + (center[0].item(), center[1].item(), z_sim - 0.005 * mesh_scale) + ) + + draw_thickness = 0.02 * mesh_scale + draw_length = 0.3 * mesh_scale + grasp_finger1 = o3d.geometry.TriangleMesh.create_box( + draw_thickness, draw_thickness, draw_length + ) + grasp_finger1.translate( + (-0.5 * draw_thickness, -0.5 * draw_thickness, -0.5 * draw_length) + ) + grasp_finger2 = o3d.geometry.TriangleMesh.create_box( + draw_thickness, draw_thickness, draw_length + ) + grasp_finger2.translate( + (-0.5 * draw_thickness, -0.5 * draw_thickness, -0.5 * draw_length) + ) + grasp_finger1.translate((-open_length / 2, 0, -0.25 * draw_length)) + grasp_finger2.translate((open_length / 2, 0, -0.25 * draw_length)) + grasp_root1 = o3d.geometry.TriangleMesh.create_box( + open_length, draw_thickness, draw_thickness + ) + grasp_root1.translate( + (-open_length / 2, -0.5 * draw_thickness, -0.5 * draw_thickness) + ) + grasp_root1.translate((0, 0, -0.75 * draw_length)) + grasp_root2 = o3d.geometry.TriangleMesh.create_box( + draw_thickness, draw_thickness, draw_length + ) + grasp_root2.translate( + (-0.5 * draw_thickness, -0.5 * draw_thickness, -0.5 * draw_length) + ) + grasp_root2.translate((0, 0, -1.25 * draw_length)) + + grasp_visual = grasp_finger1 + grasp_finger2 + grasp_root1 + grasp_root2 + grasp_visual.paint_uniform_color([0.8, 0.2, 0.8]) + grasp_visual.transform(grasp_pose.to("cpu").numpy()) + o3d.visualization.draw_geometries( + [grasp_visual, mesh, groud_plane], + window_name="Grasp Pose Visualization", + mesh_show_back_face=True, + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Viser mesh 标注工具:框选并导出对应顶点与三角面" + ) + parser.add_argument( + "--mesh", type=Path, required=True, help="输入 mesh 文件路径,例如 mug.obj" + ) + parser.add_argument("--scale", type=float, default=1.0, help="加载后整体缩放系数") + parser.add_argument("--port", type=int, default=12151, help="viser 服务端口") + parser.add_argument( + "--output-dir", + type=Path, + default=Path("outputs/mesh_annotations"), + help="标注结果导出目录", + ) + parser.add_argument( + "--largest-component", + action="store_true", + help="只保留框选结果中的最大连通块(常用于稳定提取把手等局部)", + ) + args = parser.parse_args() + + mesh = trimesh.load(args.mesh, process=False, force="mesh") + vertices = mesh.vertices * args.scale + triangles = mesh.faces + cfg = GraspAnnotatorCfg( + force_regenerate=True, + ) + tool = GraspAnnotator(cfg=cfg) + hit_point_pairs = tool.annotate( + vertices=torch.from_numpy(vertices).float(), + triangles=torch.from_numpy(triangles).long(), + ) + logger.log_info(f"Sample {hit_point_pairs.shape[0]} antipodal point pairs.") + + +if __name__ == "__main__": + main() diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py new file mode 100644 index 00000000..1eb3ec61 --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py @@ -0,0 +1,231 @@ +import torch +import torch.nn.functional as F +import numpy as np +import open3d as o3d +import open3d.core as o3c +from dataclasses import dataclass +from embodichain.utils import logger + + +@dataclass +class AntipodalSamplerCfg: + n_sample: int = 10000 + """surface point sample number""" + max_angle: float = np.pi / 12 + """maximum angle (in radians) to randomly disturb the ray direction for antipodal point sampling, used to increase the diversity of sampled antipodal points. Note that setting max_angle to 0 will disable the random disturbance and sample antipodal points strictly along the surface normals, which may result in less diverse antipodal points and may not be ideal for all objects or grasping scenarios.""" + max_length: float = 0.1 + """maximum gripper open width, used to filter out antipodal points that are too far apart to be grasped""" + min_length: float = 0.001 + """minimum gripper open width, used to filter out antipodal points that are too close to be grasped""" + + +class AntipodalSampler: + def __init__( + self, + cfg: AntipodalSamplerCfg = AntipodalSamplerCfg(), + ): + self.mesh: o3d.t.geometry.TriangleMesh | None = None + self.cfg = cfg + + def sample(self, vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor: + """Get sample Antipodal point pair + + Returns: + hit_point_pairs: [N, 2, 3] tensor of N antipodal point pairs. Each pair consists of a hit point and its corresponding surface point. + """ + # update mesh + self.mesh = o3d.t.geometry.TriangleMesh() + self.mesh.vertex.positions = o3c.Tensor( + vertices.to("cpu").numpy(), dtype=o3c.float32 + ) + self.mesh.triangle.indices = o3c.Tensor( + faces.to("cpu").numpy(), dtype=o3c.int32 + ) + self.mesh.compute_vertex_normals() + # sample points and normals + sample_pcd = self.mesh.sample_points_uniformly( + number_of_points=self.cfg.n_sample + ) + sample_points = torch.tensor( + sample_pcd.point.positions.numpy(), + device=vertices.device, + dtype=vertices.dtype, + ) + sample_normals = torch.tensor( + sample_pcd.point.normals.numpy(), + device=vertices.device, + dtype=vertices.dtype, + ) + # generate rays + ray_direc = -sample_normals + ray_origin = ( + sample_points + 1e-3 * ray_direc + ) # Offset ray origin slightly along the normal to avoid self-intersection + disturb_direc = AntipodalSampler._random_rotate_unit_vectors( + ray_direc, max_angle=self.cfg.max_angle + ) + ray_origin = torch.vstack([ray_origin, ray_origin]) + ray_direc = torch.vstack([ray_direc, disturb_direc]) + # casting + return self.get_raycast_result( + ray_origin, + ray_direc, + surface_origin=torch.vstack([sample_points, sample_points]), + ) + + def get_raycast_result( + self, + ray_origin: torch.Tensor, + ray_direc: torch.Tensor, + surface_origin: torch.Tensor, + ): + if ray_origin.ndim != 2 or ray_origin.shape[-1] != 3: + raise ValueError("ray_origin must have shape [N, 3]") + if ray_direc.ndim != 2 or ray_direc.shape[-1] != 3: + raise ValueError("ray_direc must have shape [N, 3]") + if ray_origin.shape[0] != ray_direc.shape[0]: + raise ValueError( + "ray_origin and ray_direc must have the same number of rays" + ) + if ray_origin.shape[0] != surface_origin.shape[0]: + raise ValueError( + "ray_origin and surface_origin must have the same number of rays" + ) + + scene = o3d.t.geometry.RaycastingScene() + scene.add_triangles(self.mesh) + + rays = torch.cat([ray_origin, ray_direc], dim=-1) + rays_o3d = o3c.Tensor(rays.detach().to("cpu").numpy(), dtype=o3c.float32) + + ans = scene.cast_rays(rays_o3d) + t_hit = torch.from_numpy(ans["t_hit"].numpy()).to( + device=ray_origin.device, dtype=ray_origin.dtype + ) + hit_mask = torch.logical_and( + t_hit > self.cfg.min_length, t_hit < self.cfg.max_length + ) + hit_points = ray_origin[hit_mask] + t_hit[hit_mask, None] * ray_direc[hit_mask] + hit_origins = surface_origin[hit_mask] + hit_point_pairs = torch.cat( + [hit_points[:, None, :], hit_origins[:, None, :]], dim=1 + ) + hit_point_pairs = hit_point_pairs.to(dtype=torch.float32) + return hit_point_pairs + + @staticmethod + def _random_rotate_unit_vectors( + vectors: torch.Tensor, + max_angle: float, + degrees: bool = False, + eps: float = 1e-8, + ) -> torch.Tensor: + """ + Apply random small rotations to a batch of unit vectors [N, 3]. + + Args: + vectors: [N, 3], unit vectors + max_angle: Maximum rotation angle + degrees: If True, `max_angle` is given in degrees + eps: Numerical stability constant + + Returns: + rotated: [N, 3], rotated unit vectors + """ + assert vectors.ndim == 2 and vectors.shape[-1] == 3, "vectors must be [N, 3]" + + v = F.normalize(vectors, dim=-1) + + if degrees: + max_angle = torch.deg2rad( + torch.tensor(max_angle, dtype=v.dtype, device=v.device) + ).item() + + n = v.shape[0] + + # 1) Generate a random direction for each vector + # then project it onto the plane perpendicular to v to get the rotation axis k + rand_dir = torch.randn_like(v) + eps + proj = (rand_dir * v).sum(dim=-1, keepdim=True) * v + k = rand_dir - proj + k = F.normalize(k, dim=-1) + + # 2) Sample rotation angles in the range [eps, max_angle] + theta = ( + torch.rand(n, 1, device=v.device, dtype=v.dtype) * (max_angle - eps) + eps + ) + + # 3) Rodrigues' rotation formula + # R(v) = v*cosθ + (k×v)*sinθ + k*(k·v)*(1-cosθ) + # Since k ⟂ v, the last term is theoretically 0, but keeping the general formula is more robust + cos_t = torch.cos(theta) + sin_t = torch.sin(theta) + + kv = (k * v).sum(dim=-1, keepdim=True) + rotated = v * cos_t + torch.cross(k, v, dim=-1) * sin_t + k * kv * (1.0 - cos_t) + + return F.normalize(rotated, dim=-1) + + def visualize(self, hit_point_pairs: torch.Tensor): + if self.mesh is None: + logger.log_warning("Mesh is not initialized. Cannot visualize.") + return + + if hit_point_pairs.shape[0] == 0: + raise ValueError("No point pairs to visualize") + origin_points = hit_point_pairs[:, 0, :] + hit_points = hit_point_pairs[:, 1, :] + + origin_points_np = origin_points.to("cpu").numpy() + hit_points_np = hit_points.detach().to("cpu").numpy() + + n_pairs = hit_point_pairs.shape[0] + line_indices = np.stack( + [np.arange(n_pairs), np.arange(n_pairs) + n_pairs], axis=1 + ) + + mesh_legacy = self.mesh.to_legacy() + mesh_legacy.compute_vertex_normals() + mesh_legacy.paint_uniform_color([0.8, 0.8, 0.8]) + + origin_pcd = o3d.geometry.PointCloud() + origin_pcd.points = o3d.utility.Vector3dVector(origin_points_np) + origin_pcd.colors = o3d.utility.Vector3dVector( + np.tile(np.array([[0.1, 0.4, 1.0]]), (n_pairs, 1)) + ) + + hit_pcd = o3d.geometry.PointCloud() + hit_pcd.points = o3d.utility.Vector3dVector(hit_points_np) + hit_pcd.colors = o3d.utility.Vector3dVector( + np.tile(np.array([[1.0, 0.2, 0.2]]), (n_pairs, 1)) + ) + + line_set = o3d.geometry.LineSet() + mid_points = (origin_points_np + hit_points_np) / 2 + point_diff = hit_points_np - origin_points_np + draw_origin = mid_points - 0.6 * point_diff + draw_end = mid_points + 0.6 * point_diff + draw_pointpair = np.concatenate([draw_origin, draw_end], axis=0) + line_set.points = o3d.utility.Vector3dVector(draw_pointpair) + line_set.lines = o3d.utility.Vector2iVector(line_indices) + line_set.colors = o3d.utility.Vector3dVector( + np.tile(np.array([[0.2, 0.9, 0.2]]), (n_pairs, 1)) + ) + + o3d.visualization.draw_geometries( + [mesh_legacy, origin_pcd, hit_pcd, line_set], + window_name="Antipodal Point Pairs", + mesh_show_back_face=True, + ) + + +if __name__ == "__main__": + mesh_path = "/media/chenjian/_abc/project/grasp_annotator/dustpan_saa.ply" + mesh = o3d.t.io.read_triangle_mesh(mesh_path) + vertices = torch.from_numpy(mesh.vertex.positions.cpu().numpy()) + faces = torch.from_numpy(mesh.triangle.indices.cpu().numpy()) + + sampler = AntipodalSampler() + hit_point_pairs = sampler.sample(vertices, faces) + sampler.visualize(hit_point_pairs) + print(f"Sampled {hit_point_pairs.shape[0]} antipodal points") diff --git a/examples/sim/demo/grasp_mug.py b/examples/sim/demo/grasp_mug.py new file mode 100644 index 00000000..a0a138d0 --- /dev/null +++ b/examples/sim/demo/grasp_mug.py @@ -0,0 +1,257 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates the creation and simulation of a robot with a soft object, +and performs a pressing task in a simulated environment. +""" + +import argparse +import numpy as np +import time +import torch + +from dexsim.utility.path import get_resources_data_path + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot, RigidObject +from embodichain.lab.sim.utility.action_utils import interpolate_with_distance +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.solvers import PytorchSolverCfg +from embodichain.data import get_data_path +from embodichain.utils import logger +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + RobotCfg, + LightCfg, + RigidBodyAttributesCfg, + RigidObjectCfg, + URDFCfg, +) +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.toolkits.graspkit.pg_grasp.antipodal_annotator import ( + GraspAnnotatorCfg, + AntipodalSamplerCfg, +) + + +def parse_arguments(): + """ + Parse command-line arguments to configure the simulation. + + Returns: + argparse.Namespace: Parsed arguments including number of environments and rendering options. + """ + parser = argparse.ArgumentParser( + description="Create and simulate a robot in SimulationManager" + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + parser.add_argument( + "--enable_rt", action="store_true", help="Enable ray tracing rendering" + ) + parser.add_argument("--headless", action="store_true", help="Enable headless mode") + parser.add_argument( + "--device", + type=str, + default="cpu", + help="device to run the environment on, e.g., 'cpu' or 'cuda'", + ) + return parser.parse_args() + + +def initialize_simulation(args) -> SimulationManager: + """ + Initialize the simulation environment based on the provided arguments. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + + Returns: + SimulationManager: Configured simulation manager instance. + """ + config = SimulationManagerCfg( + headless=True, + sim_device=args.device, + enable_rt=args.enable_rt, + physics_dt=1.0 / 100.0, + num_envs=args.num_envs, + arena_space=2.5, + ) + sim = SimulationManager(config) + + if args.enable_rt: + light = sim.add_light( + cfg=LightCfg( + uid="main_light", + color=(0.6, 0.6, 0.6), + intensity=30.0, + init_pos=(1.0, 0, 3.0), + ) + ) + + return sim + + +def create_robot(sim: SimulationManager, position=[0.0, 0.0, 0.0]) -> Robot: + """ + Create and configure a robot with an arm and a dexterous hand in the simulation. + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + Robot: The configured robot instance added to the simulation. + """ + # Retrieve URDF paths for the robot arm and hand + ur10_urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + gripper_urdf_path = get_data_path("DH_PGC_140_50_M/DH_PGC_140_50_M.urdf") + # Configure the robot with its components and control properties + cfg = RobotCfg( + uid="UR10", + urdf_cfg=URDFCfg( + components=[ + {"component_type": "arm", "urdf_path": ur10_urdf_path}, + {"component_type": "hand", "urdf_path": gripper_urdf_path}, + ] + ), + drive_pros=JointDrivePropertiesCfg( + stiffness={"JOINT[0-9]": 1e4, "FINGER[1-2]": 1e3}, + damping={"JOINT[0-9]": 1e3, "FINGER[1-2]": 1e2}, + max_effort={"JOINT[0-9]": 1e5, "FINGER[1-2]": 1e4}, + drive_type="force", + ), + control_parts={ + "arm": ["JOINT[0-9]"], + "hand": ["FINGER[1-2]"], + }, + solver_cfg={ + "arm": PytorchSolverCfg( + end_link_name="ee_link", + root_link_name="base_link", + tcp=[ + [0.0, 1.0, 0.0, 0.0], + [-1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.12], + [0.0, 0.0, 0.0, 1.0], + ], + ) + }, + init_qpos=[0.0, -np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2, 0.0, 0.0, 0.0], + init_pos=position, + ) + return sim.add_robot(cfg=cfg) + + +def create_mug(sim: SimulationManager): + mug_cfg = RigidObjectCfg( + uid="table", + shape=MeshCfg( + fpath=get_data_path("CoffeeCup/cup.ply"), + ), + attrs=RigidBodyAttributesCfg( + mass=0.01, + dynamic_friction=0.97, + static_friction=0.99, + ), + max_convex_hull_num=16, + init_pos=[0.55, 0.0, 0.01], + init_rot=[0.0, 0.0, -90], + body_scale=(4, 4, 4), + ) + mug = sim.add_rigid_object(cfg=mug_cfg) + return mug + + +def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tensor): + n_envs = sim.num_envs + rest_arm_qpos = robot.get_qpos("arm") + + approach_xpos = grasp_xpos.clone() + approach_xpos[:, 2, 3] += 0.04 + + _, qpos_approach = robot.compute_ik( + pose=approach_xpos, joint_seed=rest_arm_qpos, name="arm" + ) + _, qpos_grasp = robot.compute_ik( + pose=grasp_xpos, joint_seed=qpos_approach, name="arm" + ) + hand_open_qpos = torch.tensor([0.00, 0.00], dtype=torch.float32, device=sim.device) + hand_close_qpos = torch.tensor( + [0.025, 0.025], dtype=torch.float32, device=sim.device + ) + + arm_trajectory = torch.cat( + [ + rest_arm_qpos[:, None, :], + qpos_approach[:, None, :], + qpos_grasp[:, None, :], + qpos_grasp[:, None, :], + qpos_approach[:, None, :], + rest_arm_qpos[:, None, :], + ], + dim=1, + ) + hand_trajectory = torch.cat( + [ + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + ], + dim=1, + ) + all_trajectory = torch.cat([arm_trajectory, hand_trajectory], dim=-1) + interp_trajectory = interpolate_with_distance( + trajectory=all_trajectory, interp_num=300, device=sim.device + ) + return interp_trajectory + + +if __name__ == "__main__": + args = parse_arguments() + sim = initialize_simulation(args) + robot = create_robot(sim, position=[0.0, 0.0, 0.0]) + mug = create_mug(sim) + + # get mug grasp pose + grasp_cfg = GraspAnnotatorCfg( + viser_port=11801, + antipodal_sampler_cfg=AntipodalSamplerCfg( + n_sample=5000, max_length=0.088, min_length=0.003 + ), + force_regenerate=True, + ) + sim.open_window() + grasp_xpos = mug.get_grasp_pose( + approach_direction=torch.tensor( + [0, 0, -1], dtype=torch.float32, device=sim.device + ), + cfg=grasp_cfg, + is_visual=True, + ) + + grab_traj = get_grasp_traj(sim, robot, grasp_xpos) + input("Press Enter to start the grab mug demo...") + n_waypoint = grab_traj.shape[1] + for i in range(n_waypoint): + robot.set_qpos(grab_traj[:, i, :]) + sim.update(step=4) + time.sleep(1e-2) + input("Press Enter to exit the simulation...") From 1e15c77b29f491fd3c4e86bb189187b5f92bbc08 Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 23 Mar 2026 19:17:02 +0800 Subject: [PATCH 02/25] update --- examples/sim/demo/grasp_mug.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/sim/demo/grasp_mug.py b/examples/sim/demo/grasp_mug.py index a0a138d0..18c5ff9c 100644 --- a/examples/sim/demo/grasp_mug.py +++ b/examples/sim/demo/grasp_mug.py @@ -236,15 +236,19 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso antipodal_sampler_cfg=AntipodalSamplerCfg( n_sample=5000, max_length=0.088, min_length=0.003 ), - force_regenerate=True, + force_regenerate=True, # force user to annotate grasp region each time ) sim.open_window() + + # 1. View grasp object in browser (e.g http://localhost:11801) + # 2. press 'Rect Select Region', select grasp region + # 3. press 'Confirm Selection' to finish grasp region selection. grasp_xpos = mug.get_grasp_pose( approach_direction=torch.tensor( [0, 0, -1], dtype=torch.float32, device=sim.device - ), + ), # gripper approach direction in the mug local frame cfg=grasp_cfg, - is_visual=True, + is_visual=True, # visualize selected grasp pose finally ) grab_traj = get_grasp_traj(sim, robot, grasp_xpos) From baf731a8ed19e65268dfead52afc3fff93f472bf Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 23 Mar 2026 19:22:12 +0800 Subject: [PATCH 03/25] update comment --- examples/sim/demo/grasp_mug.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sim/demo/grasp_mug.py b/examples/sim/demo/grasp_mug.py index 18c5ff9c..6ff56d69 100644 --- a/examples/sim/demo/grasp_mug.py +++ b/examples/sim/demo/grasp_mug.py @@ -246,7 +246,7 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso grasp_xpos = mug.get_grasp_pose( approach_direction=torch.tensor( [0, 0, -1], dtype=torch.float32, device=sim.device - ), # gripper approach direction in the mug local frame + ), # gripper approach direction in the world frame cfg=grasp_cfg, is_visual=True, # visualize selected grasp pose finally ) From f1f043b809c59ff4215ea1fa6d5db3bdd5466281 Mon Sep 17 00:00:00 2001 From: chenjian Date: Tue, 24 Mar 2026 11:23:13 +0800 Subject: [PATCH 04/25] add viser dependence --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 60a12496..25b15290 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,8 @@ dependencies = [ "black==24.3.0", "fvcore", "h5py", - "tensordict" + "tensordict", + "viser==1.0.21" ] [project.optional-dependencies] From 63ec5e6667eb0c2419a2a8c47d7d556acd5eb3f0 Mon Sep 17 00:00:00 2001 From: chenjian Date: Wed, 25 Mar 2026 19:14:51 +0800 Subject: [PATCH 05/25] update --- .../pg_grasp/batch_collision_checker.py | 528 ++++++++++++++++++ 1 file changed, 528 insertions(+) create mode 100644 embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py new file mode 100644 index 00000000..f50a12f9 --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -0,0 +1,528 @@ +import trimesh +import numpy as np +import torch +import time +from typing import List, Tuple, Union +from dexsim.kit.meshproc import convex_decomposition_coacd +import hashlib +from dataclasses import dataclass +import os +import pickle +import open3d as o3d +from embodichain.utils import logger + + +CONVEX_CACHE_DIR = os.path.join( + os.path.expanduser("~"), ".cache", "embodichain_cache", "convex_decomposition" +) + + +@dataclass +class BatchConvexCollisionCheckerCfg: + collsion_threshold: float = 0.0 + n_query_mesh_samples: int = 4096 + debug: bool = False + + +class BatchConvexCollisionChecker: + def __init__( + self, + base_mesh_verts: torch.Tensor, + base_mesh_faces: torch.Tensor, + max_decomposition_hulls: int = 32, + ): + if not os.path.isdir(CONVEX_CACHE_DIR): + os.makedirs(CONVEX_CACHE_DIR, exist_ok=True) + base_mesh_verts_np = base_mesh_verts.cpu().numpy() + base_mesh_faces_np = base_mesh_faces.cpu().numpy() + mesh_hash = hashlib.md5( + (base_mesh_verts_np.tobytes() + base_mesh_faces_np.tobytes()) + ).hexdigest() + + # for visualization + self.mesh = o3d.geometry.TriangleMesh( + vertices=o3d.utility.Vector3dVector(base_mesh_verts_np), + triangles=o3d.utility.Vector3iVector(base_mesh_faces_np), + ) + self.mesh.compute_vertex_normals() + self.cache_path = os.path.join( + CONVEX_CACHE_DIR, f"{mesh_hash}_{max_decomposition_hulls}.pkl" + ) + + if not os.path.isfile(self.cache_path): + # generate convex hulls and extract plane equations, then cache to disk + self.plane_equations = BatchConvexCollisionChecker._compute_plane_equations( + base_mesh_verts_np, base_mesh_faces_np, max_decomposition_hulls + ) + pickle.dump(self.plane_equations, open(self.cache_path, "wb")) + else: + # load precomputed plane equations from cache + self.plane_equations = pickle.load(open(self.cache_path, "rb")) + + def query( + self, + query_mesh_verts: torch.Tensor, + query_mesh_faces: torch.Tensor, + poses: torch.Tensor, + cfg: BatchConvexCollisionCheckerCfg = BatchConvexCollisionCheckerCfg(), + ) -> Tuple[torch.Tensor, torch.Tensor]: + query_mesh = trimesh.Trimesh( + vertices=query_mesh_verts.to("cpu").numpy(), + faces=query_mesh_faces.to("cpu").numpy(), + ) + n_query = cfg.n_query_mesh_samples + n_batch = poses.shape[0] + query_points_np = query_mesh.sample(n_query).astype(np.float32) + query_points = torch.tensor( + query_points_np, device=poses.device + ) # [n_query, 3] + penetration_result = torch.zeros(size=(n_batch, n_query), device=poses.device) + penetration_result.fill_(-float("inf")) + collision_result = torch.zeros( + size=(n_batch, n_query), dtype=torch.bool, device=poses.device + ) + collision_result.fill_(False) + for normals, offsets in self.plane_equations: + normals_torch = torch.tensor(normals, device=poses.device) + offsets_torch = torch.tensor(offsets, device=poses.device) + penetration, collides = check_collision_single_hull( + normals_torch, + offsets_torch, + transform_points_batch(query_points, poses), + cfg.collsion_threshold, + ) + penetration_result = torch.max(penetration_result, penetration) + collision_result = torch.logical_or(collision_result, collides) + is_colliding = collision_result.any(dim=-1) # [B] + max_penetration = penetration_result.max(dim=-1)[0] # [B] + + if cfg.debug: + # visualize result + query_points_o3d = o3d.geometry.PointCloud() + query_points_o3d.points = o3d.utility.Vector3dVector(query_points_np) + query_points_o3d.transform(poses[-1].to("cpu").numpy()) + query_points_color = np.zeros_like(query_points_np) + query_points_color[collision_result[-1].cpu().numpy()] = [ + 1.0, + 0, + 0, + ] # red for colliding points + query_points_color[~collision_result[-1].cpu().numpy()] = [ + 0, + 1.0, + 0, + ] # green for non-colliding points + query_points_o3d.colors = o3d.utility.Vector3dVector(query_points_color) + o3d.visualization.draw_geometries( + [self.mesh, query_points_o3d], mesh_show_back_face=True + ) + return is_colliding, max_penetration + + @staticmethod + def _compute_plane_equations( + vertices: np.ndarray, faces: np.ndarray, max_decomposition_hulls: int + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Convex decomposition and extract plane equations given mesh vertices and triangles. + Each convex hull is represented by its outward-facing face normals and offsets. + No padding is applied; each hull can have a different number of faces. + + Args: + vertices: [N, 3] vertex positions of the input mesh. + faces: [M, 3] triangle indices of the input mesh. + max_decomposition_hulls: maximum number of convex hulls to decompose into. + + Returns: + List of (normals_i [Ki, 3], offsets_i [Ki]) tuples, one per convex hull. + Ki is the number of faces of the i-th hull and can differ across hulls. + """ + mesh = o3d.t.geometry.TriangleMesh() + mesh.vertex.positions = o3d.core.Tensor(vertices, dtype=o3d.core.Dtype.Float32) + mesh.triangle.indices = o3d.core.Tensor(faces, dtype=o3d.core.Dtype.Int32) + is_success, out_mesh_list = convex_decomposition_coacd( + mesh, max_convex_hull_num=max_decomposition_hulls + ) + convex_vert_face_list = [] + for out_mesh in out_mesh_list: + verts = out_mesh.vertex.positions.numpy() + faces = out_mesh.triangle.indices.numpy() + convex_vert_face_list.append((verts, faces)) + return extract_plane_equations(convex_vert_face_list) + + +def extract_plane_equations( + convex_meshes: List[Tuple[np.ndarray, np.ndarray]], +) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Extract plane equations from a list of convex hull meshes. + Each convex hull is represented by its outward-facing face normals and offsets. + No padding is applied; each hull can have a different number of faces. + + Args: + convex_meshes: List of convex hull data. + - tuple of (vertices [N,3], faces [M,3]) + + Returns: + List of (normals_i [Ki, 3], offsets_i [Ki]) tuples, one per convex hull. + Ki is the number of faces of the i-th hull and can differ across hulls. + """ + convex_plane_data = [] + + for i, convex_mesh_data in enumerate(convex_meshes): + vertices, faces = convex_mesh_data + hull = trimesh.Trimesh( + vertices=vertices, + faces=faces, + ) + # Outward-facing face normals [Ki, 3] + face_normals = hull.face_normals + # One vertex per face to compute offset [Ki, 3] + face_origins = hull.triangles[:, 0, :] + # Plane equation: n · x + d = 0 => d = -(n · p) + offsets_i = -np.sum(face_normals * face_origins, axis=1) + + convex_plane_data.append( + (face_normals.astype(np.float32), offsets_i.astype(np.float32)) + ) + return convex_plane_data + + +def sample_surface_points(mesh_path: str, num_points: int = 4096) -> np.ndarray: + """ + Sample surface points from a mesh file. + + Args: + mesh_path: Path to the mesh file. + num_points: Number of surface points to sample. + + Returns: + points: [P, 3] numpy array of sampled surface points. + """ + mesh = trimesh.load(mesh_path, force="mesh") + points = mesh.sample(num_points) + return points.astype(np.float32) + + +def check_collision_single_hull( + normals: torch.Tensor, # [K, 3] + offsets: torch.Tensor, # [K] + transformed_points: torch.Tensor, # [B, P, 3] + threshold: float = 0.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Check collision between a batch of transformed point clouds and a single convex hull. + + A point p is inside the convex hull iff: + max_k (n_k · p + d_k) <= 0 + + Penetration depth for a point is defined as: + penetration = -(max_k (n_k · p + d_k)) + Positive penetration means the point is inside the hull. + + Args: + normals: [K, 3] outward face normals of the convex hull. + offsets: [K] plane offsets of the convex hull. + transformed_points: [B, P, 3] point cloud already transformed by batch poses. + threshold: collision threshold. A point is considered colliding if + its signed distance to the hull interior is <= threshold. + + Returns: + penetration: [B, P] penetration depth for each point. + Positive values indicate the point is inside the hull. + collides: [B, P] boolean mask, True if the point collides with this hull. + """ + # signed_dist: [B, P, K] = einsum([B,P,3], [K,3]) + [K] + signed_dist = torch.einsum("bpj, kj -> bpk", transformed_points, normals) + offsets + + # For each point, the maximum signed distance across all planes + # If max <= 0, the point satisfies all half-plane constraints => inside the hull + max_over_planes, _ = signed_dist.max(dim=-1) # [B, P] + + # Penetration depth: negate so that positive = inside + penetration = -max_over_planes # [B, P] + + # A point collides if its penetration exceeds the threshold + collides = penetration > threshold # [B, P] + + return penetration, collides + + +def merge_collision_results( + hull_results: List[Tuple[torch.Tensor, torch.Tensor]], device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Merge collision detection results from multiple convex hulls. + + A pose is considered colliding if ANY point penetrates ANY convex hull. + The reported penetration depth is the maximum across all points and all hulls. + + Args: + hull_results: List of (penetration [B, P], collides [B, P]) tuples, + one per convex hull, as returned by check_collision_single_hull. + device: torch device. + + Returns: + overall_collisions: [B] boolean, True if the pose collides with any hull. + overall_max_penetrations: [B] float, maximum penetration depth per pose. + """ + if not hull_results: + raise ValueError("hull_results is empty, nothing to merge.") + + B = hull_results[0][0].shape[0] + + overall_collisions = torch.zeros(B, dtype=torch.bool, device=device) + overall_max_penetrations = torch.full( + (B,), -float("inf"), dtype=torch.float32, device=device + ) + + for penetration, collides in hull_results: + # Update collision flag: OR across hulls + # A pose collides if any point collides with this hull + overall_collisions |= collides.any(dim=-1) # [B] + + # Update max penetration: take the per-pose maximum across all points for this hull, + # then compare with the running maximum + hull_max_pen = penetration.max(dim=-1)[0] # [B] + overall_max_penetrations = torch.max(overall_max_penetrations, hull_max_pen) + + return overall_collisions, overall_max_penetrations + + +def transform_points_batch( + points: torch.Tensor, poses: torch.Tensor # [P, 3] # [B, 4, 4] +) -> torch.Tensor: + """ + Apply a batch of rigid transforms to a point cloud. + + Args: + points: [P, 3] source point cloud. + poses: [B, 4, 4] batch of homogeneous transformation matrices. + + Returns: + transformed: [B, P, 3] transformed point cloud for each pose. + """ + R = poses[:, :3, :3] # [B, 3, 3] + t = poses[:, :3, 3] # [B, 3] + transformed = torch.einsum("bij, pj -> bpi", R, points) + t.unsqueeze(1) + return transformed + + +def batch_collision_detection( + convex_planes: List[Tuple[torch.Tensor, torch.Tensor]], + points_B: torch.Tensor, # [P, 3] + poses: torch.Tensor, # [B, 4, 4] + threshold: float = 0.0, + chunk_size: int = 512, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Full batch collision detection pipeline. + + Iterates over convex hulls sequentially and over pose chunks to control + GPU memory usage. Within each (hull, chunk) pair, the computation is + fully parallelized over B_chunk * P * K. + + Args: + convex_planes: List of (normals [Ki, 3], offsets [Ki]) tensors on device, + one per convex hull. Ki can differ across hulls. + points_B: [P, 3] sampled surface points of mesh B, on device. + poses: [B, 4, 4] batch of relative poses, on device. + threshold: collision threshold (positive = safety margin). + chunk_size: number of poses to process per chunk. + + Returns: + overall_collisions: [B] bool + overall_max_penetrations: [B] float + """ + device = points_B.device + B = poses.shape[0] + + all_hull_results: List[Tuple[torch.Tensor, torch.Tensor]] = [] + + # Sequential iteration over convex hulls to save memory + for hull_idx, (normals, offsets) in enumerate(convex_planes): + hull_pen_chunks = [] + hull_col_chunks = [] + + # Chunk over batch dimension to control peak memory + for start in range(0, B, chunk_size): + end = min(start + chunk_size, B) + poses_chunk = poses[start:end] + + # Transform points for this chunk of poses + transformed_chunk = transform_points_batch( + points_B, poses_chunk + ) # [B_chunk, P, 3] + + # Check collision against this single hull + penetration, collides = check_collision_single_hull( + normals, offsets, transformed_chunk, threshold + ) + + hull_pen_chunks.append(penetration) + hull_col_chunks.append(collides) + + # Concatenate chunks for this hull + hull_penetration = torch.cat(hull_pen_chunks, dim=0) # [B, P] + hull_collides = torch.cat(hull_col_chunks, dim=0) # [B, P] + + all_hull_results.append((hull_penetration, hull_collides)) + + # Merge results across all hulls + overall_collisions, overall_max_penetrations = merge_collision_results( + all_hull_results, device + ) + + return overall_collisions, overall_max_penetrations + + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # --- Create dummy mesh files for testing --- + box1 = trimesh.primitives.Box(extents=[0.5, 0.5, 0.5]) + box2 = trimesh.primitives.Box( + extents=[0.4, 0.4, 0.4], + transform=trimesh.transformations.translation_matrix([1, 0, 0]), + ) + box1.export("mesh_hull_1.obj") + box2.export("mesh_hull_2.obj") + + sphere_mesh = trimesh.primitives.Sphere(radius=0.3) + sphere_mesh.export("mesh_B.obj") + print("Created dummy mesh files.\n") + + # ==================== Preprocessing ==================== + + # Load externally decomposed convex hull meshes + convex_mesh_files = ["mesh_hull_1.obj", "mesh_hull_2.obj"] + convex_meshes = load_convex_meshes(convex_mesh_files) + if not convex_meshes: + print("No convex hulls loaded. Exiting.") + return + + # Extract plane equations (list of variable-length arrays) + convex_plane_data_np = extract_plane_equations(convex_meshes) + + # Convert to torch tensors on device + convex_planes_torch: List[Tuple[torch.Tensor, torch.Tensor]] = [] + for normals_np, offsets_np in convex_plane_data_np: + convex_planes_torch.append( + ( + torch.tensor(normals_np, device=device), # [Ki, 3] + torch.tensor(offsets_np, device=device), # [Ki] + ) + ) + + # Sample surface points from mesh B + points_np = sample_surface_points("mesh_B.obj", num_points=2048) + points_B = torch.tensor(points_np, device=device) # [P, 3] + + # ==================== Generate test poses ==================== + B = 10000 + chunk_size = 1024 + + # Random rotation matrices via SVD + random_mat = torch.randn(B, 3, 3, device=device) + U, _, Vt = torch.linalg.svd(random_mat) + R = U @ Vt + # Fix reflections to ensure proper rotations (det = +1) + det = torch.det(R) + R[det < 0] *= -1 + + poses = torch.eye(4, device=device).unsqueeze(0).repeat(B, 1, 1) + poses[:, :3, :3] = R + poses[:, :3, 3] = torch.randn(B, 3, device=device) * 0.5 + + # ==================== Run collision detection ==================== + print( + f"\nRunning collision detection: {B} poses, {points_B.shape[0]} points, " + f"{len(convex_planes_torch)} hulls..." + ) + + torch.cuda.synchronize() if device.type == "cuda" else None + start_time = time.time() + + with torch.no_grad(): + collisions, penetration_depths = batch_collision_detection( + convex_planes_torch, points_B, poses, threshold=0.001, chunk_size=chunk_size + ) + + torch.cuda.synchronize() if device.type == "cuda" else None + elapsed = time.time() - start_time + + # ==================== Report results ==================== + print(f"\n{'='*40}") + print(f"Total poses: {B}") + print(f"Collisions: {collisions.sum().item()} / {B}") + if collisions.any(): + print(f"Max penetration: {penetration_depths[collisions].max().item():.6f}") + else: + print(f"Max penetration: N/A (no collisions)") + print(f"Total time: {elapsed:.3f}s") + print(f"Per pose: {elapsed / B * 1e6:.2f} μs") + print(f"{'='*40}") + + # ==================== Benchmark ==================== + num_iters = 50 + torch.cuda.synchronize() if device.type == "cuda" else None + t0 = time.time() + for _ in range(num_iters): + with torch.no_grad(): + batch_collision_detection( + convex_planes_torch, + points_B, + poses, + threshold=0.001, + chunk_size=chunk_size, + ) + torch.cuda.synchronize() if device.type == "cuda" else None + t1 = time.time() + + avg_ms = (t1 - t0) / num_iters * 1000 + print( + f"\nBenchmark ({num_iters} iters): {avg_ms:.2f} ms/iter, " + f"{avg_ms / B * 1000:.2f} μs/pose" + ) + + +if __name__ == "__main__": + from embodichain.data import get_data_path + + bottle_a_path = get_data_path("ScannedBottle/moliwulong_processed.ply") + bottle_b_path = get_data_path("ScannedBottle/yibao_processed.ply") + + bottle_a_mesh = trimesh.load(bottle_a_path) + bottle_b_mesh = trimesh.load(bottle_b_path) + bottle_a_verts = torch.tensor(bottle_a_mesh.vertices, dtype=torch.float32) + bottle_a_faces = torch.tensor(bottle_a_mesh.faces, dtype=torch.int64) + bottle_b_verts = torch.tensor(bottle_b_mesh.vertices, dtype=torch.float32) + bottle_b_faces = torch.tensor(bottle_b_mesh.faces, dtype=torch.int64) + + collision_checker = BatchConvexCollisionChecker(bottle_a_verts, bottle_a_faces) + poses = torch.tensor( + [ + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 1.0], + [0, 0, 0, 1], + ], + [ + [1, 0, 0, 0.05], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], + ] + ) + check_cfg = BatchConvexCollisionCheckerCfg( + debug=False, + n_query_mesh_samples=32768, + collsion_threshold=-0.003, + ) + collisions, penetrations = collision_checker.query( + bottle_b_verts, bottle_b_faces, poses, cfg=check_cfg + ) + print("Collisions:", collisions) + print("Penetrations:", penetrations) From 73781d887eb82142224cc35daedc25ad505e964b Mon Sep 17 00:00:00 2001 From: chenjian Date: Thu, 26 Mar 2026 17:40:42 +0800 Subject: [PATCH 06/25] update --- .../graspkit/pg_grasp/antipodal_annotator.py | 16 ++ .../graspkit/pg_grasp/antipodal_sampler.py | 16 ++ .../pg_grasp/batch_collision_checker.py | 179 ------------------ .../pg_grasp/gripper_collision_checker.py | 0 4 files changed, 32 insertions(+), 179 deletions(-) create mode 100644 embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py index 4852879e..5ee3eda4 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -1,3 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + import os import argparse import open3d as o3d diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py index 1eb3ec61..09e4858e 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py @@ -1,3 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + import torch import torch.nn.functional as F import numpy as np diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py index f50a12f9..cf18b76e 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -307,185 +307,6 @@ def transform_points_batch( return transformed -def batch_collision_detection( - convex_planes: List[Tuple[torch.Tensor, torch.Tensor]], - points_B: torch.Tensor, # [P, 3] - poses: torch.Tensor, # [B, 4, 4] - threshold: float = 0.0, - chunk_size: int = 512, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Full batch collision detection pipeline. - - Iterates over convex hulls sequentially and over pose chunks to control - GPU memory usage. Within each (hull, chunk) pair, the computation is - fully parallelized over B_chunk * P * K. - - Args: - convex_planes: List of (normals [Ki, 3], offsets [Ki]) tensors on device, - one per convex hull. Ki can differ across hulls. - points_B: [P, 3] sampled surface points of mesh B, on device. - poses: [B, 4, 4] batch of relative poses, on device. - threshold: collision threshold (positive = safety margin). - chunk_size: number of poses to process per chunk. - - Returns: - overall_collisions: [B] bool - overall_max_penetrations: [B] float - """ - device = points_B.device - B = poses.shape[0] - - all_hull_results: List[Tuple[torch.Tensor, torch.Tensor]] = [] - - # Sequential iteration over convex hulls to save memory - for hull_idx, (normals, offsets) in enumerate(convex_planes): - hull_pen_chunks = [] - hull_col_chunks = [] - - # Chunk over batch dimension to control peak memory - for start in range(0, B, chunk_size): - end = min(start + chunk_size, B) - poses_chunk = poses[start:end] - - # Transform points for this chunk of poses - transformed_chunk = transform_points_batch( - points_B, poses_chunk - ) # [B_chunk, P, 3] - - # Check collision against this single hull - penetration, collides = check_collision_single_hull( - normals, offsets, transformed_chunk, threshold - ) - - hull_pen_chunks.append(penetration) - hull_col_chunks.append(collides) - - # Concatenate chunks for this hull - hull_penetration = torch.cat(hull_pen_chunks, dim=0) # [B, P] - hull_collides = torch.cat(hull_col_chunks, dim=0) # [B, P] - - all_hull_results.append((hull_penetration, hull_collides)) - - # Merge results across all hulls - overall_collisions, overall_max_penetrations = merge_collision_results( - all_hull_results, device - ) - - return overall_collisions, overall_max_penetrations - - -def main(): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - - # --- Create dummy mesh files for testing --- - box1 = trimesh.primitives.Box(extents=[0.5, 0.5, 0.5]) - box2 = trimesh.primitives.Box( - extents=[0.4, 0.4, 0.4], - transform=trimesh.transformations.translation_matrix([1, 0, 0]), - ) - box1.export("mesh_hull_1.obj") - box2.export("mesh_hull_2.obj") - - sphere_mesh = trimesh.primitives.Sphere(radius=0.3) - sphere_mesh.export("mesh_B.obj") - print("Created dummy mesh files.\n") - - # ==================== Preprocessing ==================== - - # Load externally decomposed convex hull meshes - convex_mesh_files = ["mesh_hull_1.obj", "mesh_hull_2.obj"] - convex_meshes = load_convex_meshes(convex_mesh_files) - if not convex_meshes: - print("No convex hulls loaded. Exiting.") - return - - # Extract plane equations (list of variable-length arrays) - convex_plane_data_np = extract_plane_equations(convex_meshes) - - # Convert to torch tensors on device - convex_planes_torch: List[Tuple[torch.Tensor, torch.Tensor]] = [] - for normals_np, offsets_np in convex_plane_data_np: - convex_planes_torch.append( - ( - torch.tensor(normals_np, device=device), # [Ki, 3] - torch.tensor(offsets_np, device=device), # [Ki] - ) - ) - - # Sample surface points from mesh B - points_np = sample_surface_points("mesh_B.obj", num_points=2048) - points_B = torch.tensor(points_np, device=device) # [P, 3] - - # ==================== Generate test poses ==================== - B = 10000 - chunk_size = 1024 - - # Random rotation matrices via SVD - random_mat = torch.randn(B, 3, 3, device=device) - U, _, Vt = torch.linalg.svd(random_mat) - R = U @ Vt - # Fix reflections to ensure proper rotations (det = +1) - det = torch.det(R) - R[det < 0] *= -1 - - poses = torch.eye(4, device=device).unsqueeze(0).repeat(B, 1, 1) - poses[:, :3, :3] = R - poses[:, :3, 3] = torch.randn(B, 3, device=device) * 0.5 - - # ==================== Run collision detection ==================== - print( - f"\nRunning collision detection: {B} poses, {points_B.shape[0]} points, " - f"{len(convex_planes_torch)} hulls..." - ) - - torch.cuda.synchronize() if device.type == "cuda" else None - start_time = time.time() - - with torch.no_grad(): - collisions, penetration_depths = batch_collision_detection( - convex_planes_torch, points_B, poses, threshold=0.001, chunk_size=chunk_size - ) - - torch.cuda.synchronize() if device.type == "cuda" else None - elapsed = time.time() - start_time - - # ==================== Report results ==================== - print(f"\n{'='*40}") - print(f"Total poses: {B}") - print(f"Collisions: {collisions.sum().item()} / {B}") - if collisions.any(): - print(f"Max penetration: {penetration_depths[collisions].max().item():.6f}") - else: - print(f"Max penetration: N/A (no collisions)") - print(f"Total time: {elapsed:.3f}s") - print(f"Per pose: {elapsed / B * 1e6:.2f} μs") - print(f"{'='*40}") - - # ==================== Benchmark ==================== - num_iters = 50 - torch.cuda.synchronize() if device.type == "cuda" else None - t0 = time.time() - for _ in range(num_iters): - with torch.no_grad(): - batch_collision_detection( - convex_planes_torch, - points_B, - poses, - threshold=0.001, - chunk_size=chunk_size, - ) - torch.cuda.synchronize() if device.type == "cuda" else None - t1 = time.time() - - avg_ms = (t1 - t0) / num_iters * 1000 - print( - f"\nBenchmark ({num_iters} iters): {avg_ms:.2f} ms/iter, " - f"{avg_ms / B * 1000:.2f} μs/pose" - ) - - if __name__ == "__main__": from embodichain.data import get_data_path diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py new file mode 100644 index 00000000..e69de29b From e0d129d6a7dfce09259e674ec262615370479c43 Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 30 Mar 2026 10:39:45 +0800 Subject: [PATCH 07/25] TODO: too slow --- embodichain/lab/sim/objects/rigid_object.py | 18 +- .../graspkit/pg_grasp/antipodal_annotator.py | 126 +++++++--- .../pg_grasp/batch_collision_checker.py | 141 ++++++++--- .../pg_grasp/gripper_collision_checker.py | 230 ++++++++++++++++++ examples/sim/demo/grasp_mug.py | 2 +- 5 files changed, 431 insertions(+), 86 deletions(-) diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index 2e058ad3..4da6e632 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -1140,8 +1140,6 @@ def get_grasp_pose( ) approach_direction = F.normalize(approach_direction, dim=-1) if hasattr(self, "_grasp_annotator") is False: - self._grasp_annotator = GraspAnnotator(cfg=cfg) - if hasattr(self, "_hit_point_pairs") is False or cfg.force_regenerate: vertices = torch.tensor( self._entities[0].get_vertices(), dtype=torch.float32, @@ -1156,7 +1154,13 @@ def get_grasp_pose( device=self.device, ) vertices = vertices * scale - self._hit_point_pairs = self._grasp_annotator.annotate(vertices, triangles) + self._grasp_annotator = GraspAnnotator( + vertices=vertices, triangles=triangles, cfg=cfg + ) + + # Annotate antipodal point pairs + if hasattr(self, "_hit_point_pairs") is False or cfg.force_regenerate: + self._hit_point_pairs = self._grasp_annotator.annotate() poses = self.get_local_pose(to_matrix=True) poses = torch.as_tensor(poses, dtype=torch.float32, device=self.device) @@ -1177,13 +1181,7 @@ def get_grasp_pose( triangles = self._entities[0].get_triangles() scale = self._entities[0].get_body_scale() vertices = vertices * scale - GraspAnnotator.visualize_grasp_pose( - vertices=torch.tensor( - vertices, dtype=torch.float32, device=self.device - ), - triangles=torch.tensor( - triangles, dtype=torch.int32, device=self.device - ), + self._grasp_annotator.visualize_grasp_pose( obj_pose=poses[0], grasp_pose=grasp_poses[0], open_length=open_lengths[0], diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py index 5ee3eda4..bf84a9f3 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -32,6 +32,10 @@ AntipodalSampler, AntipodalSamplerCfg, ) +from .gripper_collision_checker import ( + SimpleGripperCollisionChecker, + SimpleGripperCollisionCfg, +) import hashlib import torch.nn.functional as F import tempfile @@ -55,18 +59,37 @@ class SelectResult: class GraspAnnotator: - def __init__(self, cfg: GraspAnnotatorCfg = GraspAnnotatorCfg()) -> None: + def __init__( + self, + vertices: torch.Tensor, + triangles: torch.Tensor, + cfg: GraspAnnotatorCfg = GraspAnnotatorCfg(), + ) -> None: + self.device = vertices.device + self.vertices = vertices + self.triangles = triangles + self.mesh = trimesh.Trimesh( + vertices=vertices.to("cpu").numpy(), + faces=triangles.to("cpu").numpy(), + process=False, + force="mesh", + ) + self._collision_checker = SimpleGripperCollisionChecker( + object_mesh_verts=vertices, + object_mesh_faces=triangles, + cfg=SimpleGripperCollisionCfg(), + ) self.cfg = cfg self.antipodal_sampler = AntipodalSampler(cfg=cfg.antipodal_sampler_cfg) - def annotate(self, vertices: torch.Tensor, triangles: torch.Tensor): - cache_path = self._get_cache_dir(vertices, triangles) + def annotate(self): + cache_path = self._get_cache_dir(self.vertices, self.triangles) if os.path.exists(cache_path) and not self.cfg.force_regenerate: logger.log_info( f"Found existing antipodal retult. Loading cached antipodal pairs from {cache_path}" ) hit_point_pairs = torch.tensor( - np.load(cache_path), dtype=torch.float32, device=vertices.device + np.load(cache_path), dtype=torch.float32, device=self.device ) return hit_point_pairs else: @@ -74,14 +97,6 @@ def annotate(self, vertices: torch.Tensor, triangles: torch.Tensor): f"[Viser] *****Annotate grasp region in http://localhost:{self.cfg.viser_port}" ) - self.mesh = trimesh.Trimesh( - vertices=vertices.to("cpu").numpy(), - faces=triangles.to("cpu").numpy(), - process=False, - force="mesh", - ) - self.device = vertices.device - server = viser.ViserServer(port=self.cfg.viser_port) server.gui.configure_theme(brand_color=(130, 0, 150)) server.scene.set_up_direction("+z") @@ -362,56 +377,93 @@ def get_approach_grasp_poses( """ origin_points = hit_point_pairs[:, 0, :] hit_points = hit_point_pairs[:, 1, :] - print("origin_points dtype:", origin_points.dtype) - print("object_pose dtype:", object_pose.dtype) origin_points_ = self._apply_transform(origin_points, object_pose) hit_points_ = self._apply_transform(hit_points, object_pose) centers = (origin_points_ + hit_points_) / 2 center = centers.mean(dim=0) - # get best grasp pose + # filter perpendicular antipodal point grasp_x = F.normalize(hit_points_ - origin_points_, dim=-1) cos_angle = torch.clamp((grasp_x * approach_direction).sum(dim=-1), -1.0, 1.0) positive_angle = torch.abs(torch.acos(cos_angle)) - antipodal_length = torch.norm(hit_points_ - origin_points_, dim=-1) - length_cost = 1 - antipodal_length / antipodal_length.max() + valid_mask = ( + positive_angle - torch.pi / 2 + ).abs() <= self.cfg.max_deviation_angle + valid_grasp_x = grasp_x[valid_mask] + valid_centers = centers[valid_mask] + + # compute grasp poses using antipodal point pairs and approach direction + valid_grasp_poses = GraspAnnotator._grasp_pose_from_approach_direction( + valid_grasp_x, approach_direction, valid_centers + ) + valid_open_lengths = torch.norm( + origin_points_[valid_mask] - hit_points_[valid_mask], dim=-1 + ) + # select non-collide grasp poses + + is_colliding, max_penetration = self._collision_checker.query( + object_pose, valid_grasp_poses, valid_open_lengths + ) + + # get best grasp pose + valid_grasp_poses = valid_grasp_poses[~is_colliding] + valid_open_lengths = valid_open_lengths[~is_colliding] + valid_centers = valid_centers[~is_colliding] + valid_grasp_x = F.normalize(valid_grasp_poses[:, :3, 0], dim=-1) + + cos_angle = torch.clamp( + (valid_grasp_x * approach_direction).sum(dim=-1), -1.0, 1.0 + ) + positive_angle = torch.abs(torch.acos(cos_angle)) angle_cost = torch.abs(positive_angle - 0.5 * torch.pi) / (0.5 * torch.pi) - center_distance = torch.norm(centers - center, dim=-1) + center_distance = torch.norm(valid_centers - center, dim=-1) center_cost = center_distance / center_distance.max() + length_cost = 1 - valid_open_lengths / valid_open_lengths.max() total_cost = 0.4 * angle_cost + 0.3 * length_cost + 0.3 * center_cost best_idx = torch.argmin(total_cost) - - best_open_length = torch.norm(hit_points_[best_idx] - origin_points_[best_idx]) - best_grasp_x = grasp_x[best_idx] - best_grasp_center = centers[best_idx] - best_grasp_y = torch.cross(approach_direction, best_grasp_x, dim=0) - best_grasp_y = F.normalize(best_grasp_y, dim=-1) - best_grasp_z = torch.cross(best_grasp_x, best_grasp_y, dim=0) - best_grasp_z = F.normalize(best_grasp_z, dim=-1) - grasp_pose = torch.eye(4, device=hit_point_pairs.device, dtype=torch.float32) - grasp_pose[:3, 0] = best_grasp_x - grasp_pose[:3, 1] = best_grasp_y - grasp_pose[:3, 2] = best_grasp_z - grasp_pose[:3, 3] = best_grasp_center - return grasp_pose, best_open_length + best_grasp_pose = valid_grasp_poses[best_idx] + best_open_length = valid_open_lengths[best_idx] + return best_grasp_pose, best_open_length @staticmethod + def _grasp_pose_from_approach_direction( + grasp_x: torch.Tensor, approach_direction: torch.Tensor, center: torch.Tensor + ): + approach_direction_repeat = approach_direction[None, :].repeat( + grasp_x.shape[0], 1 + ) + grasp_y = torch.cross(approach_direction_repeat, grasp_x, dim=-1) + grasp_y = F.normalize(grasp_y, dim=-1) + grasp_z = torch.cross(grasp_x, grasp_y, dim=-1) + grasp_z = F.normalize(grasp_z, dim=-1) + grasp_poses = ( + torch.eye(4, device=grasp_x.device, dtype=torch.float32) + .unsqueeze(0) + .repeat(grasp_x.shape[0], 1, 1) + ) + grasp_poses[:, :3, 0] = grasp_x + grasp_poses[:, :3, 1] = grasp_y + grasp_poses[:, :3, 2] = grasp_z + grasp_poses[:, :3, 3] = center + return grasp_poses + def visualize_grasp_pose( - vertices: torch.Tensor, - triangles: torch.Tensor, + self, obj_pose: torch.Tensor, grasp_pose: torch.Tensor, open_length: float, ): mesh = o3d.geometry.TriangleMesh( - vertices=o3d.utility.Vector3dVector(vertices.to("cpu").numpy()), - triangles=o3d.utility.Vector3iVector(triangles.to("cpu").numpy()), + vertices=o3d.utility.Vector3dVector(self.vertices.to("cpu").numpy()), + triangles=o3d.utility.Vector3iVector(self.triangles.to("cpu").numpy()), ) mesh.compute_vertex_normals() mesh.paint_uniform_color([0.3, 0.6, 0.3]) mesh.transform(obj_pose.to("cpu").numpy()) vertices_ = torch.tensor( - np.asarray(mesh.vertices), device=vertices.device, dtype=vertices.dtype + np.asarray(mesh.vertices), + device=self.vertices.device, + dtype=self.vertices.dtype, ) mesh_scale = (vertices_.max(dim=0)[0] - vertices_.min(dim=0)[0]).max().item() groud_plane = o3d.geometry.TriangleMesh.create_cylinder( diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py index cf18b76e..25327bea 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -11,7 +11,6 @@ import open3d as o3d from embodichain.utils import logger - CONVEX_CACHE_DIR = os.path.join( os.path.expanduser("~"), ".cache", "embodichain_cache", "convex_decomposition" ) @@ -33,6 +32,7 @@ def __init__( ): if not os.path.isdir(CONVEX_CACHE_DIR): os.makedirs(CONVEX_CACHE_DIR, exist_ok=True) + self.device = base_mesh_verts.device base_mesh_verts_np = base_mesh_verts.cpu().numpy() base_mesh_faces_np = base_mesh_faces.cpu().numpy() mesh_hash = hashlib.md5( @@ -45,6 +45,7 @@ def __init__( triangles=o3d.utility.Vector3iVector(base_mesh_faces_np), ) self.mesh.compute_vertex_normals() + self.cache_path = os.path.join( CONVEX_CACHE_DIR, f"{mesh_hash}_{max_decomposition_hulls}.pkl" ) @@ -59,6 +60,58 @@ def __init__( # load precomputed plane equations from cache self.plane_equations = pickle.load(open(self.cache_path, "rb")) + def query_batch_points( + self, + batch_points: torch.Tensor, + collision_threshold: float = 0.0, + is_visual: bool = False, + ): + n_batch = batch_points.shape[0] + n_points = batch_points.shape[1] + penetration_result = torch.zeros(size=(n_batch, n_points), device=self.device) + penetration_result.fill_(-float("inf")) + collision_result = torch.zeros( + size=(n_batch, n_points), dtype=torch.bool, device=self.device + ) + collision_result.fill_(False) + for normals, offsets in self.plane_equations: + normals_torch = torch.tensor(normals, device=self.device) + offsets_torch = torch.tensor(offsets, device=self.device) + penetration, collides = check_collision_single_hull( + normals_torch, + offsets_torch, + batch_points, + collision_threshold, + ) + penetration_result = torch.max(penetration_result, penetration) + collision_result = torch.logical_or(collision_result, collides) + is_colliding = collision_result.any(dim=-1) # [B] + max_penetration = penetration_result.max(dim=-1)[0] # [B] + + if is_visual: + # visualize result + frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) + for i in range(n_batch): + query_points_o3d = o3d.geometry.PointCloud() + query_points_np = batch_points[i].cpu().numpy() + query_points_o3d.points = o3d.utility.Vector3dVector(query_points_np) + query_points_color = np.zeros_like(query_points_np) + query_points_color[collision_result[i].cpu().numpy()] = [ + 1.0, + 0, + 0, + ] # red for colliding points + query_points_color[~collision_result[i].cpu().numpy()] = [ + 0, + 1.0, + 0, + ] # green for non-colliding points + query_points_o3d.colors = o3d.utility.Vector3dVector(query_points_color) + o3d.visualization.draw_geometries( + [self.mesh, query_points_o3d, frame], mesh_show_back_face=True + ) + return is_colliding, max_penetration + def query( self, query_mesh_verts: torch.Tensor, @@ -98,24 +151,25 @@ def query( if cfg.debug: # visualize result - query_points_o3d = o3d.geometry.PointCloud() - query_points_o3d.points = o3d.utility.Vector3dVector(query_points_np) - query_points_o3d.transform(poses[-1].to("cpu").numpy()) - query_points_color = np.zeros_like(query_points_np) - query_points_color[collision_result[-1].cpu().numpy()] = [ - 1.0, - 0, - 0, - ] # red for colliding points - query_points_color[~collision_result[-1].cpu().numpy()] = [ - 0, - 1.0, - 0, - ] # green for non-colliding points - query_points_o3d.colors = o3d.utility.Vector3dVector(query_points_color) - o3d.visualization.draw_geometries( - [self.mesh, query_points_o3d], mesh_show_back_face=True - ) + for i in range(n_batch): + query_points_o3d = o3d.geometry.PointCloud() + query_points_o3d.points = o3d.utility.Vector3dVector(query_points_np) + query_points_o3d.transform(poses[i].to("cpu").numpy()) + query_points_color = np.zeros_like(query_points_np) + query_points_color[collision_result[i].cpu().numpy()] = [ + 1.0, + 0, + 0, + ] # red for colliding points + query_points_color[~collision_result[i].cpu().numpy()] = [ + 0, + 1.0, + 0, + ] # green for non-colliding points + query_points_o3d.colors = o3d.utility.Vector3dVector(query_points_color) + o3d.visualization.draw_geometries( + [self.mesh, query_points_o3d], mesh_show_back_face=True + ) return is_colliding, max_penetration @staticmethod @@ -310,23 +364,21 @@ def transform_points_batch( if __name__ == "__main__": from embodichain.data import get_data_path - bottle_a_path = get_data_path("ScannedBottle/moliwulong_processed.ply") - bottle_b_path = get_data_path("ScannedBottle/yibao_processed.ply") - - bottle_a_mesh = trimesh.load(bottle_a_path) - bottle_b_mesh = trimesh.load(bottle_b_path) - bottle_a_verts = torch.tensor(bottle_a_mesh.vertices, dtype=torch.float32) - bottle_a_faces = torch.tensor(bottle_a_mesh.faces, dtype=torch.int64) - bottle_b_verts = torch.tensor(bottle_b_mesh.vertices, dtype=torch.float32) - bottle_b_faces = torch.tensor(bottle_b_mesh.faces, dtype=torch.int64) + mug_path = get_data_path("CoffeeCup/cup.ply") + mug_path = get_data_path("ScannedBottle/moliwulong_processed.ply") + mug_mesh = trimesh.load(mug_path, force="mesh", process=False) + verts = torch.tensor(mug_mesh.vertices, dtype=torch.float32) + faces = torch.tensor(mug_mesh.faces, dtype=torch.int32) + collision_checker = BatchConvexCollisionChecker( + verts, faces, max_decomposition_hulls=16 + ) - collision_checker = BatchConvexCollisionChecker(bottle_a_verts, bottle_a_faces) poses = torch.tensor( [ [ [1, 0, 0, 0], [0, 1, 0, 0], - [0, 0, 1, 1.0], + [0, 0, 1, 0.05], [0, 0, 0, 1], ], [ @@ -337,13 +389,26 @@ def transform_points_batch( ], ] ) - check_cfg = BatchConvexCollisionCheckerCfg( - debug=False, - n_query_mesh_samples=32768, - collsion_threshold=-0.003, + from scipy.spatial.transform import Rotation + + rot = Rotation.from_euler("xyz", [12, 3, 32], degrees=True).as_matrix() + poses[0, :3, :3] = torch.tensor(rot, dtype=torch.float32) + poses[1, :3, :3] = torch.tensor(rot, dtype=torch.float32) + + obj_path = get_data_path("ScannedBottle/yibao_processed.ply") + obj_mesh = trimesh.load(obj_path, force="mesh", process=False) + obj_verts = torch.tensor(obj_mesh.vertices, dtype=torch.float32) + obj_faces = torch.tensor(obj_mesh.faces, dtype=torch.int32) + test_pc = transform_points_batch(obj_verts, poses) + + collision_checker.query_batch_points( + test_pc, collision_threshold=0.003, is_visual=True ) - collisions, penetrations = collision_checker.query( - bottle_b_verts, bottle_b_faces, poses, cfg=check_cfg + collision_checker.query( + obj_verts, + obj_faces, + poses, + cfg=BatchConvexCollisionCheckerCfg( + debug=True, n_query_mesh_samples=32768, collsion_threshold=0.000 + ), ) - print("Collisions:", collisions) - print("Penetrations:", penetrations) diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py index e69de29b..13dbe162 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -0,0 +1,230 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence +from .batch_collision_checker import BatchConvexCollisionChecker +import torch + + +@dataclass +class SimpleGripperCollisionCfg: + max_open_length: float = 0.1 + finger_length: float = 0.16 + y_thickness: float = 0.03 + x_thickness: float = 0.01 + root_z_width: float = 0.06 + device = torch.device("cpu") + rough_dense: float = 0.01 + max_decomposition_hulls: int = 16 + + +class SimpleGripperCollisionChecker: + def __init__( + self, + object_mesh_verts: torch.Tensor, + object_mesh_faces: torch.Tensor, + cfg: SimpleGripperCollisionCfg = SimpleGripperCollisionCfg(), + ): + self._checker = BatchConvexCollisionChecker( + base_mesh_verts=object_mesh_verts, + base_mesh_faces=object_mesh_faces, + max_decomposition_hulls=cfg.max_decomposition_hulls, + ) + self.cfg = cfg + self._init_pc_template() + + def _init_pc_template(self): + self.root_template = box_surface_grid( + size=( + self.cfg.max_open_length, + self.cfg.y_thickness, + self.cfg.root_z_width, + ), + dense=self.cfg.rough_dense, + ) + self.left_template = box_surface_grid( + size=(self.cfg.x_thickness, self.cfg.y_thickness, self.cfg.finger_length), + dense=self.cfg.rough_dense, + ) + self.right_template = box_surface_grid( + size=(self.cfg.x_thickness, self.cfg.y_thickness, self.cfg.finger_length), + dense=self.cfg.rough_dense, + ) + + def _get_gripper_pc( + self, grasp_poses: torch.Tensor, open_lengths: torch.Tensor + ) -> torch.Tensor: + """ + Args: + grasp_poses: [B, 4, 4] homogeneous transformation matrix of the gripper root frame. + open_lengths: [B] opening length of the gripper fingers. + Returns: + gripper_pc: [B, P, 3] point cloud of the gripper in the world frame. + """ + + root_grasp_poses = grasp_poses.clone() + root_grasp_poses[:, :3, 3] -= ( + root_grasp_poses[:, :3, 2] + * 0.5 + * (self.cfg.finger_length + self.cfg.root_z_width) + ) + open_lengths_repeat = open_lengths[:, None].repeat(1, 3) + left_finger_poses = grasp_poses.clone() + left_finger_poses[:, :3, 3] -= left_finger_poses[:, :3, 0] * open_lengths_repeat + + right_finger_poses = grasp_poses.clone() + right_finger_poses[:, :3, 3] += ( + right_finger_poses[:, :3, 0] * open_lengths_repeat + ) + + root_pc = transform_points_batch(self.root_template, root_grasp_poses) + left_pc = transform_points_batch(self.left_template, left_finger_poses) + right_pc = transform_points_batch(self.right_template, right_finger_poses) + gripper_pc = torch.cat([root_pc, left_pc, right_pc], dim=1) + return gripper_pc + + def query( + self, + obj_pose: torch.Tensor, + grasp_poses: torch.Tensor, + open_lengths: torch.Tensor, + ) -> torch.Tensor: + inv_obj_pose = obj_pose.clone() + inv_obj_pose[:3, :3] = obj_pose[:3, :3].T + inv_obj_pose[:3, 3] = -obj_pose[:3, 3] @ obj_pose[:3, :3] + inv_obj_poses = inv_obj_pose[None, :, :].repeat(grasp_poses.shape[0], 1, 1) + grasp_relative_pose = torch.bmm(inv_obj_poses, grasp_poses) + gripper_pc = self._get_gripper_pc(grasp_relative_pose, open_lengths) + return self._checker.query_batch_points( + gripper_pc, collision_threshold=0.005, is_visual=False + ) + + +def transform_points_batch( + points: torch.Tensor, poses: torch.Tensor # [P, 3] # [B, 4, 4] +) -> torch.Tensor: + """ + Apply a batch of rigid transforms to a point cloud. + + Args: + points: [P, 3] source point cloud. + poses: [B, 4, 4] batch of homogeneous transformation matrices. + + Returns: + transformed: [B, P, 3] transformed point cloud for each pose. + """ + R = poses[:, :3, :3] # [B, 3, 3] + t = poses[:, :3, 3] # [B, 3] + transformed = torch.einsum("bij, pj -> bpi", R, points) + t.unsqueeze(1) + return transformed + + +def box_surface_grid( + size: Sequence[float] | torch.Tensor, + dense: float, + device: torch.device | str = "cpu", +) -> torch.Tensor: + """Generate grid-sampled points on the surface of an axis-aligned box. + + Six faces of the box are each sampled independently on a regular 2-D grid. + Grid resolution per face is derived automatically from ``dense``: + the number of sample points along an edge of length *L* is + ``max(2, round(L * dense) + 1)``, so ``dense`` behaves as + *approximate samples per unit length*. + + Edge and corner points are shared across adjacent faces and are included + exactly once (no duplicates). + + Args: + size: Box dimensions ``(sx, sy, sz)``. Accepts a sequence of three + floats or a 1-D :class:`torch.Tensor` of length 3. + dense: Approximate number of grid sample points per unit length along + each edge. Higher values yield denser point clouds. + device: Target PyTorch device for the returned tensor. + + Returns: + Float tensor of shape ``(N, 3)`` containing surface points expressed + in the box's local frame (origin at the box centre). + + Example: + >>> pts = box_surface_grid((0.1, 0.06, 0.03), dense=200.0) + >>> pts.shape + torch.Size([..., 3]) + """ + if isinstance(size, torch.Tensor): + sx, sy, sz = size[0].item(), size[1].item(), size[2].item() + else: + sx, sy, sz = float(size[0]), float(size[1]), float(size[2]) + + hx, hy, hz = sx / 2.0, sy / 2.0, sz / 2.0 + + # ── grid resolution per axis (at least 2 points to span the full edge) ── + nx = max(2, round(sx / dense) + 1) + ny = max(2, round(sy / dense) + 1) + nz = max(2, round(sz / dense) + 1) + + xs = torch.linspace(-hx, hx, nx, device=device) + ys = torch.linspace(-hy, hy, ny, device=device) + zs = torch.linspace(-hz, hz, nz, device=device) + + # Interior slices (exclude first and last to avoid duplicate edges) + xs_inner = xs[1:-1] # length nx-2 + ys_inner = ys[1:-1] # length ny-2 + + def _grid( + u: torch.Tensor, v: torch.Tensor, axis: int, offset: float + ) -> torch.Tensor: + """Build a flat (M, 3) tensor for one face grid. + + Args: + u: 1-D tensor of coordinates along the first in-plane axis. + v: 1-D tensor of coordinates along the second in-plane axis. + axis: Normal axis of the face — 0 (±X), 1 (±Y), or 2 (±Z). + offset: Signed half-extent along ``axis``. + + Returns: + Tensor of shape ``(len(u) * len(v), 3)``. + """ + uu, vv = torch.meshgrid(u, v, indexing="ij") + uu = uu.reshape(-1) + vv = vv.reshape(-1) + cc = torch.full_like(uu, offset) + if axis == 0: + return torch.stack([cc, uu, vv], dim=-1) + elif axis == 1: + return torch.stack([uu, cc, vv], dim=-1) + else: + return torch.stack([uu, vv, cc], dim=-1) + + # ───────────────────────────────────────────────────────────────────────── + # Build 6 faces. To avoid duplicate points on shared edges/corners: + # ±X faces → full NY × NZ grids + # ±Y faces → (NX-2) × NZ grids (x-edges owned by ±X faces) + # ±Z faces → (NX-2) × (NY-2) grids (x- and y-edges owned above) + # ───────────────────────────────────────────────────────────────────────── + faces: list[torch.Tensor] = [ + _grid(ys, zs, axis=0, offset=-hx), # −X face (NY × NZ) + _grid(ys, zs, axis=0, offset=+hx), # +X face (NY × NZ) + _grid(xs_inner, zs, axis=1, offset=-hy), # −Y face ((NX-2) × NZ) + _grid(xs_inner, zs, axis=1, offset=+hy), # +Y face ((NX-2) × NZ) + _grid(xs_inner, ys_inner, axis=2, offset=-hz), # −Z face + _grid(xs_inner, ys_inner, axis=2, offset=+hz), # +Z face + ] + + return torch.cat(faces, dim=0) diff --git a/examples/sim/demo/grasp_mug.py b/examples/sim/demo/grasp_mug.py index 6ff56d69..ac68c073 100644 --- a/examples/sim/demo/grasp_mug.py +++ b/examples/sim/demo/grasp_mug.py @@ -236,7 +236,7 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso antipodal_sampler_cfg=AntipodalSamplerCfg( n_sample=5000, max_length=0.088, min_length=0.003 ), - force_regenerate=True, # force user to annotate grasp region each time + force_regenerate=False, # force user to annotate grasp region each time ) sim.open_window() From 7c55249e4998a0bc410d68d79f40ee4cd0d80223 Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 30 Mar 2026 19:17:15 +0800 Subject: [PATCH 08/25] add collision checker --- embodichain/lab/sim/objects/rigid_object.py | 8 +- .../graspkit/pg_grasp/antipodal_annotator.py | 17 +- .../graspkit/pg_grasp/antipodal_sampler.py | 2 +- .../pg_grasp/batch_collision_checker.py | 168 +++++++++++------- .../pg_grasp/gripper_collision_checker.py | 15 +- embodichain/utils/warp/__init__.py | 2 + .../utils/warp/collision_checker/__init__.py | 17 ++ .../warp/collision_checker/convex_query.py | 55 ++++++ .../tutorials/grasp}/grasp_mug.py | 16 +- 9 files changed, 216 insertions(+), 84 deletions(-) create mode 100644 embodichain/utils/warp/collision_checker/__init__.py create mode 100644 embodichain/utils/warp/collision_checker/convex_query.py rename {examples/sim/demo => scripts/tutorials/grasp}/grasp_mug.py (95%) diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index 4da6e632..f90eee63 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -1164,11 +1164,11 @@ def get_grasp_pose( poses = self.get_local_pose(to_matrix=True) poses = torch.as_tensor(poses, dtype=torch.float32, device=self.device) - grasp_poses = [] - open_lengths = [] + grasp_poses: tuple[torch.Tensor] = [] + open_lengths: tuple[torch.Tensor] = [] for pose in poses: grasp_pose, open_length = self._grasp_annotator.get_approach_grasp_poses( - self._hit_point_pairs, pose, approach_direction + self._hit_point_pairs, pose, approach_direction, is_visual=False ) grasp_poses.append(grasp_pose) open_lengths.append(open_length) @@ -1184,6 +1184,6 @@ def get_grasp_pose( self._grasp_annotator.visualize_grasp_pose( obj_pose=poses[0], grasp_pose=grasp_poses[0], - open_length=open_lengths[0], + open_length=open_lengths[0].item(), ) return grasp_poses diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py index bf84a9f3..2770cbfe 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -364,6 +364,7 @@ def get_approach_grasp_poses( hit_point_pairs: torch.Tensor, object_pose: torch.Tensor, approach_direction: torch.Tensor, + is_visual: bool = False, ) -> torch.Tensor: """Get grasp pose given approach direction @@ -380,7 +381,9 @@ def get_approach_grasp_poses( origin_points_ = self._apply_transform(origin_points, object_pose) hit_points_ = self._apply_transform(hit_points, object_pose) centers = (origin_points_ + hit_points_) / 2 - center = centers.mean(dim=0) + + mesh_vert_transformed = self._apply_transform(self.vertices, object_pose) + mesh_center = mesh_vert_transformed.mean(dim=0) # filter perpendicular antipodal point grasp_x = F.normalize(hit_points_ - origin_points_, dim=-1) @@ -400,11 +403,13 @@ def get_approach_grasp_poses( origin_points_[valid_mask] - hit_points_[valid_mask], dim=-1 ) # select non-collide grasp poses - is_colliding, max_penetration = self._collision_checker.query( - object_pose, valid_grasp_poses, valid_open_lengths + object_pose, + valid_grasp_poses, + valid_open_lengths, + is_visual=is_visual, + collision_threshold=0.0, ) - # get best grasp pose valid_grasp_poses = valid_grasp_poses[~is_colliding] valid_open_lengths = valid_open_lengths[~is_colliding] @@ -416,10 +421,10 @@ def get_approach_grasp_poses( ) positive_angle = torch.abs(torch.acos(cos_angle)) angle_cost = torch.abs(positive_angle - 0.5 * torch.pi) / (0.5 * torch.pi) - center_distance = torch.norm(valid_centers - center, dim=-1) + center_distance = torch.norm(valid_centers - mesh_center, dim=-1) center_cost = center_distance / center_distance.max() length_cost = 1 - valid_open_lengths / valid_open_lengths.max() - total_cost = 0.4 * angle_cost + 0.3 * length_cost + 0.3 * center_cost + total_cost = 0.3 * angle_cost + 0.3 * length_cost + 0.4 * center_cost best_idx = torch.argmin(total_cost) best_grasp_pose = valid_grasp_poses[best_idx] best_open_length = valid_open_lengths[best_idx] diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py index 09e4858e..a840e147 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py @@ -25,7 +25,7 @@ @dataclass class AntipodalSamplerCfg: - n_sample: int = 10000 + n_sample: int = 20000 """surface point sample number""" max_angle: float = np.pi / 12 """maximum angle (in radians) to randomly disturb the ray direction for antipodal point sampling, used to increase the diversity of sampled antipodal points. Note that setting max_angle to 0 will disable the random disturbance and sample antipodal points strictly along the surface normals, which may result in less diverse antipodal points and may not be ideal for all objects or grasping scenarios.""" diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py index 25327bea..7cb35be9 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -1,3 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + import trimesh import numpy as np import torch @@ -10,6 +26,9 @@ import pickle import open3d as o3d from embodichain.utils import logger +from embodichain.utils.warp import convex_signed_distance_kernel +import warp as wp +from embodichain.utils.device_utils import standardize_device_string CONVEX_CACHE_DIR = os.path.join( os.path.expanduser("~"), ".cache", "embodichain_cache", "convex_decomposition" @@ -51,14 +70,84 @@ def __init__( ) if not os.path.isfile(self.cache_path): + # [n_convex, n_max_faces, 4]: plane equations, normals(3) and offsets(1), padded with zeros if a hull has less than n_max_faces + # [n_convex, ]: number of faces for each convex hull + # generate convex hulls and extract plane equations, then cache to disk - self.plane_equations = BatchConvexCollisionChecker._compute_plane_equations( + plane_equations_np = BatchConvexCollisionChecker._compute_plane_equations( base_mesh_verts_np, base_mesh_faces_np, max_decomposition_hulls ) + # pack as a single tensor + n_convex = len(plane_equations_np) + n_max_equation = max(len(normals) for normals, _ in plane_equations_np) + plane_equations = torch.zeros( + size=(n_convex, n_max_equation, 4), + dtype=torch.float32, + device=self.device, + ) + plane_equations_counts = torch.zeros( + n_convex, dtype=torch.int32, device=self.device + ) + for i in range(n_convex): + n_equation = plane_equations_np[i][0].shape[0] + # plane normals + plane_equations[i, :n_equation, :3] = torch.tensor( + plane_equations_np[i][0], device=self.device + ) + # plane offsets + plane_equations[i, :n_equation, 3] = torch.tensor( + plane_equations_np[i][1], device=self.device + ) + plane_equations_counts[i] = n_equation + self.plane_equations = { + "plane_equations": plane_equations, + "plane_equation_counts": plane_equations_counts, + } pickle.dump(self.plane_equations, open(self.cache_path, "wb")) else: - # load precomputed plane equations from cache self.plane_equations = pickle.load(open(self.cache_path, "rb")) + self.plane_equations["plane_equations"] = self.plane_equations[ + "plane_equations" + ].to(self.device) + self.plane_equations["plane_equation_counts"] = self.plane_equations[ + "plane_equation_counts" + ].to(self.device) + + @staticmethod + def batch_point_convex_query( + plane_equations: torch.Tensor, + plane_equation_counts: torch.Tensor, + batch_points: torch.Tensor, + device: torch.device, + collision_threshold: float = -0.003, + ): + plane_equations_wp = wp.from_torch(plane_equations) + plane_equation_counts_wp = wp.from_torch(plane_equation_counts) + batch_points_wp = wp.from_torch(batch_points) + + n_pose = batch_points.shape[0] + n_point = batch_points.shape[1] + n_convex = plane_equations.shape[0] + point_convex_signed_distance_wp = wp.full( + shape=(n_pose, n_point, n_convex), + value=-float("inf"), + dtype=float, + device=standardize_device_string(device), + ) # [n_pose, n_point, n_convex] + wp.launch( + kernel=convex_signed_distance_kernel, + dim=(n_pose, n_point, n_convex), + inputs=(batch_points_wp, plane_equations_wp, plane_equation_counts_wp), + outputs=(point_convex_signed_distance_wp,), + device=standardize_device_string(device), + ) + point_convex_signed_distance = wp.to_torch(point_convex_signed_distance_wp) + # import ipdb; ipdb.set_trace() + point_signed_distance = point_convex_signed_distance.min( + dim=-1 + ).values # [n_pose, n_point] + is_point_collide = point_signed_distance <= collision_threshold + return point_signed_distance, is_point_collide def query_batch_points( self, @@ -67,27 +156,17 @@ def query_batch_points( is_visual: bool = False, ): n_batch = batch_points.shape[0] - n_points = batch_points.shape[1] - penetration_result = torch.zeros(size=(n_batch, n_points), device=self.device) - penetration_result.fill_(-float("inf")) - collision_result = torch.zeros( - size=(n_batch, n_points), dtype=torch.bool, device=self.device - ) - collision_result.fill_(False) - for normals, offsets in self.plane_equations: - normals_torch = torch.tensor(normals, device=self.device) - offsets_torch = torch.tensor(offsets, device=self.device) - penetration, collides = check_collision_single_hull( - normals_torch, - offsets_torch, + point_signed_distance, is_point_collide = ( + BatchConvexCollisionChecker.batch_point_convex_query( + self.plane_equations["plane_equations"], + self.plane_equations["plane_equation_counts"], batch_points, - collision_threshold, + device=self.device, + collision_threshold=collision_threshold, ) - penetration_result = torch.max(penetration_result, penetration) - collision_result = torch.logical_or(collision_result, collides) - is_colliding = collision_result.any(dim=-1) # [B] - max_penetration = penetration_result.max(dim=-1)[0] # [B] - + ) + is_pose_collide = is_point_collide.any(dim=-1) # [B] + pose_surface_distance = point_signed_distance.min(dim=-1).values # [B] if is_visual: # visualize result frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) @@ -96,12 +175,12 @@ def query_batch_points( query_points_np = batch_points[i].cpu().numpy() query_points_o3d.points = o3d.utility.Vector3dVector(query_points_np) query_points_color = np.zeros_like(query_points_np) - query_points_color[collision_result[i].cpu().numpy()] = [ + query_points_color[is_point_collide[i].cpu().numpy()] = [ 1.0, 0, 0, ] # red for colliding points - query_points_color[~collision_result[i].cpu().numpy()] = [ + query_points_color[~is_point_collide[i].cpu().numpy()] = [ 0, 1.0, 0, @@ -110,7 +189,7 @@ def query_batch_points( o3d.visualization.draw_geometries( [self.mesh, query_points_o3d, frame], mesh_show_back_face=True ) - return is_colliding, max_penetration + return is_pose_collide, pose_surface_distance def query( self, @@ -301,47 +380,6 @@ def check_collision_single_hull( return penetration, collides -def merge_collision_results( - hull_results: List[Tuple[torch.Tensor, torch.Tensor]], device: torch.device -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Merge collision detection results from multiple convex hulls. - - A pose is considered colliding if ANY point penetrates ANY convex hull. - The reported penetration depth is the maximum across all points and all hulls. - - Args: - hull_results: List of (penetration [B, P], collides [B, P]) tuples, - one per convex hull, as returned by check_collision_single_hull. - device: torch device. - - Returns: - overall_collisions: [B] boolean, True if the pose collides with any hull. - overall_max_penetrations: [B] float, maximum penetration depth per pose. - """ - if not hull_results: - raise ValueError("hull_results is empty, nothing to merge.") - - B = hull_results[0][0].shape[0] - - overall_collisions = torch.zeros(B, dtype=torch.bool, device=device) - overall_max_penetrations = torch.full( - (B,), -float("inf"), dtype=torch.float32, device=device - ) - - for penetration, collides in hull_results: - # Update collision flag: OR across hulls - # A pose collides if any point collides with this hull - overall_collisions |= collides.any(dim=-1) # [B] - - # Update max penetration: take the per-pose maximum across all points for this hull, - # then compare with the running maximum - hull_max_pen = penetration.max(dim=-1)[0] # [B] - overall_max_penetrations = torch.max(overall_max_penetrations, hull_max_pen) - - return overall_collisions, overall_max_penetrations - - def transform_points_batch( points: torch.Tensor, poses: torch.Tensor # [P, 3] # [B, 4, 4] ) -> torch.Tensor: diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py index 13dbe162..42dfeb1a 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -30,8 +30,9 @@ class SimpleGripperCollisionCfg: x_thickness: float = 0.01 root_z_width: float = 0.06 device = torch.device("cpu") - rough_dense: float = 0.01 + rough_dense: float = 0.015 max_decomposition_hulls: int = 16 + open_check_margin: float = 0.01 class SimpleGripperCollisionChecker: @@ -46,6 +47,7 @@ def __init__( base_mesh_faces=object_mesh_faces, max_decomposition_hulls=cfg.max_decomposition_hulls, ) + self.device = object_mesh_verts.device self.cfg = cfg self._init_pc_template() @@ -57,14 +59,17 @@ def _init_pc_template(self): self.cfg.root_z_width, ), dense=self.cfg.rough_dense, + device=self.device, ) self.left_template = box_surface_grid( size=(self.cfg.x_thickness, self.cfg.y_thickness, self.cfg.finger_length), dense=self.cfg.rough_dense, + device=self.device, ) self.right_template = box_surface_grid( size=(self.cfg.x_thickness, self.cfg.y_thickness, self.cfg.finger_length), dense=self.cfg.rough_dense, + device=self.device, ) def _get_gripper_pc( @@ -84,7 +89,9 @@ def _get_gripper_pc( * 0.5 * (self.cfg.finger_length + self.cfg.root_z_width) ) - open_lengths_repeat = open_lengths[:, None].repeat(1, 3) + open_lengths_repeat = ( + open_lengths[:, None] + self.cfg.open_check_margin + ).repeat(1, 3) left_finger_poses = grasp_poses.clone() left_finger_poses[:, :3, 3] -= left_finger_poses[:, :3, 0] * open_lengths_repeat @@ -104,6 +111,8 @@ def query( obj_pose: torch.Tensor, grasp_poses: torch.Tensor, open_lengths: torch.Tensor, + collision_threshold: float = 0.0, + is_visual: bool = False, ) -> torch.Tensor: inv_obj_pose = obj_pose.clone() inv_obj_pose[:3, :3] = obj_pose[:3, :3].T @@ -112,7 +121,7 @@ def query( grasp_relative_pose = torch.bmm(inv_obj_poses, grasp_poses) gripper_pc = self._get_gripper_pc(grasp_relative_pose, open_lengths) return self._checker.query_batch_points( - gripper_pc, collision_threshold=0.005, is_visual=False + gripper_pc, collision_threshold=collision_threshold, is_visual=is_visual ) diff --git a/embodichain/utils/warp/__init__.py b/embodichain/utils/warp/__init__.py index 905bc9e7..e0fac57a 100644 --- a/embodichain/utils/warp/__init__.py +++ b/embodichain/utils/warp/__init__.py @@ -30,3 +30,5 @@ repeat_first_point, interpolate_along_distance, ) + +from .collision_checker.convex_query import convex_signed_distance_kernel diff --git a/embodichain/utils/warp/collision_checker/__init__.py b/embodichain/utils/warp/collision_checker/__init__.py new file mode 100644 index 00000000..d7e19801 --- /dev/null +++ b/embodichain/utils/warp/collision_checker/__init__.py @@ -0,0 +1,17 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from . import convex_query diff --git a/embodichain/utils/warp/collision_checker/convex_query.py b/embodichain/utils/warp/collision_checker/convex_query.py new file mode 100644 index 00000000..f321e462 --- /dev/null +++ b/embodichain/utils/warp/collision_checker/convex_query.py @@ -0,0 +1,55 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import warp as wp +from typing import Any + + +@wp.kernel(enable_backward=False) +def convex_signed_distance_kernel( + query_points: wp.array(dtype=wp.float32, ndim=3), + plane_equations: wp.array(dtype=wp.float32, ndim=3), + plane_equation_counts: wp.array(dtype=wp.int32, ndim=1), + signed_distances: wp.array(dtype=wp.float32, ndim=3), +): + """ + Compute the signed distance from query points to convex hulls defined by plane equations. + + Args: + query_points: [n_pose, n_point, 3] coordinates of query points. + plane_equations: [n_convex, n_max_equation, 4] plane equations of convex hulls, where each plane equation is represented as (normal_x, normal_y, normal_z, offset). + plane_equation_counts: [n_convex, ] number of valid plane equations for each convex hull. + + Returns: + signed_distances: [n_pose, n_point, n_convex] output signed distances from query points to convex hulls. Should be initialized as +inf before calling this kernel. + """ + pose_id, point_id, convex_id = wp.tid() + n_equation = plane_equation_counts[convex_id] + for i in range(n_equation): + normal_x = plane_equations[convex_id, i, 0] + normal_y = plane_equations[convex_id, i, 1] + normal_z = plane_equations[convex_id, i, 2] + offset = plane_equations[convex_id, i, 3] + signed_distance = ( + query_points[pose_id, point_id, 0] * normal_x + + query_points[pose_id, point_id, 1] * normal_y + + query_points[pose_id, point_id, 2] * normal_z + + offset + ) + # should initialize as -inf + signed_distances[pose_id, point_id, convex_id] = max( + signed_distance, signed_distances[pose_id, point_id, convex_id] + ) diff --git a/examples/sim/demo/grasp_mug.py b/scripts/tutorials/grasp/grasp_mug.py similarity index 95% rename from examples/sim/demo/grasp_mug.py rename to scripts/tutorials/grasp/grasp_mug.py index ac68c073..5a7c89da 100644 --- a/examples/sim/demo/grasp_mug.py +++ b/scripts/tutorials/grasp/grasp_mug.py @@ -68,7 +68,7 @@ def parse_arguments(): parser.add_argument( "--device", type=str, - default="cpu", + default="cuda", help="device to run the environment on, e.g., 'cpu' or 'cuda'", ) return parser.parse_args() @@ -182,7 +182,7 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso rest_arm_qpos = robot.get_qpos("arm") approach_xpos = grasp_xpos.clone() - approach_xpos[:, 2, 3] += 0.04 + approach_xpos[:, 2, 3] += 0.1 _, qpos_approach = robot.compute_ik( pose=approach_xpos, joint_seed=rest_arm_qpos, name="arm" @@ -219,12 +219,14 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso ) all_trajectory = torch.cat([arm_trajectory, hand_trajectory], dim=-1) interp_trajectory = interpolate_with_distance( - trajectory=all_trajectory, interp_num=300, device=sim.device + trajectory=all_trajectory, interp_num=200, device=sim.device ) return interp_trajectory if __name__ == "__main__": + import time + args = parse_arguments() sim = initialize_simulation(args) robot = create_robot(sim, position=[0.0, 0.0, 0.0]) @@ -234,15 +236,17 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso grasp_cfg = GraspAnnotatorCfg( viser_port=11801, antipodal_sampler_cfg=AntipodalSamplerCfg( - n_sample=5000, max_length=0.088, min_length=0.003 + n_sample=20000, max_length=0.088, min_length=0.003 ), - force_regenerate=False, # force user to annotate grasp region each time + force_regenerate=True, # force user to annotate grasp region each time ) sim.open_window() # 1. View grasp object in browser (e.g http://localhost:11801) # 2. press 'Rect Select Region', select grasp region # 3. press 'Confirm Selection' to finish grasp region selection. + + start_time = time.time() grasp_xpos = mug.get_grasp_pose( approach_direction=torch.tensor( [0, 0, -1], dtype=torch.float32, device=sim.device @@ -250,6 +254,8 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso cfg=grasp_cfg, is_visual=True, # visualize selected grasp pose finally ) + cost_time = time.time() - start_time + logger.log_info(f"Get grasp pose cost time: {cost_time:.2f} seconds") grab_traj = get_grasp_traj(sim, robot, grasp_xpos) input("Press Enter to start the grab mug demo...") From 35dcb4407c6301340b629c8a3f71ccadc959e2ad Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 30 Mar 2026 19:33:41 +0800 Subject: [PATCH 09/25] update --- docs/source/tutorial/grasp_generator.rst | 77 ++++++++++++++++++++++++ docs/source/tutorial/index.rst | 1 + 2 files changed, 78 insertions(+) create mode 100644 docs/source/tutorial/grasp_generator.rst diff --git a/docs/source/tutorial/grasp_generator.rst b/docs/source/tutorial/grasp_generator.rst new file mode 100644 index 00000000..51e802e8 --- /dev/null +++ b/docs/source/tutorial/grasp_generator.rst @@ -0,0 +1,77 @@ +Generating and Executing Robot Grasps +====================================== + +.. currentmodule:: embodichain.lab.sim + +This tutorial demonstrates how to generate antipodal grasp poses for a target object and execute a full grasp trajectory with a robot arm. It covers scene initialization, robot and object creation, interactive grasp region annotation, grasp pose computation, and trajectory execution in the simulation loop. + +The Code +~~~~~~~~ + +The tutorial corresponds to the ``grasp_generator.py`` script in the ``scripts/tutorials/grasp`` directory. + +.. dropdown:: Code for grasp_generator.py + :icon: code + + .. literalinclude:: ../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :linenos: + + +The Code Explained +~~~~~~~~~~~~~~~~~~ + +Configuring the simulation +-------------------------- + +Command-line arguments are parsed with ``argparse`` to select the number of parallel environments, the compute device, and optional rendering features such as ray tracing and headless mode. + +.. literalinclude:: ../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :start-at: def parse_arguments(): + :end-at: return parser.parse_args() + +The parsed arguments are passed to ``initialize_simulation``, which builds a :class:`SimulationManagerCfg` and creates the :class:`SimulationManager` instance. When ray tracing is enabled a directional :class:`cfg.LightCfg` is also added to the scene. + +.. literalinclude:: ../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :start-at: def initialize_simulation(args) -> SimulationManager: + :end-at: return sim + +Annotating and computing grasp poses +------------------------------------- + +Grasp generation is performed by :meth:`objects.RigidObject.get_grasp_pose`, which internally runs an antipodal sampler on the object mesh. A :class:`toolkits.graspkit.pg_grasp.GraspAnnotatorCfg` controls sampler parameters (sample count, gripper jaw limits) and the interactive annotation workflow: + +1. Open the visualization in a browser at the reported port (e.g. ``http://localhost:11801``). +2. Use *Rect Select Region* to highlight the area of the object that should be grasped. +3. Click *Confirm Selection* to finalize the region. + +The function returns a batch of ``(N_envs, 4, 4)`` homogeneous transformation matrices representing candidate grasp frames in the world coordinate system. + +For each grasp pose, gripper approach direction in world coordinate is required to compute the antipodal grasp. In this tutorial, we use a fixed approach direction (straight down in world frame) for simplicity, but it can be customized based on the task or object geometry. + +.. literalinclude:: ../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :start-at: # get mug grasp pose + :end-at: logger.log_info(f"Get grasp pose cost time: {cost_time:.2f} seconds") + + +The Code Execution +~~~~~~~~~~~~~~~~~~ + +To run the script, execute the following command from the project root: + +.. code-block:: bash + + python scripts/tutorials/grasp/grasp_generator.py + +A simulation window will open showing the robot and the mug. A browser-based visualizer will also launch (default port ``11801``) for interactive grasp region annotation. + +You can customize the run with additional arguments: + +.. code-block:: bash + + python scripts/tutorials/grasp/grasp_generator.py --num_envs --device --enable_rt --headless + +After confirming the grasp region in the browser, the script will compute a grasp pose, print the elapsed time, and then wait for you to press **Enter** before executing the full grasp trajectory in the simulation. Press **Enter** again to exit once the motion is complete. diff --git a/docs/source/tutorial/index.rst b/docs/source/tutorial/index.rst index 05154047..ac58290c 100644 --- a/docs/source/tutorial/index.rst +++ b/docs/source/tutorial/index.rst @@ -14,6 +14,7 @@ Tutorials sensor motion_gen gizmo + grasp_generator basic_env modular_env rl From 508a7121802a4c158437b208b0689e4acc9c7cfb Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 30 Mar 2026 19:41:33 +0800 Subject: [PATCH 10/25] add comments --- embodichain/lab/sim/objects/rigid_object.py | 2 +- .../graspkit/pg_grasp/antipodal_annotator.py | 16 ++++++++++++-- .../graspkit/pg_grasp/antipodal_sampler.py | 10 +++++++-- .../pg_grasp/batch_collision_checker.py | 22 ++++++++++++++++++- .../pg_grasp/gripper_collision_checker.py | 11 ++++++++++ .../{grasp_mug.py => grasp_generator.py} | 1 + 6 files changed, 56 insertions(+), 6 deletions(-) rename scripts/tutorials/grasp/{grasp_mug.py => grasp_generator.py} (98%) diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index f90eee63..a9cde20a 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -1167,7 +1167,7 @@ def get_grasp_pose( grasp_poses: tuple[torch.Tensor] = [] open_lengths: tuple[torch.Tensor] = [] for pose in poses: - grasp_pose, open_length = self._grasp_annotator.get_approach_grasp_poses( + grasp_pose, open_length = self._grasp_annotator.get_grasp_poses( self._hit_point_pairs, pose, approach_direction, is_visual=False ) grasp_poses.append(grasp_pose) diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py index 2770cbfe..54cac47d 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -59,12 +59,20 @@ class SelectResult: class GraspAnnotator: + """GraspAnnotator provides functionality to annotate antipodal grasp regions on a given object mesh. It allows users to interactively select regions on the mesh and generates antipodal point pairs for grasping based on the selected region. The annotator also includes a collision checker to filter out infeasible grasp poses and can visualize the generated grasp poses in a 3D viewer. + """ def __init__( self, vertices: torch.Tensor, triangles: torch.Tensor, cfg: GraspAnnotatorCfg = GraspAnnotatorCfg(), ) -> None: + """Initialize the GraspAnnotator with the given mesh vertices, triangles, and configuration. + Args: + vertices (torch.Tensor): A tensor of shape (V, 3) representing the vertex positions of the mesh. + triangles (torch.Tensor): A tensor of shape (F, 3) representing the triangle indices of the mesh. + cfg (GraspAnnotatorCfg, optional): Configuration for the grasp annotator. Defaults to GraspAnnotatorCfg(). + """ self.device = vertices.device self.vertices = vertices self.triangles = triangles @@ -82,7 +90,11 @@ def __init__( self.cfg = cfg self.antipodal_sampler = AntipodalSampler(cfg=cfg.antipodal_sampler_cfg) - def annotate(self): + def annotate(self) -> torch.Tensor: + """Annotate antipodal grasp region on the mesh and return sampled antipodal point pairs. + Returns: + torch.Tensor: A tensor of shape (N, 2, 3) representing N antipodal point pairs. Each pair consists of a hit point and its corresponding surface point. + """ cache_path = self._get_cache_dir(self.vertices, self.triangles) if os.path.exists(cache_path) and not self.cfg.force_regenerate: logger.log_info( @@ -359,7 +371,7 @@ def _apply_transform(points: torch.Tensor, transform: torch.Tensor) -> torch.Ten t = transform[:3, 3] return points @ r.T + t - def get_approach_grasp_poses( + def get_grasp_poses( self, hit_point_pairs: torch.Tensor, object_pose: torch.Tensor, diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py index a840e147..cebcafde 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py @@ -25,6 +25,7 @@ @dataclass class AntipodalSamplerCfg: + """ Configuration for AntipodalSampler.""" n_sample: int = 20000 """surface point sample number""" max_angle: float = np.pi / 12 @@ -36,6 +37,7 @@ class AntipodalSamplerCfg: class AntipodalSampler: + """ AntipodalSampler samples antipodal point pairs on a given mesh. It uses Open3D's raycasting functionality to find points on the mesh that are visible along the negative normal direction from uniformly sampled points on the mesh surface. The sampler can also apply a random disturbance to the ray direction to increase the diversity of sampled antipodal points. The resulting antipodal point pairs can be used for grasp generation and annotation tasks.""" def __init__( self, cfg: AntipodalSamplerCfg = AntipodalSamplerCfg(), @@ -46,6 +48,10 @@ def __init__( def sample(self, vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor: """Get sample Antipodal point pair + Args: + vertices: [V, 3] vertex positions of the mesh + faces: [F, 3] triangle indices of the mesh + Returns: hit_point_pairs: [N, 2, 3] tensor of N antipodal point pairs. Each pair consists of a hit point and its corresponding surface point. """ @@ -83,13 +89,13 @@ def sample(self, vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor: ray_origin = torch.vstack([ray_origin, ray_origin]) ray_direc = torch.vstack([ray_direc, disturb_direc]) # casting - return self.get_raycast_result( + return self._get_raycast_result( ray_origin, ray_direc, surface_origin=torch.vstack([sample_points, sample_points]), ) - def get_raycast_result( + def _get_raycast_result( self, ray_origin: torch.Tensor, ray_direc: torch.Tensor, diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py index 7cb35be9..a488108b 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -37,18 +37,31 @@ @dataclass class BatchConvexCollisionCheckerCfg: + """ Configuration for BatchConvexCollisionChecker.""" + collsion_threshold: float = 0.0 + """ Collision threshold in meters. A point is considered colliding if its signed distance to the hull interior is <= this threshold. This allows for a margin of error in collision checking, where a small positive threshold can be used to consider points near the surface as colliding, and a small negative threshold can be used to allow for slight penetration without considering it a collision.""" n_query_mesh_samples: int = 4096 + """ Number of points to sample from the query mesh surface for collision checking. A higher number of samples can provide a more accurate collision check at the cost of increased computation time. The optimal number may depend on the complexity of the mesh and the required precision of collision detection.""" debug: bool = False + """ Whether to visualize the collision checking results for debugging purposes. If set to True, the code will generate visualizations of the query points colored by their collision status (e.g., red for colliding points and green for non-colliding points) along with the original mesh. This can help in understanding and verifying the collision checking process, especially during development and testing.""" class BatchConvexCollisionChecker: + """ BatchConvexCollisionChecker performs efficient collision checking between a batch of query point clouds and a convex decomposition of a mesh. The convex decomposition is represented by plane equations of the convex hulls, which are precomputed and cached for efficiency. The collision checking is done by computing the signed distance from each query point to the convex hulls using the plane equations, and determining if any points are colliding based on a specified collision threshold. This class can be used""" + def __init__( self, base_mesh_verts: torch.Tensor, base_mesh_faces: torch.Tensor, max_decomposition_hulls: int = 32, ): + """ Initialize the BatchConvexCollisionChecker by performing convex decomposition on the input mesh and extracting plane equations for the convex hulls. The plane equations are cached to disk to avoid redundant computation in future runs. + Args: + base_mesh_verts: [N, 3] vertex positions of the input mesh. + base_mesh_faces: [M, 3] triangle indices of the input mesh. + max_decomposition_hulls: maximum number of convex hulls to decompose into. A higher number allows for a more accurate approximation of the original mesh but increases computation time and memory usage. The optimal number may depend on the complexity of the mesh and the required precision of collision checking. + """ if not os.path.isdir(CONVEX_CACHE_DIR): os.makedirs(CONVEX_CACHE_DIR, exist_ok=True) self.device = base_mesh_verts.device @@ -154,7 +167,14 @@ def query_batch_points( batch_points: torch.Tensor, collision_threshold: float = 0.0, is_visual: bool = False, - ): + ) -> torch.Tensor: + """ Query collision status for a batch of point clouds. The collision status is determined by checking if the signed distance from any point in the cloud to the convex hulls is less than or equal to the specified collision threshold. + Args: + batch_points: [B, n_point, 3] batch of point clouds to query for collision status. + collision_threshold: Collision threshold in meters. A point is considered colliding if its signed distance to the hull interior is <= this threshold. This allows for a margin of error in collision checking, where a small positive threshold can be used to consider points near the surface as colliding, and a small negative threshold can be used to allow for slight penetration without considering it a collision. + is_visual: Whether to visualize the collision checking results for debugging purposes. If set to True, the code will generate visualizations of the query points colored by their collision status (e.g., red for colliding points and green for non-colliding points) along with the original mesh. This can help in understanding and verifying the collision checking process, especially during development and testing. + Returns: + is_pose_collide: [B, ] boolean tensor indicating whether each point cloud in the""" n_batch = batch_points.shape[0] point_signed_distance, is_point_collide = ( BatchConvexCollisionChecker.batch_point_convex_query( diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py index 42dfeb1a..bacd6037 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -24,15 +24,26 @@ @dataclass class SimpleGripperCollisionCfg: + """ Configuration for the SimpleGripperCollisionChecker. This class defines various parameters related to the gripper geometry, point cloud generation, and collision checking process. Users can customize these parameters based on the specific gripper being modeled and the requirements of the application.""" + max_open_length: float = 0.1 + """ Maximum opening length of the gripper fingers. This should be set according to the specific gripper being modeled, and it defines the maximum distance between the two fingers when fully open.""" finger_length: float = 0.16 + """ Length of the gripper fingers from the root to the tip. This should be set according to the specific gripper being modeled, and it defines how far the fingers extend from the gripper root frame.""" y_thickness: float = 0.03 + """ Thickness of the gripper along the Y-axis (the axis perpendicular to the finger opening direction). This should be set according to the specific gripper being modeled, and it defines the width of the gripper's main body and fingers in the Y direction.""" x_thickness: float = 0.01 + """ Thickness of the gripper along the X-axis (the axis parallel to the finger opening direction). This should be set according to the specific gripper being modeled, and it defines the thickness of the fingers and the root in the X direction.""" root_z_width: float = 0.06 + """ Width of the gripper root along the Z-axis (the axis along the finger length direction). This should be set according to the specific gripper being modeled, and it defines how far the root extends along the Z direction.""" device = torch.device("cpu") + """ Device on which the gripper point cloud will be generated and processed. This should be set according to the computational resources available and the requirements of the application. For example, if using a GPU for collision checking, this should be set to torch.device('cuda'). """ rough_dense: float = 0.015 + """ Approximate number of points per unit length for the gripper point cloud. Higher values will yield denser point clouds, which can improve collision checking accuracy but also increase computational cost. This should be set based on the desired balance between accuracy and efficiency for the specific application.""" max_decomposition_hulls: int = 16 + """ Maximum number of convex hulls to decompose the object mesh into for collision checking. This should be set based on the complexity of the object geometry and the desired accuracy of collision checking. More hulls can provide a tighter approximation of the object shape but will increase computational cost.""" open_check_margin: float = 0.01 + """ Additional margin added to the gripper open length when checking for collisions. This can help account for uncertainties in the gripper pose or object geometry, and can be set based on the specific requirements of the application.""" class SimpleGripperCollisionChecker: diff --git a/scripts/tutorials/grasp/grasp_mug.py b/scripts/tutorials/grasp/grasp_generator.py similarity index 98% rename from scripts/tutorials/grasp/grasp_mug.py rename to scripts/tutorials/grasp/grasp_generator.py index 5a7c89da..9f4450d0 100644 --- a/scripts/tutorials/grasp/grasp_mug.py +++ b/scripts/tutorials/grasp/grasp_generator.py @@ -242,6 +242,7 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso ) sim.open_window() + # Annotate part of the mug to be grasped by following the instructions in the visualization window: # 1. View grasp object in browser (e.g http://localhost:11801) # 2. press 'Rect Select Region', select grasp region # 3. press 'Confirm Selection' to finish grasp region selection. From 3fb7ff5ebba716cd8fa819eb97d82fd87b93c9aa Mon Sep 17 00:00:00 2001 From: yuecideng Date: Wed, 1 Apr 2026 14:19:17 +0800 Subject: [PATCH 11/25] wip --- .../toolkits/graspkit/pg_grasp/antipodal_annotator.py | 4 ++-- .../toolkits/graspkit/pg_grasp/antipodal_sampler.py | 6 ++++-- .../graspkit/pg_grasp/batch_collision_checker.py | 11 ++++++----- .../graspkit/pg_grasp/gripper_collision_checker.py | 2 +- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py index 54cac47d..20177656 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -59,8 +59,8 @@ class SelectResult: class GraspAnnotator: - """GraspAnnotator provides functionality to annotate antipodal grasp regions on a given object mesh. It allows users to interactively select regions on the mesh and generates antipodal point pairs for grasping based on the selected region. The annotator also includes a collision checker to filter out infeasible grasp poses and can visualize the generated grasp poses in a 3D viewer. - """ + """GraspAnnotator provides functionality to annotate antipodal grasp regions on a given object mesh. It allows users to interactively select regions on the mesh and generates antipodal point pairs for grasping based on the selected region. The annotator also includes a collision checker to filter out infeasible grasp poses and can visualize the generated grasp poses in a 3D viewer.""" + def __init__( self, vertices: torch.Tensor, diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py index cebcafde..57632f95 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py @@ -25,7 +25,8 @@ @dataclass class AntipodalSamplerCfg: - """ Configuration for AntipodalSampler.""" + """Configuration for AntipodalSampler.""" + n_sample: int = 20000 """surface point sample number""" max_angle: float = np.pi / 12 @@ -37,7 +38,8 @@ class AntipodalSamplerCfg: class AntipodalSampler: - """ AntipodalSampler samples antipodal point pairs on a given mesh. It uses Open3D's raycasting functionality to find points on the mesh that are visible along the negative normal direction from uniformly sampled points on the mesh surface. The sampler can also apply a random disturbance to the ray direction to increase the diversity of sampled antipodal points. The resulting antipodal point pairs can be used for grasp generation and annotation tasks.""" + """AntipodalSampler samples antipodal point pairs on a given mesh. It uses Open3D's raycasting functionality to find points on the mesh that are visible along the negative normal direction from uniformly sampled points on the mesh surface. The sampler can also apply a random disturbance to the ray direction to increase the diversity of sampled antipodal points. The resulting antipodal point pairs can be used for grasp generation and annotation tasks.""" + def __init__( self, cfg: AntipodalSamplerCfg = AntipodalSamplerCfg(), diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py index a488108b..165f9531 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -37,7 +37,7 @@ @dataclass class BatchConvexCollisionCheckerCfg: - """ Configuration for BatchConvexCollisionChecker.""" + """Configuration for BatchConvexCollisionChecker.""" collsion_threshold: float = 0.0 """ Collision threshold in meters. A point is considered colliding if its signed distance to the hull interior is <= this threshold. This allows for a margin of error in collision checking, where a small positive threshold can be used to consider points near the surface as colliding, and a small negative threshold can be used to allow for slight penetration without considering it a collision.""" @@ -48,7 +48,7 @@ class BatchConvexCollisionCheckerCfg: class BatchConvexCollisionChecker: - """ BatchConvexCollisionChecker performs efficient collision checking between a batch of query point clouds and a convex decomposition of a mesh. The convex decomposition is represented by plane equations of the convex hulls, which are precomputed and cached for efficiency. The collision checking is done by computing the signed distance from each query point to the convex hulls using the plane equations, and determining if any points are colliding based on a specified collision threshold. This class can be used""" + """BatchConvexCollisionChecker performs efficient collision checking between a batch of query point clouds and a convex decomposition of a mesh. The convex decomposition is represented by plane equations of the convex hulls, which are precomputed and cached for efficiency. The collision checking is done by computing the signed distance from each query point to the convex hulls using the plane equations, and determining if any points are colliding based on a specified collision threshold. This class can be used""" def __init__( self, @@ -56,7 +56,7 @@ def __init__( base_mesh_faces: torch.Tensor, max_decomposition_hulls: int = 32, ): - """ Initialize the BatchConvexCollisionChecker by performing convex decomposition on the input mesh and extracting plane equations for the convex hulls. The plane equations are cached to disk to avoid redundant computation in future runs. + """Initialize the BatchConvexCollisionChecker by performing convex decomposition on the input mesh and extracting plane equations for the convex hulls. The plane equations are cached to disk to avoid redundant computation in future runs. Args: base_mesh_verts: [N, 3] vertex positions of the input mesh. base_mesh_faces: [M, 3] triangle indices of the input mesh. @@ -168,13 +168,14 @@ def query_batch_points( collision_threshold: float = 0.0, is_visual: bool = False, ) -> torch.Tensor: - """ Query collision status for a batch of point clouds. The collision status is determined by checking if the signed distance from any point in the cloud to the convex hulls is less than or equal to the specified collision threshold. + """Query collision status for a batch of point clouds. The collision status is determined by checking if the signed distance from any point in the cloud to the convex hulls is less than or equal to the specified collision threshold. Args: batch_points: [B, n_point, 3] batch of point clouds to query for collision status. collision_threshold: Collision threshold in meters. A point is considered colliding if its signed distance to the hull interior is <= this threshold. This allows for a margin of error in collision checking, where a small positive threshold can be used to consider points near the surface as colliding, and a small negative threshold can be used to allow for slight penetration without considering it a collision. is_visual: Whether to visualize the collision checking results for debugging purposes. If set to True, the code will generate visualizations of the query points colored by their collision status (e.g., red for colliding points and green for non-colliding points) along with the original mesh. This can help in understanding and verifying the collision checking process, especially during development and testing. Returns: - is_pose_collide: [B, ] boolean tensor indicating whether each point cloud in the""" + is_pose_collide: [B, ] boolean tensor indicating whether each point cloud in the + """ n_batch = batch_points.shape[0] point_signed_distance, is_point_collide = ( BatchConvexCollisionChecker.batch_point_convex_query( diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py index bacd6037..2f13d16f 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -24,7 +24,7 @@ @dataclass class SimpleGripperCollisionCfg: - """ Configuration for the SimpleGripperCollisionChecker. This class defines various parameters related to the gripper geometry, point cloud generation, and collision checking process. Users can customize these parameters based on the specific gripper being modeled and the requirements of the application.""" + """Configuration for the SimpleGripperCollisionChecker. This class defines various parameters related to the gripper geometry, point cloud generation, and collision checking process. Users can customize these parameters based on the specific gripper being modeled and the requirements of the application.""" max_open_length: float = 0.1 """ Maximum opening length of the gripper fingers. This should be set according to the specific gripper being modeled, and it defines the maximum distance between the two fingers when fully open.""" From ab5506db968687675d0140aa601a0bf8769f0c25 Mon Sep 17 00:00:00 2001 From: chenjian Date: Wed, 1 Apr 2026 14:43:11 +0800 Subject: [PATCH 12/25] style --- .../toolkits/graspkit/pg_grasp/antipodal_annotator.py | 4 ++-- .../toolkits/graspkit/pg_grasp/antipodal_sampler.py | 6 ++++-- .../graspkit/pg_grasp/batch_collision_checker.py | 11 ++++++----- .../graspkit/pg_grasp/gripper_collision_checker.py | 2 +- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py index 54cac47d..20177656 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -59,8 +59,8 @@ class SelectResult: class GraspAnnotator: - """GraspAnnotator provides functionality to annotate antipodal grasp regions on a given object mesh. It allows users to interactively select regions on the mesh and generates antipodal point pairs for grasping based on the selected region. The annotator also includes a collision checker to filter out infeasible grasp poses and can visualize the generated grasp poses in a 3D viewer. - """ + """GraspAnnotator provides functionality to annotate antipodal grasp regions on a given object mesh. It allows users to interactively select regions on the mesh and generates antipodal point pairs for grasping based on the selected region. The annotator also includes a collision checker to filter out infeasible grasp poses and can visualize the generated grasp poses in a 3D viewer.""" + def __init__( self, vertices: torch.Tensor, diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py index cebcafde..57632f95 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py @@ -25,7 +25,8 @@ @dataclass class AntipodalSamplerCfg: - """ Configuration for AntipodalSampler.""" + """Configuration for AntipodalSampler.""" + n_sample: int = 20000 """surface point sample number""" max_angle: float = np.pi / 12 @@ -37,7 +38,8 @@ class AntipodalSamplerCfg: class AntipodalSampler: - """ AntipodalSampler samples antipodal point pairs on a given mesh. It uses Open3D's raycasting functionality to find points on the mesh that are visible along the negative normal direction from uniformly sampled points on the mesh surface. The sampler can also apply a random disturbance to the ray direction to increase the diversity of sampled antipodal points. The resulting antipodal point pairs can be used for grasp generation and annotation tasks.""" + """AntipodalSampler samples antipodal point pairs on a given mesh. It uses Open3D's raycasting functionality to find points on the mesh that are visible along the negative normal direction from uniformly sampled points on the mesh surface. The sampler can also apply a random disturbance to the ray direction to increase the diversity of sampled antipodal points. The resulting antipodal point pairs can be used for grasp generation and annotation tasks.""" + def __init__( self, cfg: AntipodalSamplerCfg = AntipodalSamplerCfg(), diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py index a488108b..165f9531 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -37,7 +37,7 @@ @dataclass class BatchConvexCollisionCheckerCfg: - """ Configuration for BatchConvexCollisionChecker.""" + """Configuration for BatchConvexCollisionChecker.""" collsion_threshold: float = 0.0 """ Collision threshold in meters. A point is considered colliding if its signed distance to the hull interior is <= this threshold. This allows for a margin of error in collision checking, where a small positive threshold can be used to consider points near the surface as colliding, and a small negative threshold can be used to allow for slight penetration without considering it a collision.""" @@ -48,7 +48,7 @@ class BatchConvexCollisionCheckerCfg: class BatchConvexCollisionChecker: - """ BatchConvexCollisionChecker performs efficient collision checking between a batch of query point clouds and a convex decomposition of a mesh. The convex decomposition is represented by plane equations of the convex hulls, which are precomputed and cached for efficiency. The collision checking is done by computing the signed distance from each query point to the convex hulls using the plane equations, and determining if any points are colliding based on a specified collision threshold. This class can be used""" + """BatchConvexCollisionChecker performs efficient collision checking between a batch of query point clouds and a convex decomposition of a mesh. The convex decomposition is represented by plane equations of the convex hulls, which are precomputed and cached for efficiency. The collision checking is done by computing the signed distance from each query point to the convex hulls using the plane equations, and determining if any points are colliding based on a specified collision threshold. This class can be used""" def __init__( self, @@ -56,7 +56,7 @@ def __init__( base_mesh_faces: torch.Tensor, max_decomposition_hulls: int = 32, ): - """ Initialize the BatchConvexCollisionChecker by performing convex decomposition on the input mesh and extracting plane equations for the convex hulls. The plane equations are cached to disk to avoid redundant computation in future runs. + """Initialize the BatchConvexCollisionChecker by performing convex decomposition on the input mesh and extracting plane equations for the convex hulls. The plane equations are cached to disk to avoid redundant computation in future runs. Args: base_mesh_verts: [N, 3] vertex positions of the input mesh. base_mesh_faces: [M, 3] triangle indices of the input mesh. @@ -168,13 +168,14 @@ def query_batch_points( collision_threshold: float = 0.0, is_visual: bool = False, ) -> torch.Tensor: - """ Query collision status for a batch of point clouds. The collision status is determined by checking if the signed distance from any point in the cloud to the convex hulls is less than or equal to the specified collision threshold. + """Query collision status for a batch of point clouds. The collision status is determined by checking if the signed distance from any point in the cloud to the convex hulls is less than or equal to the specified collision threshold. Args: batch_points: [B, n_point, 3] batch of point clouds to query for collision status. collision_threshold: Collision threshold in meters. A point is considered colliding if its signed distance to the hull interior is <= this threshold. This allows for a margin of error in collision checking, where a small positive threshold can be used to consider points near the surface as colliding, and a small negative threshold can be used to allow for slight penetration without considering it a collision. is_visual: Whether to visualize the collision checking results for debugging purposes. If set to True, the code will generate visualizations of the query points colored by their collision status (e.g., red for colliding points and green for non-colliding points) along with the original mesh. This can help in understanding and verifying the collision checking process, especially during development and testing. Returns: - is_pose_collide: [B, ] boolean tensor indicating whether each point cloud in the""" + is_pose_collide: [B, ] boolean tensor indicating whether each point cloud in the + """ n_batch = batch_points.shape[0] point_signed_distance, is_point_collide = ( BatchConvexCollisionChecker.batch_point_convex_query( diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py index bacd6037..2f13d16f 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -24,7 +24,7 @@ @dataclass class SimpleGripperCollisionCfg: - """ Configuration for the SimpleGripperCollisionChecker. This class defines various parameters related to the gripper geometry, point cloud generation, and collision checking process. Users can customize these parameters based on the specific gripper being modeled and the requirements of the application.""" + """Configuration for the SimpleGripperCollisionChecker. This class defines various parameters related to the gripper geometry, point cloud generation, and collision checking process. Users can customize these parameters based on the specific gripper being modeled and the requirements of the application.""" max_open_length: float = 0.1 """ Maximum opening length of the gripper fingers. This should be set according to the specific gripper being modeled, and it defines the maximum distance between the two fingers when fully open.""" From 4b31ae8086af6ba5985ace4cba60d27fd1697e23 Mon Sep 17 00:00:00 2001 From: chenjian Date: Wed, 1 Apr 2026 15:48:01 +0800 Subject: [PATCH 13/25] update --- embodichain/lab/sim/objects/rigid_object.py | 11 ++++- .../graspkit/pg_grasp/antipodal_annotator.py | 9 ++-- .../graspkit/pg_grasp/antipodal_sampler.py | 4 +- .../pg_grasp/batch_collision_checker.py | 49 ++++++++++++------- .../pg_grasp/gripper_collision_checker.py | 19 +++---- scripts/tutorials/grasp/grasp_generator.py | 13 ++++- 6 files changed, 68 insertions(+), 37 deletions(-) diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index a9cde20a..e42b0005 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -39,6 +39,9 @@ GraspAnnotator, GraspAnnotatorCfg, ) +from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import ( + SimpleGripperCollisionCfg, +) import torch.nn.functional as F @@ -1130,7 +1133,8 @@ def destroy(self) -> None: def get_grasp_pose( self, - cfg: GraspAnnotatorCfg, + cfg: GraspAnnotatorCfg = GraspAnnotatorCfg(), + gripper_collision_cfg: SimpleGripperCollisionCfg = SimpleGripperCollisionCfg(), approach_direction: torch.Tensor = None, is_visual: bool = False, ) -> torch.Tensor: @@ -1155,7 +1159,10 @@ def get_grasp_pose( ) vertices = vertices * scale self._grasp_annotator = GraspAnnotator( - vertices=vertices, triangles=triangles, cfg=cfg + vertices=vertices, + triangles=triangles, + cfg=cfg, + gripper_collision_cfg=gripper_collision_cfg, ) # Annotate antipodal point pairs diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py index 20177656..1f824ad2 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -27,7 +27,7 @@ import viser import viser.transforms as tf from embodichain.utils import logger -from dataclasses import dataclass +from embodichain.utils import configclass from embodichain.toolkits.graspkit.pg_grasp.antipodal_sampler import ( AntipodalSampler, AntipodalSamplerCfg, @@ -41,7 +41,7 @@ import tempfile -@dataclass +@configclass class GraspAnnotatorCfg: viser_port: int = 15531 use_largest_connected_component: bool = False @@ -50,7 +50,7 @@ class GraspAnnotatorCfg: max_deviation_angle: float = np.pi / 12 -@dataclass +@configclass class SelectResult: vertex_indices: np.ndarray | None = None face_indices: np.ndarray | None = None @@ -66,6 +66,7 @@ def __init__( vertices: torch.Tensor, triangles: torch.Tensor, cfg: GraspAnnotatorCfg = GraspAnnotatorCfg(), + gripper_collision_cfg: SimpleGripperCollisionCfg = SimpleGripperCollisionCfg(), ) -> None: """Initialize the GraspAnnotator with the given mesh vertices, triangles, and configuration. Args: @@ -85,7 +86,7 @@ def __init__( self._collision_checker = SimpleGripperCollisionChecker( object_mesh_verts=vertices, object_mesh_faces=triangles, - cfg=SimpleGripperCollisionCfg(), + cfg=gripper_collision_cfg, ) self.cfg = cfg self.antipodal_sampler = AntipodalSampler(cfg=cfg.antipodal_sampler_cfg) diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py index 57632f95..2d0aa518 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py @@ -19,11 +19,11 @@ import numpy as np import open3d as o3d import open3d.core as o3c -from dataclasses import dataclass +from embodichain.utils import configclass from embodichain.utils import logger -@dataclass +@configclass class AntipodalSamplerCfg: """Configuration for AntipodalSampler.""" diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py index 165f9531..41bb4a89 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -21,7 +21,7 @@ from typing import List, Tuple, Union from dexsim.kit.meshproc import convex_decomposition_coacd import hashlib -from dataclasses import dataclass +from embodichain.utils import configclass import os import pickle import open3d as o3d @@ -35,7 +35,7 @@ ) -@dataclass +@configclass class BatchConvexCollisionCheckerCfg: """Configuration for BatchConvexCollisionChecker.""" @@ -134,10 +134,21 @@ def batch_point_convex_query( device: torch.device, collision_threshold: float = -0.003, ): - plane_equations_wp = wp.from_torch(plane_equations) - plane_equation_counts_wp = wp.from_torch(plane_equation_counts) - batch_points_wp = wp.from_torch(batch_points) - + # always use cuda for batch grasp pose query + is_cpu = device == torch.device("cpu") + if is_cpu: + plane_equations_wp = wp.from_torch(plane_equations.to("cuda")) + plane_equation_counts_wp = wp.from_torch(plane_equation_counts.to("cuda")) + batch_points_wp = wp.from_torch(batch_points.to("cuda")) + else: + plane_equations_wp = wp.from_torch(plane_equations) + plane_equation_counts_wp = wp.from_torch(plane_equation_counts) + batch_points_wp = wp.from_torch(batch_points) + + if is_cpu: + wp_device = standardize_device_string(torch.device("cuda")) + else: + wp_device = standardize_device_string(device) n_pose = batch_points.shape[0] n_point = batch_points.shape[1] n_convex = plane_equations.shape[0] @@ -145,14 +156,14 @@ def batch_point_convex_query( shape=(n_pose, n_point, n_convex), value=-float("inf"), dtype=float, - device=standardize_device_string(device), + device=wp_device, ) # [n_pose, n_point, n_convex] wp.launch( kernel=convex_signed_distance_kernel, dim=(n_pose, n_point, n_convex), inputs=(batch_points_wp, plane_equations_wp, plane_equation_counts_wp), outputs=(point_convex_signed_distance_wp,), - device=standardize_device_string(device), + device=wp_device, ) point_convex_signed_distance = wp.to_torch(point_convex_signed_distance_wp) # import ipdb; ipdb.set_trace() @@ -160,7 +171,10 @@ def batch_point_convex_query( dim=-1 ).values # [n_pose, n_point] is_point_collide = point_signed_distance <= collision_threshold - return point_signed_distance, is_point_collide + if is_cpu: + return point_signed_distance.to("cpu"), is_point_collide.to("cpu") + else: + return point_signed_distance, is_point_collide def query_batch_points( self, @@ -423,7 +437,6 @@ def transform_points_batch( if __name__ == "__main__": from embodichain.data import get_data_path - mug_path = get_data_path("CoffeeCup/cup.ply") mug_path = get_data_path("ScannedBottle/moliwulong_processed.ply") mug_mesh = trimesh.load(mug_path, force="mesh", process=False) verts = torch.tensor(mug_mesh.vertices, dtype=torch.float32) @@ -463,11 +476,11 @@ def transform_points_batch( collision_checker.query_batch_points( test_pc, collision_threshold=0.003, is_visual=True ) - collision_checker.query( - obj_verts, - obj_faces, - poses, - cfg=BatchConvexCollisionCheckerCfg( - debug=True, n_query_mesh_samples=32768, collsion_threshold=0.000 - ), - ) + # collision_checker.query( + # obj_verts, + # obj_faces, + # poses, + # cfg=BatchConvexCollisionCheckerCfg( + # debug=True, n_query_mesh_samples=32768, collsion_threshold=0.000 + # ), + # ) diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py index 2f13d16f..78ce4eb4 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -16,29 +16,30 @@ from __future__ import annotations -from dataclasses import dataclass +from embodichain.utils import configclass + from typing import Sequence from .batch_collision_checker import BatchConvexCollisionChecker import torch -@dataclass +@configclass class SimpleGripperCollisionCfg: """Configuration for the SimpleGripperCollisionChecker. This class defines various parameters related to the gripper geometry, point cloud generation, and collision checking process. Users can customize these parameters based on the specific gripper being modeled and the requirements of the application.""" max_open_length: float = 0.1 """ Maximum opening length of the gripper fingers. This should be set according to the specific gripper being modeled, and it defines the maximum distance between the two fingers when fully open.""" - finger_length: float = 0.16 - """ Length of the gripper fingers from the root to the tip. This should be set according to the specific gripper being modeled, and it defines how far the fingers extend from the gripper root frame.""" + finger_length: float = 0.08 + """ Length of the gripper fingers from the root to the tip, in z axis. This should be set according to the specific gripper being modeled, and it defines how far the fingers extend from the gripper root frame.""" y_thickness: float = 0.03 """ Thickness of the gripper along the Y-axis (the axis perpendicular to the finger opening direction). This should be set according to the specific gripper being modeled, and it defines the width of the gripper's main body and fingers in the Y direction.""" x_thickness: float = 0.01 """ Thickness of the gripper along the X-axis (the axis parallel to the finger opening direction). This should be set according to the specific gripper being modeled, and it defines the thickness of the fingers and the root in the X direction.""" - root_z_width: float = 0.06 + root_z_width: float = 0.08 """ Width of the gripper root along the Z-axis (the axis along the finger length direction). This should be set according to the specific gripper being modeled, and it defines how far the root extends along the Z direction.""" device = torch.device("cpu") """ Device on which the gripper point cloud will be generated and processed. This should be set according to the computational resources available and the requirements of the application. For example, if using a GPU for collision checking, this should be set to torch.device('cuda'). """ - rough_dense: float = 0.015 + point_sample_dense: float = 0.01 """ Approximate number of points per unit length for the gripper point cloud. Higher values will yield denser point clouds, which can improve collision checking accuracy but also increase computational cost. This should be set based on the desired balance between accuracy and efficiency for the specific application.""" max_decomposition_hulls: int = 16 """ Maximum number of convex hulls to decompose the object mesh into for collision checking. This should be set based on the complexity of the object geometry and the desired accuracy of collision checking. More hulls can provide a tighter approximation of the object shape but will increase computational cost.""" @@ -69,17 +70,17 @@ def _init_pc_template(self): self.cfg.y_thickness, self.cfg.root_z_width, ), - dense=self.cfg.rough_dense, + dense=self.cfg.point_sample_dense, device=self.device, ) self.left_template = box_surface_grid( size=(self.cfg.x_thickness, self.cfg.y_thickness, self.cfg.finger_length), - dense=self.cfg.rough_dense, + dense=self.cfg.point_sample_dense, device=self.device, ) self.right_template = box_surface_grid( size=(self.cfg.x_thickness, self.cfg.y_thickness, self.cfg.finger_length), - dense=self.cfg.rough_dense, + dense=self.cfg.point_sample_dense, device=self.device, ) diff --git a/scripts/tutorials/grasp/grasp_generator.py b/scripts/tutorials/grasp/grasp_generator.py index 9f4450d0..087d4b19 100644 --- a/scripts/tutorials/grasp/grasp_generator.py +++ b/scripts/tutorials/grasp/grasp_generator.py @@ -46,6 +46,9 @@ GraspAnnotatorCfg, AntipodalSamplerCfg, ) +from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import ( + SimpleGripperCollisionCfg, +) def parse_arguments(): @@ -238,7 +241,7 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso antipodal_sampler_cfg=AntipodalSamplerCfg( n_sample=20000, max_length=0.088, min_length=0.003 ), - force_regenerate=True, # force user to annotate grasp region each time + force_regenerate=False, # force user to annotate grasp region each time ) sim.open_window() @@ -248,12 +251,18 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso # 3. press 'Confirm Selection' to finish grasp region selection. start_time = time.time() + + gripper_collision_cfg = SimpleGripperCollisionCfg( + max_open_length=0.088, finger_length=0.078, + point_sample_dense=0.012 + ) grasp_xpos = mug.get_grasp_pose( approach_direction=torch.tensor( [0, 0, -1], dtype=torch.float32, device=sim.device ), # gripper approach direction in the world frame cfg=grasp_cfg, - is_visual=True, # visualize selected grasp pose finally + gripper_collision_cfg=gripper_collision_cfg, + is_visual=False, # visualize selected grasp pose finally ) cost_time = time.time() - start_time logger.log_info(f"Get grasp pose cost time: {cost_time:.2f} seconds") From c05b1305e40a1ea8f51cc2b1120aa902727a7972 Mon Sep 17 00:00:00 2001 From: chenjian Date: Wed, 1 Apr 2026 16:37:25 +0800 Subject: [PATCH 14/25] add batch convex unittest --- .../pg_grasp/batch_collision_checker.py | 39 ++++---- scripts/tutorials/grasp/grasp_generator.py | 3 +- tests/toolkits/test_batch_convex_collision.py | 98 +++++++++++++++++++ 3 files changed, 121 insertions(+), 19 deletions(-) create mode 100644 tests/toolkits/test_batch_convex_collision.py diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py index 41bb4a89..eea46c73 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -28,6 +28,7 @@ from embodichain.utils import logger from embodichain.utils.warp import convex_signed_distance_kernel import warp as wp + from embodichain.utils.device_utils import standardize_device_string CONVEX_CACHE_DIR = os.path.join( @@ -144,7 +145,7 @@ def batch_point_convex_query( plane_equations_wp = wp.from_torch(plane_equations) plane_equation_counts_wp = wp.from_torch(plane_equation_counts) batch_points_wp = wp.from_torch(batch_points) - + if is_cpu: wp_device = standardize_device_string(torch.device("cuda")) else: @@ -166,7 +167,6 @@ def batch_point_convex_query( device=wp_device, ) point_convex_signed_distance = wp.to_torch(point_convex_signed_distance_wp) - # import ipdb; ipdb.set_trace() point_signed_distance = point_convex_signed_distance.min( dim=-1 ).values # [n_pose, n_point] @@ -439,8 +439,10 @@ def transform_points_batch( mug_path = get_data_path("ScannedBottle/moliwulong_processed.ply") mug_mesh = trimesh.load(mug_path, force="mesh", process=False) - verts = torch.tensor(mug_mesh.vertices, dtype=torch.float32) - faces = torch.tensor(mug_mesh.faces, dtype=torch.int32) + verts = torch.tensor( + mug_mesh.vertices, dtype=torch.float32, device=torch.device("cuda") + ) + faces = torch.tensor(mug_mesh.faces, dtype=torch.int32, device=torch.device("cuda")) collision_checker = BatchConvexCollisionChecker( verts, faces, max_decomposition_hulls=16 ) @@ -459,28 +461,31 @@ def transform_points_batch( [0, 0, -1, 0], [0, 0, 0, 1], ], - ] + ], + device=torch.device("cuda"), ) from scipy.spatial.transform import Rotation + wp.init() + rot = Rotation.from_euler("xyz", [12, 3, 32], degrees=True).as_matrix() - poses[0, :3, :3] = torch.tensor(rot, dtype=torch.float32) - poses[1, :3, :3] = torch.tensor(rot, dtype=torch.float32) + poses[0, :3, :3] = torch.tensor( + rot, dtype=torch.float32, device=torch.device("cuda") + ) + poses[1, :3, :3] = torch.tensor( + rot, dtype=torch.float32, device=torch.device("cuda") + ) obj_path = get_data_path("ScannedBottle/yibao_processed.ply") obj_mesh = trimesh.load(obj_path, force="mesh", process=False) - obj_verts = torch.tensor(obj_mesh.vertices, dtype=torch.float32) - obj_faces = torch.tensor(obj_mesh.faces, dtype=torch.int32) + obj_verts = torch.tensor( + obj_mesh.vertices, dtype=torch.float32, device=torch.device("cuda") + ) + obj_faces = torch.tensor( + obj_mesh.faces, dtype=torch.int32, device=torch.device("cuda") + ) test_pc = transform_points_batch(obj_verts, poses) collision_checker.query_batch_points( test_pc, collision_threshold=0.003, is_visual=True ) - # collision_checker.query( - # obj_verts, - # obj_faces, - # poses, - # cfg=BatchConvexCollisionCheckerCfg( - # debug=True, n_query_mesh_samples=32768, collsion_threshold=0.000 - # ), - # ) diff --git a/scripts/tutorials/grasp/grasp_generator.py b/scripts/tutorials/grasp/grasp_generator.py index 087d4b19..ce940423 100644 --- a/scripts/tutorials/grasp/grasp_generator.py +++ b/scripts/tutorials/grasp/grasp_generator.py @@ -253,8 +253,7 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso start_time = time.time() gripper_collision_cfg = SimpleGripperCollisionCfg( - max_open_length=0.088, finger_length=0.078, - point_sample_dense=0.012 + max_open_length=0.088, finger_length=0.078, point_sample_dense=0.012 ) grasp_xpos = mug.get_grasp_pose( approach_direction=torch.tensor( diff --git a/tests/toolkits/test_batch_convex_collision.py b/tests/toolkits/test_batch_convex_collision.py new file mode 100644 index 00000000..5a4269d8 --- /dev/null +++ b/tests/toolkits/test_batch_convex_collision.py @@ -0,0 +1,98 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +import torch +from embodichain.data import get_data_path +import trimesh +from embodichain.toolkits.graspkit.pg_grasp.batch_collision_checker import ( + BatchConvexCollisionChecker, + BatchConvexCollisionCheckerCfg, +) +from embodichain.utils.warp import convex_signed_distance_kernel +import warp as wp + + +def transform_points_batch( + points: torch.Tensor, poses: torch.Tensor # [P, 3] # [B, 4, 4] +) -> torch.Tensor: + """ + Apply a batch of rigid transforms to a point cloud. + + Args: + points: [P, 3] source point cloud. + poses: [B, 4, 4] batch of homogeneous transformation matrices. + + Returns: + transformed: [B, P, 3] transformed point cloud for each pose. + """ + R = poses[:, :3, :3] # [B, 3, 3] + t = poses[:, :3, 3] # [B, 3] + transformed = torch.einsum("bij, pj -> bpi", R, points) + t.unsqueeze(1) + return transformed + + +def batch_convex_collision_query(device=torch.device("cuda")): + mug_path = get_data_path("ScannedBottle/moliwulong_processed.ply") + mug_mesh = trimesh.load(mug_path, force="mesh", process=False) + verts = torch.tensor(mug_mesh.vertices, dtype=torch.float32, device=device) + faces = torch.tensor(mug_mesh.faces, dtype=torch.int32, device=device) + collision_checker = BatchConvexCollisionChecker( + verts, faces, max_decomposition_hulls=16 + ) + + poses = torch.tensor( + [ + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 1.05], + [0, 0, 0, 1], + ], + [ + [1, 0, 0, 0.05], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], + ], + device=device, + ) + from scipy.spatial.transform import Rotation + + rot = Rotation.from_euler("xyz", [12, 3, 32], degrees=True).as_matrix() + poses[0, :3, :3] = torch.tensor(rot, dtype=torch.float32, device=device) + poses[1, :3, :3] = torch.tensor(rot, dtype=torch.float32, device=device) + + obj_path = get_data_path("ScannedBottle/yibao_processed.ply") + obj_mesh = trimesh.load(obj_path, force="mesh", process=False) + obj_verts = torch.tensor(obj_mesh.vertices, dtype=torch.float32, device=device) + obj_faces = torch.tensor(obj_mesh.faces, dtype=torch.int32, device=device) + test_pc = transform_points_batch(obj_verts, poses) + + is_pose_collide, pose_surface_distance = collision_checker.query_batch_points( + test_pc, collision_threshold=0.003, is_visual=False + ) + assert is_pose_collide.sum().item() == 1 + assert abs(pose_surface_distance.max().item() - 0.8492) < 1e-2 + + +def test_batch_convex_collision_cpu(): + wp.init() + batch_convex_collision_query(torch.device("cpu")) + + +def test_batch_convex_collision_gpu(): + wp.init() + batch_convex_collision_query(torch.device("cuda")) From a4c834133e9a40daafa132be0027c8e9201afd22 Mon Sep 17 00:00:00 2001 From: chenjian Date: Wed, 1 Apr 2026 16:46:25 +0800 Subject: [PATCH 15/25] add trasform points mat --- .../pg_grasp/batch_collision_checker.py | 80 +------------------ .../pg_grasp/gripper_collision_checker.py | 26 +----- embodichain/utils/math.py | 19 +++++ tests/toolkits/test_batch_convex_collision.py | 24 +----- 4 files changed, 27 insertions(+), 122 deletions(-) diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py index eea46c73..234a6813 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -28,7 +28,7 @@ from embodichain.utils import logger from embodichain.utils.warp import convex_signed_distance_kernel import warp as wp - +from embodichain.utils.math import transform_points_mat from embodichain.utils.device_utils import standardize_device_string CONVEX_CACHE_DIR = os.path.join( @@ -255,7 +255,7 @@ def query( penetration, collides = check_collision_single_hull( normals_torch, offsets_torch, - transform_points_batch(query_points, poses), + transform_points_mat(query_points, poses), cfg.collsion_threshold, ) penetration_result = torch.max(penetration_result, penetration) @@ -413,79 +413,3 @@ def check_collision_single_hull( collides = penetration > threshold # [B, P] return penetration, collides - - -def transform_points_batch( - points: torch.Tensor, poses: torch.Tensor # [P, 3] # [B, 4, 4] -) -> torch.Tensor: - """ - Apply a batch of rigid transforms to a point cloud. - - Args: - points: [P, 3] source point cloud. - poses: [B, 4, 4] batch of homogeneous transformation matrices. - - Returns: - transformed: [B, P, 3] transformed point cloud for each pose. - """ - R = poses[:, :3, :3] # [B, 3, 3] - t = poses[:, :3, 3] # [B, 3] - transformed = torch.einsum("bij, pj -> bpi", R, points) + t.unsqueeze(1) - return transformed - - -if __name__ == "__main__": - from embodichain.data import get_data_path - - mug_path = get_data_path("ScannedBottle/moliwulong_processed.ply") - mug_mesh = trimesh.load(mug_path, force="mesh", process=False) - verts = torch.tensor( - mug_mesh.vertices, dtype=torch.float32, device=torch.device("cuda") - ) - faces = torch.tensor(mug_mesh.faces, dtype=torch.int32, device=torch.device("cuda")) - collision_checker = BatchConvexCollisionChecker( - verts, faces, max_decomposition_hulls=16 - ) - - poses = torch.tensor( - [ - [ - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0.05], - [0, 0, 0, 1], - ], - [ - [1, 0, 0, 0.05], - [0, -1, 0, 0], - [0, 0, -1, 0], - [0, 0, 0, 1], - ], - ], - device=torch.device("cuda"), - ) - from scipy.spatial.transform import Rotation - - wp.init() - - rot = Rotation.from_euler("xyz", [12, 3, 32], degrees=True).as_matrix() - poses[0, :3, :3] = torch.tensor( - rot, dtype=torch.float32, device=torch.device("cuda") - ) - poses[1, :3, :3] = torch.tensor( - rot, dtype=torch.float32, device=torch.device("cuda") - ) - - obj_path = get_data_path("ScannedBottle/yibao_processed.ply") - obj_mesh = trimesh.load(obj_path, force="mesh", process=False) - obj_verts = torch.tensor( - obj_mesh.vertices, dtype=torch.float32, device=torch.device("cuda") - ) - obj_faces = torch.tensor( - obj_mesh.faces, dtype=torch.int32, device=torch.device("cuda") - ) - test_pc = transform_points_batch(obj_verts, poses) - - collision_checker.query_batch_points( - test_pc, collision_threshold=0.003, is_visual=True - ) diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py index 78ce4eb4..2fb496f1 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -21,7 +21,7 @@ from typing import Sequence from .batch_collision_checker import BatchConvexCollisionChecker import torch - +from embodichain.utils.math import transform_points_mat @configclass class SimpleGripperCollisionCfg: @@ -112,9 +112,9 @@ def _get_gripper_pc( right_finger_poses[:, :3, 0] * open_lengths_repeat ) - root_pc = transform_points_batch(self.root_template, root_grasp_poses) - left_pc = transform_points_batch(self.left_template, left_finger_poses) - right_pc = transform_points_batch(self.right_template, right_finger_poses) + root_pc = transform_points_mat(self.root_template, root_grasp_poses) + left_pc = transform_points_mat(self.left_template, left_finger_poses) + right_pc = transform_points_mat(self.right_template, right_finger_poses) gripper_pc = torch.cat([root_pc, left_pc, right_pc], dim=1) return gripper_pc @@ -137,24 +137,6 @@ def query( ) -def transform_points_batch( - points: torch.Tensor, poses: torch.Tensor # [P, 3] # [B, 4, 4] -) -> torch.Tensor: - """ - Apply a batch of rigid transforms to a point cloud. - - Args: - points: [P, 3] source point cloud. - poses: [B, 4, 4] batch of homogeneous transformation matrices. - - Returns: - transformed: [B, P, 3] transformed point cloud for each pose. - """ - R = poses[:, :3, :3] # [B, 3, 3] - t = poses[:, :3, 3] # [B, 3] - transformed = torch.einsum("bij, pj -> bpi", R, points) + t.unsqueeze(1) - return transformed - def box_surface_grid( size: Sequence[float] | torch.Tensor, diff --git a/embodichain/utils/math.py b/embodichain/utils/math.py index 084e51f4..caaa39d2 100644 --- a/embodichain/utils/math.py +++ b/embodichain/utils/math.py @@ -1206,6 +1206,25 @@ def transform_points( return points_batch +def transform_points_mat( + points: torch.Tensor, poses: torch.Tensor # [P, 3] # [B, 4, 4] +) -> torch.Tensor: + """ + Apply a batch of rigid transforms to a point cloud. + + Args: + points: [P, 3] source point cloud. + poses: [B, 4, 4] batch of homogeneous transformation matrices. + + Returns: + transformed: [B, P, 3] transformed point cloud for each pose. + """ + R = poses[:, :3, :3] # [B, 3, 3] + t = poses[:, :3, 3] # [B, 3] + transformed = torch.einsum("bij, pj -> bpi", R, points) + t.unsqueeze(1) + return transformed + + """ Projection operations. """ diff --git a/tests/toolkits/test_batch_convex_collision.py b/tests/toolkits/test_batch_convex_collision.py index 5a4269d8..18143321 100644 --- a/tests/toolkits/test_batch_convex_collision.py +++ b/tests/toolkits/test_batch_convex_collision.py @@ -20,27 +20,7 @@ BatchConvexCollisionChecker, BatchConvexCollisionCheckerCfg, ) -from embodichain.utils.warp import convex_signed_distance_kernel -import warp as wp - - -def transform_points_batch( - points: torch.Tensor, poses: torch.Tensor # [P, 3] # [B, 4, 4] -) -> torch.Tensor: - """ - Apply a batch of rigid transforms to a point cloud. - - Args: - points: [P, 3] source point cloud. - poses: [B, 4, 4] batch of homogeneous transformation matrices. - - Returns: - transformed: [B, P, 3] transformed point cloud for each pose. - """ - R = poses[:, :3, :3] # [B, 3, 3] - t = poses[:, :3, 3] # [B, 3] - transformed = torch.einsum("bij, pj -> bpi", R, points) + t.unsqueeze(1) - return transformed +from embodichain.utils.math import transform_points_mat def batch_convex_collision_query(device=torch.device("cuda")): @@ -79,7 +59,7 @@ def batch_convex_collision_query(device=torch.device("cuda")): obj_mesh = trimesh.load(obj_path, force="mesh", process=False) obj_verts = torch.tensor(obj_mesh.vertices, dtype=torch.float32, device=device) obj_faces = torch.tensor(obj_mesh.faces, dtype=torch.int32, device=device) - test_pc = transform_points_batch(obj_verts, poses) + test_pc = transform_points_mat(obj_verts, poses) is_pose_collide, pose_surface_distance = collision_checker.query_batch_points( test_pc, collision_threshold=0.003, is_visual=False From 7865db1ee75756267378a7c3a40bcab6875c14fc Mon Sep 17 00:00:00 2001 From: chenjian Date: Wed, 1 Apr 2026 16:47:03 +0800 Subject: [PATCH 16/25] style --- .../toolkits/graspkit/pg_grasp/gripper_collision_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py index 2fb496f1..5c7d3dbc 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -23,6 +23,7 @@ import torch from embodichain.utils.math import transform_points_mat + @configclass class SimpleGripperCollisionCfg: """Configuration for the SimpleGripperCollisionChecker. This class defines various parameters related to the gripper geometry, point cloud generation, and collision checking process. Users can customize these parameters based on the specific gripper being modeled and the requirements of the application.""" @@ -137,7 +138,6 @@ def query( ) - def box_surface_grid( size: Sequence[float] | torch.Tensor, dense: float, From 0043d066e7aac65be68c1a5a04e461dfa4829636 Mon Sep 17 00:00:00 2001 From: chenjian Date: Wed, 1 Apr 2026 17:04:33 +0800 Subject: [PATCH 17/25] update docs --- .../features/toolkits/grasp_generator.rst | 167 ++++++++++++++++++ docs/source/features/toolkits/index.rst | 1 + docs/source/tutorial/grasp_generator.rst | 77 -------- docs/source/tutorial/index.rst | 1 - 4 files changed, 168 insertions(+), 78 deletions(-) create mode 100644 docs/source/features/toolkits/grasp_generator.rst delete mode 100644 docs/source/tutorial/grasp_generator.rst diff --git a/docs/source/features/toolkits/grasp_generator.rst b/docs/source/features/toolkits/grasp_generator.rst new file mode 100644 index 00000000..a4b39a17 --- /dev/null +++ b/docs/source/features/toolkits/grasp_generator.rst @@ -0,0 +1,167 @@ +Generating and Executing Robot Grasps +====================================== + +.. currentmodule:: embodichain.lab.sim + +This tutorial demonstrates how to generate antipodal grasp poses for a target object and execute a full grasp trajectory with a robot arm. It covers scene initialization, robot and object creation, interactive grasp region annotation, grasp pose computation, and trajectory execution in the simulation loop. + +The Code +~~~~~~~~ + +The tutorial corresponds to the ``grasp_generator.py`` script in the ``scripts/tutorials/grasp`` directory. + +.. dropdown:: Code for grasp_generator.py + :icon: code + + .. literalinclude:: ../../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :linenos: + + +The Code Explained +~~~~~~~~~~~~~~~~~~ + +Configuring the simulation +-------------------------- + +Command-line arguments are parsed with ``argparse`` to select the number of parallel environments, the compute device, and optional rendering features such as ray tracing and headless mode. + +.. literalinclude:: ../../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :start-at: def parse_arguments(): + :end-at: return parser.parse_args() + +The parsed arguments are passed to ``initialize_simulation``, which builds a :class:`SimulationManagerCfg` and creates the :class:`SimulationManager` instance. When ray tracing is enabled a directional :class:`cfg.LightCfg` is also added to the scene. + +.. literalinclude:: ../../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :start-at: def initialize_simulation(args) -> SimulationManager: + :end-at: return sim + +Annotating and computing grasp poses +------------------------------------- + +Grasp generation is performed by :meth:`objects.RigidObject.get_grasp_pose`, which internally runs an antipodal sampler on the object mesh. A :class:`toolkits.graspkit.pg_grasp.GraspAnnotatorCfg` controls sampler parameters (sample count, gripper jaw limits) and the interactive annotation workflow: + +1. Open the visualization in a browser at the reported port (e.g. ``http://localhost:11801``). +2. Use *Rect Select Region* to highlight the area of the object that should be grasped. +3. Click *Confirm Selection* to finalize the region. + +The function returns a batch of ``(N_envs, 4, 4)`` homogeneous transformation matrices representing candidate grasp frames in the world coordinate system. + +For each grasp pose, gripper approach direction in world coordinate is required to compute the antipodal grasp. In this tutorial, we use a fixed approach direction (straight down in world frame) for simplicity, but it can be customized based on the task or object geometry. + +.. literalinclude:: ../../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :start-at: # get mug grasp pose + :end-at: logger.log_info(f"Get grasp pose cost time: {cost_time:.2f} seconds") + +Configuring GraspAnnotatorCfg +------------------------------ + +:class:`toolkits.graspkit.pg_grasp.GraspAnnotatorCfg` controls the overall grasp annotation workflow. The key parameters are listed below. + +.. list-table:: GraspAnnotatorCfg parameters + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``viser_port`` + - ``15531`` + - Port used by the Viser browser-based visualizer for interactive grasp region annotation. + * - ``use_largest_connected_component`` + - ``False`` + - When ``True``, only the largest connected component of the object mesh is used for sampling. Useful for meshes that contain disconnected fragments. + * - ``antipodal_sampler_cfg`` + - ``AntipodalSamplerCfg()`` + - Nested configuration for the antipodal point sampler. See the table below for its parameters. + * - ``force_regenerate`` + - ``False`` + - When ``True``, the user is required to annotate the grasp region every time, bypassing any cached results from a previous run. + * - ``max_deviation_angle`` + - ``π / 12`` + - Maximum allowed angle (in radians) between the specified approach direction and the axis connecting an antipodal point pair. Pairs that deviate more than this threshold are discarded. + +The ``antipodal_sampler_cfg`` field accepts an :class:`toolkits.graspkit.pg_grasp.AntipodalSamplerCfg` instance, which controls how antipodal point pairs are sampled on the mesh surface. + +.. list-table:: AntipodalSamplerCfg parameters + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``n_sample`` + - ``20000`` + - Number of surface points uniformly sampled from the mesh before ray casting. Higher values yield denser coverage but increase computation time. + * - ``max_angle`` + - ``π / 12`` + - Maximum angle (in radians) used to randomly perturb the ray direction away from the inward normal. Larger values increase diversity of sampled antipodal pairs. Setting this to ``0`` disables perturbation and samples strictly along surface normals. + * - ``max_length`` + - ``0.1`` + - Maximum allowed distance (in metres) between an antipodal pair. Pairs farther apart than this value are discarded; set this to match the maximum gripper jaw opening width. + * - ``min_length`` + - ``0.001`` + - Minimum allowed distance (in metres) between an antipodal pair. Pairs closer together than this value are discarded to avoid degenerate or self-intersecting grasps. + +Configuring SimpleGripperCollisionCfg +-------------------------------------- + +:class:`toolkits.graspkit.pg_grasp.SimpleGripperCollisionCfg` models the geometry of a parallel-jaw gripper as a point cloud and is used to filter out grasp candidates that would collide with the object. All length parameters are in metres. + +.. list-table:: SimpleGripperCollisionCfg parameters + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``max_open_length`` + - ``0.1`` + - Maximum finger separation of the gripper when fully open. Should match the physical gripper specification. + * - ``finger_length`` + - ``0.08`` + - Length of each finger along the Z-axis (depth direction from the root). Should match the physical gripper specification. + * - ``y_thickness`` + - ``0.03`` + - Thickness of the gripper body and fingers along the Y-axis (perpendicular to the opening direction). + * - ``x_thickness`` + - ``0.01`` + - Thickness of each finger along the X-axis (parallel to the opening direction). + * - ``root_z_width`` + - ``0.08`` + - Extent of the gripper root block along the Z-axis. + * - ``device`` + - ``cpu`` + - PyTorch device on which the gripper point cloud is generated and processed. Set to ``cuda`` when GPU-accelerated collision checking is required. + * - ``point_sample_dense`` + - ``0.01`` + - Approximate number of sample points per unit length along each edge of the gripper point cloud. Higher values produce denser point clouds and improve collision-check accuracy at the cost of additional computation. + * - ``max_decomposition_hulls`` + - ``16`` + - Maximum number of convex hulls used when decomposing the object mesh for collision checking. More hulls give a tighter shape approximation but increase cost. + * - ``open_check_margin`` + - ``0.01`` + - Extra clearance added to the gripper open length during collision checking to account for pose uncertainty or mesh inaccuracies. + + +The Code Execution +~~~~~~~~~~~~~~~~~~ + +To run the script, execute the following command from the project root: + +.. code-block:: bash + + python scripts/tutorials/grasp/grasp_generator.py + +A simulation window will open showing the robot and the mug. A browser-based visualizer will also launch (default port ``11801``) for interactive grasp region annotation. + +You can customize the run with additional arguments: + +.. code-block:: bash + + python scripts/tutorials/grasp/grasp_generator.py --num_envs --device --enable_rt --headless + +After confirming the grasp region in the browser, the script will compute a grasp pose, print the elapsed time, and then wait for you to press **Enter** before executing the full grasp trajectory in the simulation. Press **Enter** again to exit once the motion is complete. diff --git a/docs/source/features/toolkits/index.rst b/docs/source/features/toolkits/index.rst index b6f0b22d..f6886746 100644 --- a/docs/source/features/toolkits/index.rst +++ b/docs/source/features/toolkits/index.rst @@ -7,4 +7,5 @@ ToolKits convex_decomposition urdf_assembly + grasp_generator \ No newline at end of file diff --git a/docs/source/tutorial/grasp_generator.rst b/docs/source/tutorial/grasp_generator.rst deleted file mode 100644 index 51e802e8..00000000 --- a/docs/source/tutorial/grasp_generator.rst +++ /dev/null @@ -1,77 +0,0 @@ -Generating and Executing Robot Grasps -====================================== - -.. currentmodule:: embodichain.lab.sim - -This tutorial demonstrates how to generate antipodal grasp poses for a target object and execute a full grasp trajectory with a robot arm. It covers scene initialization, robot and object creation, interactive grasp region annotation, grasp pose computation, and trajectory execution in the simulation loop. - -The Code -~~~~~~~~ - -The tutorial corresponds to the ``grasp_generator.py`` script in the ``scripts/tutorials/grasp`` directory. - -.. dropdown:: Code for grasp_generator.py - :icon: code - - .. literalinclude:: ../../../scripts/tutorials/grasp/grasp_generator.py - :language: python - :linenos: - - -The Code Explained -~~~~~~~~~~~~~~~~~~ - -Configuring the simulation --------------------------- - -Command-line arguments are parsed with ``argparse`` to select the number of parallel environments, the compute device, and optional rendering features such as ray tracing and headless mode. - -.. literalinclude:: ../../../scripts/tutorials/grasp/grasp_generator.py - :language: python - :start-at: def parse_arguments(): - :end-at: return parser.parse_args() - -The parsed arguments are passed to ``initialize_simulation``, which builds a :class:`SimulationManagerCfg` and creates the :class:`SimulationManager` instance. When ray tracing is enabled a directional :class:`cfg.LightCfg` is also added to the scene. - -.. literalinclude:: ../../../scripts/tutorials/grasp/grasp_generator.py - :language: python - :start-at: def initialize_simulation(args) -> SimulationManager: - :end-at: return sim - -Annotating and computing grasp poses -------------------------------------- - -Grasp generation is performed by :meth:`objects.RigidObject.get_grasp_pose`, which internally runs an antipodal sampler on the object mesh. A :class:`toolkits.graspkit.pg_grasp.GraspAnnotatorCfg` controls sampler parameters (sample count, gripper jaw limits) and the interactive annotation workflow: - -1. Open the visualization in a browser at the reported port (e.g. ``http://localhost:11801``). -2. Use *Rect Select Region* to highlight the area of the object that should be grasped. -3. Click *Confirm Selection* to finalize the region. - -The function returns a batch of ``(N_envs, 4, 4)`` homogeneous transformation matrices representing candidate grasp frames in the world coordinate system. - -For each grasp pose, gripper approach direction in world coordinate is required to compute the antipodal grasp. In this tutorial, we use a fixed approach direction (straight down in world frame) for simplicity, but it can be customized based on the task or object geometry. - -.. literalinclude:: ../../../scripts/tutorials/grasp/grasp_generator.py - :language: python - :start-at: # get mug grasp pose - :end-at: logger.log_info(f"Get grasp pose cost time: {cost_time:.2f} seconds") - - -The Code Execution -~~~~~~~~~~~~~~~~~~ - -To run the script, execute the following command from the project root: - -.. code-block:: bash - - python scripts/tutorials/grasp/grasp_generator.py - -A simulation window will open showing the robot and the mug. A browser-based visualizer will also launch (default port ``11801``) for interactive grasp region annotation. - -You can customize the run with additional arguments: - -.. code-block:: bash - - python scripts/tutorials/grasp/grasp_generator.py --num_envs --device --enable_rt --headless - -After confirming the grasp region in the browser, the script will compute a grasp pose, print the elapsed time, and then wait for you to press **Enter** before executing the full grasp trajectory in the simulation. Press **Enter** again to exit once the motion is complete. diff --git a/docs/source/tutorial/index.rst b/docs/source/tutorial/index.rst index 0a28a8e9..ef6efe79 100644 --- a/docs/source/tutorial/index.rst +++ b/docs/source/tutorial/index.rst @@ -15,7 +15,6 @@ Tutorials sensor motion_gen gizmo - grasp_generator basic_env modular_env rl From 90568fd494520c04a3335c18e69e8f65dbb81c95 Mon Sep 17 00:00:00 2001 From: chenjian Date: Wed, 1 Apr 2026 18:14:21 +0800 Subject: [PATCH 18/25] fix unittest --- tests/toolkits/test_batch_convex_collision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/toolkits/test_batch_convex_collision.py b/tests/toolkits/test_batch_convex_collision.py index 18143321..819079df 100644 --- a/tests/toolkits/test_batch_convex_collision.py +++ b/tests/toolkits/test_batch_convex_collision.py @@ -21,6 +21,7 @@ BatchConvexCollisionCheckerCfg, ) from embodichain.utils.math import transform_points_mat +import warp as wp def batch_convex_collision_query(device=torch.device("cuda")): From 82a0c55e828e292a3d31280ce4a4e69b18a30b6a Mon Sep 17 00:00:00 2001 From: yuecideng Date: Wed, 1 Apr 2026 19:16:00 +0800 Subject: [PATCH 19/25] wip --- .../graspkit/pg_grasp/antipodal_annotator.py | 3 ++- .../graspkit/pg_grasp/batch_collision_checker.py | 12 +++++++----- embodichain/utils/warp/__init__.py | 2 +- .../{collision_checker => collision}/__init__.py | 0 .../{collision_checker => collision}/convex_query.py | 1 - scripts/tutorials/grasp/grasp_generator.py | 10 +++------- 6 files changed, 13 insertions(+), 15 deletions(-) rename embodichain/utils/warp/{collision_checker => collision}/__init__.py (100%) rename embodichain/utils/warp/{collision_checker => collision}/convex_query.py (99%) diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py index 20177656..11c3b09e 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -32,6 +32,7 @@ AntipodalSampler, AntipodalSamplerCfg, ) +from embodichain.utils import configclass from .gripper_collision_checker import ( SimpleGripperCollisionChecker, SimpleGripperCollisionCfg, @@ -41,7 +42,7 @@ import tempfile -@dataclass +@configclass class GraspAnnotatorCfg: viser_port: int = 15531 use_largest_connected_component: bool = False diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py index 165f9531..106803af 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -17,25 +17,27 @@ import trimesh import numpy as np import torch +import warp as wp import time -from typing import List, Tuple, Union -from dexsim.kit.meshproc import convex_decomposition_coacd import hashlib -from dataclasses import dataclass import os import pickle import open3d as o3d + +from typing import List, Tuple, Union +from dexsim.kit.meshproc import convex_decomposition_coacd + from embodichain.utils import logger from embodichain.utils.warp import convex_signed_distance_kernel -import warp as wp from embodichain.utils.device_utils import standardize_device_string +from embodichain.utils import configclass CONVEX_CACHE_DIR = os.path.join( os.path.expanduser("~"), ".cache", "embodichain_cache", "convex_decomposition" ) -@dataclass +@configclass class BatchConvexCollisionCheckerCfg: """Configuration for BatchConvexCollisionChecker.""" diff --git a/embodichain/utils/warp/__init__.py b/embodichain/utils/warp/__init__.py index e0fac57a..c08be1d5 100644 --- a/embodichain/utils/warp/__init__.py +++ b/embodichain/utils/warp/__init__.py @@ -31,4 +31,4 @@ interpolate_along_distance, ) -from .collision_checker.convex_query import convex_signed_distance_kernel +from .collision.convex_query import convex_signed_distance_kernel diff --git a/embodichain/utils/warp/collision_checker/__init__.py b/embodichain/utils/warp/collision/__init__.py similarity index 100% rename from embodichain/utils/warp/collision_checker/__init__.py rename to embodichain/utils/warp/collision/__init__.py diff --git a/embodichain/utils/warp/collision_checker/convex_query.py b/embodichain/utils/warp/collision/convex_query.py similarity index 99% rename from embodichain/utils/warp/collision_checker/convex_query.py rename to embodichain/utils/warp/collision/convex_query.py index f321e462..ce2e7c1f 100644 --- a/embodichain/utils/warp/collision_checker/convex_query.py +++ b/embodichain/utils/warp/collision/convex_query.py @@ -15,7 +15,6 @@ # ---------------------------------------------------------------------------- import warp as wp -from typing import Any @wp.kernel(enable_backward=False) diff --git a/scripts/tutorials/grasp/grasp_generator.py b/scripts/tutorials/grasp/grasp_generator.py index 9f4450d0..d6b718e9 100644 --- a/scripts/tutorials/grasp/grasp_generator.py +++ b/scripts/tutorials/grasp/grasp_generator.py @@ -15,8 +15,8 @@ # ---------------------------------------------------------------------------- """ -This script demonstrates the creation and simulation of a robot with a soft object, -and performs a pressing task in a simulated environment. +This script demonstrates the creation and simulation of a robot that grasps a rigid mug +in a simulated environment using the SimulationManager and grasp planning utilities. """ import argparse @@ -24,8 +24,6 @@ import time import torch -from dexsim.utility.path import get_resources_data_path - from embodichain.lab.sim import SimulationManager, SimulationManagerCfg from embodichain.lab.sim.objects import Robot, RigidObject from embodichain.lab.sim.utility.action_utils import interpolate_with_distance @@ -41,7 +39,6 @@ RigidObjectCfg, URDFCfg, ) -from embodichain.lab.sim.shapes import MeshCfg from embodichain.toolkits.graspkit.pg_grasp.antipodal_annotator import ( GraspAnnotatorCfg, AntipodalSamplerCfg, @@ -68,7 +65,7 @@ def parse_arguments(): parser.add_argument( "--device", type=str, - default="cuda", + default="cpu", help="device to run the environment on, e.g., 'cpu' or 'cuda'", ) return parser.parse_args() @@ -89,7 +86,6 @@ def initialize_simulation(args) -> SimulationManager: sim_device=args.device, enable_rt=args.enable_rt, physics_dt=1.0 / 100.0, - num_envs=args.num_envs, arena_space=2.5, ) sim = SimulationManager(config) From ae5622ac6e6a8654b8a9bffeca05318fab1e76d4 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Wed, 1 Apr 2026 20:39:39 +0800 Subject: [PATCH 20/25] wip --- embodichain/lab/sim/objects/rigid_object.py | 1 - embodichain/lab/sim/sim_manager.py | 1 + .../graspkit/pg_grasp/antipodal_annotator.py | 4 +- .../pg_grasp/gripper_collision_checker.py | 54 +++++++++++++++---- 4 files changed, 48 insertions(+), 12 deletions(-) diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index e42b0005..656ce925 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -1170,7 +1170,6 @@ def get_grasp_pose( self._hit_point_pairs = self._grasp_annotator.annotate() poses = self.get_local_pose(to_matrix=True) - poses = torch.as_tensor(poses, dtype=torch.float32, device=self.device) grasp_poses: tuple[torch.Tensor] = [] open_lengths: tuple[torch.Tensor] = [] for pose in poses: diff --git a/embodichain/lab/sim/sim_manager.py b/embodichain/lab/sim/sim_manager.py index 6a8bf537..4aef6896 100644 --- a/embodichain/lab/sim/sim_manager.py +++ b/embodichain/lab/sim/sim_manager.py @@ -35,6 +35,7 @@ MATERIAL_CACHE_DIR = SIM_CACHE_DIR / "mat_cache" CONVEX_DECOMP_DIR = SIM_CACHE_DIR / "convex_decomposition" REACHABLE_XPOS_DIR = SIM_CACHE_DIR / "robot_reachable_xpos" +GRASP_ANNOTATOR_CACHE_DIR = SIM_CACHE_DIR / "grasp_annotator_cache" from dexsim.types import ( Backend, diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py index 45f710ca..06d82391 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -240,11 +240,13 @@ def _(_evt: viser.GuiEvent) -> None: return hit_point_pairs def _get_cache_dir(self, vertices: torch.Tensor, triangles: torch.Tensor): + from embodichain.lab.sim.sim_manager import GRASP_ANNOTATOR_CACHE_DIR + vert_bytes = vertices.to("cpu").numpy().tobytes() face_bytes = triangles.to("cpu").numpy().tobytes() md5_hash = hashlib.md5(vert_bytes + face_bytes).hexdigest() cache_path = os.path.join( - tempfile.gettempdir(), f"antipodal_cache_{md5_hash}.npy" + GRASP_ANNOTATOR_CACHE_DIR, f"antipodal_cache_{md5_hash}.npy" ) return cache_path diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py index 5c7d3dbc..e2a54558 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -26,26 +26,60 @@ @configclass class SimpleGripperCollisionCfg: - """Configuration for the SimpleGripperCollisionChecker. This class defines various parameters related to the gripper geometry, point cloud generation, and collision checking process. Users can customize these parameters based on the specific gripper being modeled and the requirements of the application.""" + """Configuration for the SimpleGripperCollisionChecker. This class defines various parameters related to the + gripper geometry, point cloud generation, and collision checking process. Users can customize these parameters + based on the specific gripper being modeled and the requirements of the application. + """ max_open_length: float = 0.1 - """ Maximum opening length of the gripper fingers. This should be set according to the specific gripper being modeled, and it defines the maximum distance between the two fingers when fully open.""" + """ Maximum opening length of the gripper fingers. This should be set according to the specific gripper being modeled, + and it defines the maximum distance between the two fingers when fully open. + """ + finger_length: float = 0.08 - """ Length of the gripper fingers from the root to the tip, in z axis. This should be set according to the specific gripper being modeled, and it defines how far the fingers extend from the gripper root frame.""" + """ Length of the gripper fingers from the root to the tip, in z axis. This should be set according to the specific + gripper being modeled, and it defines how far the fingers extend from the gripper root frame. + """ + y_thickness: float = 0.03 - """ Thickness of the gripper along the Y-axis (the axis perpendicular to the finger opening direction). This should be set according to the specific gripper being modeled, and it defines the width of the gripper's main body and fingers in the Y direction.""" + """ Thickness of the gripper along the Y-axis (the axis perpendicular to the finger opening direction). This should + be set according to the specific gripper being modeled, and it defines the width of the gripper's main body and fingers + in the Y direction. + """ + x_thickness: float = 0.01 - """ Thickness of the gripper along the X-axis (the axis parallel to the finger opening direction). This should be set according to the specific gripper being modeled, and it defines the thickness of the fingers and the root in the X direction.""" + """ Thickness of the gripper along the X-axis (the axis parallel to the finger opening direction). This should + be set according to the specific gripper being modeled, and it defines the thickness of the fingers and the root + in the X direction. + """ + root_z_width: float = 0.08 - """ Width of the gripper root along the Z-axis (the axis along the finger length direction). This should be set according to the specific gripper being modeled, and it defines how far the root extends along the Z direction.""" + """ Width of the gripper root along the Z-axis (the axis along the finger length direction). This should be set + according to the specific gripper being modeled, and it defines how far the root extends along the Z direction. + """ + device = torch.device("cpu") - """ Device on which the gripper point cloud will be generated and processed. This should be set according to the computational resources available and the requirements of the application. For example, if using a GPU for collision checking, this should be set to torch.device('cuda'). """ + """ Device on which the gripper point cloud will be generated and processed. This should be set according to + the computational resources available and the requirements of the application. For example, if using a GPU for collision + checking, this should be set to torch.device('cuda'). + """ + point_sample_dense: float = 0.01 - """ Approximate number of points per unit length for the gripper point cloud. Higher values will yield denser point clouds, which can improve collision checking accuracy but also increase computational cost. This should be set based on the desired balance between accuracy and efficiency for the specific application.""" + """ Approximate number of points per unit length for the gripper point cloud. Higher values will yield denser point + clouds, which can improve collision checking accuracy but also increase computational cost. This should be set based + on the desired balance between accuracy and efficiency for the specific application. + """ + max_decomposition_hulls: int = 16 - """ Maximum number of convex hulls to decompose the object mesh into for collision checking. This should be set based on the complexity of the object geometry and the desired accuracy of collision checking. More hulls can provide a tighter approximation of the object shape but will increase computational cost.""" + """ Maximum number of convex hulls to decompose the object mesh into for collision checking. This should be set based + on the complexity of the object geometry and the desired accuracy of collision checking. More hulls can provide a tighter + approximation of the object shape but will increase computational cost. + """ + open_check_margin: float = 0.01 - """ Additional margin added to the gripper open length when checking for collisions. This can help account for uncertainties in the gripper pose or object geometry, and can be set based on the specific requirements of the application.""" + """ Additional margin added to the gripper open length when checking for collisions. This can help account for + uncertainties in the gripper pose or object geometry, and can be set based on the specific requirements of the application. + """ class SimpleGripperCollisionChecker: From ac1a3f197aac209b821cb4a7747b4b58d309eb6b Mon Sep 17 00:00:00 2001 From: yuecideng Date: Thu, 2 Apr 2026 10:36:09 +0800 Subject: [PATCH 21/25] wip --- .../features/toolkits/grasp_generator.rst | 14 +- embodichain/lab/sim/objects/rigid_object.py | 47 +- embodichain/lab/sim/sim_manager.py | 1 + .../toolkits/graspkit/pg_grasp/__init__.py | 9 +- .../toolkits/graspkit/pg_grasp/antipodal.py | 670 ------------------ ...al_annotator.py => antipodal_generator.py} | 487 +++++++++---- .../graspkit/pg_grasp/antipodal_sampler.py | 24 +- ...lision_checker.py => collision_checker.py} | 37 +- .../graspkit/pg_grasp/cone_sampler.py | 121 ---- .../pg_grasp/gripper_collision_checker.py | 21 +- scripts/tutorials/grasp/grasp_generator.py | 10 +- tests/toolkits/test_batch_convex_collision.py | 10 +- tests/toolkits/test_pg_grasp.py | 96 --- 13 files changed, 436 insertions(+), 1111 deletions(-) delete mode 100644 embodichain/toolkits/graspkit/pg_grasp/antipodal.py rename embodichain/toolkits/graspkit/pg_grasp/{antipodal_annotator.py => antipodal_generator.py} (54%) rename embodichain/toolkits/graspkit/pg_grasp/{batch_collision_checker.py => collision_checker.py} (92%) delete mode 100644 embodichain/toolkits/graspkit/pg_grasp/cone_sampler.py delete mode 100644 tests/toolkits/test_pg_grasp.py diff --git a/docs/source/features/toolkits/grasp_generator.rst b/docs/source/features/toolkits/grasp_generator.rst index a4b39a17..aac49ddc 100644 --- a/docs/source/features/toolkits/grasp_generator.rst +++ b/docs/source/features/toolkits/grasp_generator.rst @@ -41,7 +41,7 @@ The parsed arguments are passed to ``initialize_simulation``, which builds a :cl Annotating and computing grasp poses ------------------------------------- -Grasp generation is performed by :meth:`objects.RigidObject.get_grasp_pose`, which internally runs an antipodal sampler on the object mesh. A :class:`toolkits.graspkit.pg_grasp.GraspAnnotatorCfg` controls sampler parameters (sample count, gripper jaw limits) and the interactive annotation workflow: +Grasp generation is performed by :meth:`objects.RigidObject.get_grasp_pose`, which internally runs an antipodal sampler on the object mesh. A :class:`toolkits.graspkit.pg_grasp.GraspGeneratorCfg` controls sampler parameters (sample count, gripper jaw limits) and the interactive annotation workflow: 1. Open the visualization in a browser at the reported port (e.g. ``http://localhost:11801``). 2. Use *Rect Select Region* to highlight the area of the object that should be grasped. @@ -56,12 +56,12 @@ For each grasp pose, gripper approach direction in world coordinate is required :start-at: # get mug grasp pose :end-at: logger.log_info(f"Get grasp pose cost time: {cost_time:.2f} seconds") -Configuring GraspAnnotatorCfg +Configuring GraspGeneratorCfg ------------------------------ -:class:`toolkits.graspkit.pg_grasp.GraspAnnotatorCfg` controls the overall grasp annotation workflow. The key parameters are listed below. +:class:`toolkits.graspkit.pg_grasp.GraspGeneratorCfg` controls the overall grasp annotation workflow. The key parameters are listed below. -.. list-table:: GraspAnnotatorCfg parameters +.. list-table:: GraspGeneratorCfg parameters :header-rows: 1 :widths: 25 15 60 @@ -106,12 +106,12 @@ The ``antipodal_sampler_cfg`` field accepts an :class:`toolkits.graspkit.pg_gras - ``0.001`` - Minimum allowed distance (in metres) between an antipodal pair. Pairs closer together than this value are discarded to avoid degenerate or self-intersecting grasps. -Configuring SimpleGripperCollisionCfg +Configuring GripperCollisionCfg -------------------------------------- -:class:`toolkits.graspkit.pg_grasp.SimpleGripperCollisionCfg` models the geometry of a parallel-jaw gripper as a point cloud and is used to filter out grasp candidates that would collide with the object. All length parameters are in metres. +:class:`toolkits.graspkit.pg_grasp.GripperCollisionCfg` models the geometry of a parallel-jaw gripper as a point cloud and is used to filter out grasp candidates that would collide with the object. All length parameters are in metres. -.. list-table:: SimpleGripperCollisionCfg parameters +.. list-table:: GripperCollisionCfg parameters :header-rows: 1 :widths: 25 15 60 diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index 656ce925..41dc2f83 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -35,14 +35,13 @@ from embodichain.utils.math import convert_quat from embodichain.utils.math import matrix_from_quat, quat_from_matrix, matrix_from_euler from embodichain.utils import logger -from embodichain.toolkits.graspkit.pg_grasp.antipodal_annotator import ( - GraspAnnotator, - GraspAnnotatorCfg, +from embodichain.toolkits.graspkit.pg_grasp import ( + GraspGenerator, + GraspGeneratorCfg, ) from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import ( - SimpleGripperCollisionCfg, + GripperCollisionCfg, ) -import torch.nn.functional as F @dataclass @@ -957,25 +956,51 @@ def set_body_type(self, body_type: str) -> None: self.body_type = body_type - def get_vertices(self, env_ids: Sequence[int] | None = None) -> torch.Tensor: + def get_vertices( + self, env_ids: Sequence[int] | None = None, scale: bool = False + ) -> torch.Tensor: """ Retrieve the vertices of the rigid objects. Args: env_ids (Sequence[int] | None): A sequence of environment IDs for which to retrieve vertices. If None, retrieves vertices for all instances. + scale (bool): Whether to multiply the vertices by the body scale. Defaults to False. Returns: torch.Tensor: A tensor containing the user IDs of the specified rigid objects with shape (N, num_verts, 3). """ ids = env_ids if env_ids is not None else range(self.num_instances) - return torch.as_tensor( + verts = torch.as_tensor( np.array( [self._entities[id].get_vertices() for id in ids], ), dtype=torch.float32, device=self.device, ) + if scale: + verts = verts * self.get_body_scale(env_ids).unsqueeze_(1) + return verts + + def get_triangles(self, env_ids: Sequence[int] | None = None) -> torch.Tensor: + """ + Retrieve the triangle indices of the rigid objects. + + Args: + env_ids (Sequence[int] | None): A sequence of environment IDs for which to retrieve triangle indices. + If None, retrieves triangle indices for all instances. + + Returns: + torch.Tensor: A tensor containing the triangle indices of the specified rigid objects with shape (N, num_tris, 3). + """ + ids = env_ids if env_ids is not None else range(self.num_instances) + return torch.as_tensor( + np.array( + [self._entities[id].get_triangles() for id in ids], + ), + dtype=torch.int32, + device=self.device, + ) def get_user_ids(self, env_ids: Sequence[int] | None = None) -> torch.Tensor: """Get the user ids of the rigid bodies. @@ -1133,8 +1158,8 @@ def destroy(self) -> None: def get_grasp_pose( self, - cfg: GraspAnnotatorCfg = GraspAnnotatorCfg(), - gripper_collision_cfg: SimpleGripperCollisionCfg = SimpleGripperCollisionCfg(), + cfg: GraspGeneratorCfg = GraspGeneratorCfg(), + gripper_collision_cfg: GripperCollisionCfg = GripperCollisionCfg(), approach_direction: torch.Tensor = None, is_visual: bool = False, ) -> torch.Tensor: @@ -1142,7 +1167,7 @@ def get_grasp_pose( approach_direction = torch.tensor( [0, 0, -1], dtype=torch.float32, device=self.device ) - approach_direction = F.normalize(approach_direction, dim=-1) + approach_direction = torch.nn.functional.normalize(approach_direction, dim=-1) if hasattr(self, "_grasp_annotator") is False: vertices = torch.tensor( self._entities[0].get_vertices(), @@ -1158,7 +1183,7 @@ def get_grasp_pose( device=self.device, ) vertices = vertices * scale - self._grasp_annotator = GraspAnnotator( + self._grasp_annotator = GraspGenerator( vertices=vertices, triangles=triangles, cfg=cfg, diff --git a/embodichain/lab/sim/sim_manager.py b/embodichain/lab/sim/sim_manager.py index 4aef6896..5e3e2880 100644 --- a/embodichain/lab/sim/sim_manager.py +++ b/embodichain/lab/sim/sim_manager.py @@ -90,6 +90,7 @@ "MATERIAL_CACHE_DIR", "CONVEX_DECOMP_DIR", "REACHABLE_XPOS_DIR", + "GRASP_ANNOTATOR_CACHE_DIR", ] diff --git a/embodichain/toolkits/graspkit/pg_grasp/__init__.py b/embodichain/toolkits/graspkit/pg_grasp/__init__.py index 82c25ce0..d9719a08 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/__init__.py +++ b/embodichain/toolkits/graspkit/pg_grasp/__init__.py @@ -14,8 +14,7 @@ # limitations under the License. # ---------------------------------------------------------------------------- -from .antipodal import AntipodalGenerator, GraspSelectMethod - -__all__ = ["AntipodalGenerator", "GraspSelectMethod"] - -del antipodal +from .antipodal_sampler import * +from .collision_checker import * +from .gripper_collision_checker import * +from .antipodal_generator import * diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal.py deleted file mode 100644 index 1b7b7f1a..00000000 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal.py +++ /dev/null @@ -1,670 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -import open3d as o3d -import numpy as np -import time -import pathlib -import pickle -import os - -from enum import Enum -from copy import deepcopy -from typing import List - -from .cone_sampler import ConeSampler -from embodichain.utils.utility import get_mesh_md5 -from embodichain.utils import logger - - -class GraspSelectMethod(Enum): - NORMAL_SCORE = 0 - NEAR_APPROACH = 1 - CENTER = 2 - - -class AntipodalGrasp: - def __init__(self, pose: np.ndarray, open_len: float, score: float) -> None: - self.pose = pose # [4, 4] of float grasp pose - self.open_len = open_len # gripper open length - self.score = score # grasp pose score - - def grasp_pose_visual_mesh(self, gripper_open_length: float = None): - if gripper_open_length is None: - gripper_open_length = self.open_len - open_ratio = 1.0 - else: - open_ratio = self.open_len / gripper_open_length - open_ratio = max(1e-4, open_ratio) - gripper_finger = o3d.geometry.TriangleMesh.create_box( - gripper_open_length * 0.04, - gripper_open_length * 0.04, - gripper_open_length * 0.5, - ) - gripper_finger.translate( - np.array( - [ - -gripper_open_length * 0.02, - -gripper_open_length * 0.02, - -gripper_open_length * 0.25, - ] - ) - ) - gripper_left = deepcopy(gripper_finger) - gripper_left = gripper_left.translate( - np.array( - [ - -gripper_open_length * open_ratio * 0.5, - 0, - -gripper_open_length * 0.25, - ] - ) - ) - - gripper_right = deepcopy(gripper_finger) - gripper_right = gripper_right.translate( - np.array( - [gripper_open_length * open_ratio * 0.5, 0, -gripper_open_length * 0.25] - ) - ) - - gripper_root1 = o3d.geometry.TriangleMesh.create_box( - gripper_open_length * open_ratio, - gripper_open_length * 0.04, - gripper_open_length * 0.04, - ) - gripper_root1.translate( - np.array( - [ - gripper_open_length * open_ratio * -0.5, - gripper_open_length * -0.02, - gripper_open_length * -0.02, - ] - ) - ) - gripper_root1.translate( - np.array( - [ - 0, - 0, - gripper_open_length * -0.5, - ] - ) - ) - - gripper_root2 = o3d.geometry.TriangleMesh.create_box( - gripper_open_length * 0.04, - gripper_open_length * 0.04, - gripper_open_length * 0.5, - ) - gripper_root2.translate( - np.array( - [ - gripper_open_length * -0.02, - gripper_open_length * -0.02, - gripper_open_length * -0.25, - ] - ) - ) - gripper_root2.translate( - np.array( - [ - 0, - 0, - gripper_open_length * -0.75, - ] - ) - ) - - gripper_visual = gripper_left + gripper_right + gripper_root1 + gripper_root2 - gripper_visual.compute_vertex_normals() - return gripper_visual - - -class Antipodal: - def __init__( - self, - point_a: np.ndarray, - point_b: np.ndarray, - normal_a: np.ndarray, - normal_b: np.ndarray, - ) -> None: - """antipodal contact pair - - Args: - point_a (np.ndarray): position in point a - point_b (np.ndarray): position in point b - normal_a (np.ndarray): normal in point a - normal_b (np.ndarray): normal in point b - """ - self.point_a = point_a - self.point_b = point_b - self.normal_a = normal_a - self.normal_b = normal_b - self.dis = np.linalg.norm(point_a - point_b) - self.angle_cos = self.normal_a.dot(-self.normal_b) - self.score = self._get_score() - self._canonical_pose = self._get_canonical_pose() - - def _get_score(self): - # TODO: only normal angle is taken into account - return self.angle_cos - - def get_dis(self, another) -> float: - """get distance acoording to another antipodal - - Args: - other (Antipodal): another antipodal - - Returns: - float: distance - """ - aa_dis = np.linalg.norm(self.point_a - another.point_a) - bb_dis = np.linalg.norm(self.point_b - another.point_b) - ab_dis = np.linalg.norm(self.point_a - another.point_b) - ba_dis = np.linalg.norm(self.point_b - another.point_a) - return min(aa_dis, bb_dis, ab_dis, ba_dis) - - def get_dis_arr(self, others) -> np.ndarray: - """get distance acoording to others antipodals - - Args: - others (List[Antipodal]): other antipodals - - Returns: - np.ndarray: distance array - """ - if not others: - return np.array([], dtype=float) - # Vectorized extraction of points using list comprehension and np.array - other_a = np.array([o.point_a for o in others], dtype=float) - other_b = np.array([o.point_b for o in others], dtype=float) - aa_dis = np.linalg.norm(other_a - self.point_a, axis=1) - ab_dis = np.linalg.norm(other_a - self.point_b, axis=1) - ba_dis = np.linalg.norm(other_b - self.point_a, axis=1) - bb_dis = np.linalg.norm(other_b - self.point_b, axis=1) - dis_arr = np.vstack([aa_dis, ab_dis, ba_dis, bb_dis]).min(axis=0) - return dis_arr - - def _get_canonical_pose(self) -> np.ndarray: - """get canonical pose of antipodal contact pair - - Returns: - np.ndarray: canonical pose - """ - # assume gripper closing along x_axis - x_d = self.point_a - self.point_b - x_d = x_d / np.linalg.norm(x_d) - y_d = np.cross(np.array([0, 0, 1.0], dtype=float), x_d) - if np.linalg.norm(y_d) < 1e-4: - y_d = np.cross(np.array([1, 0, 0.0], dtype=float), x_d) - y_d = y_d / np.linalg.norm(y_d) - z_d = np.cross(x_d, y_d) - pose = np.eye(4, dtype=float) - pose[:3, 0] = x_d # rotation x - pose[:3, 1] = y_d # rotation y - pose[:3, 2] = z_d # rotation z - pose[:3, 3] = (self.point_a + self.point_b) / 2 # position - return pose - - def sample_pose(self, sample_num: int = 36) -> np.ndarray: - """sample parallel gripper grasp poses given antipodal contact pairs - - Args: - sample_num (int, optional): sample number. Defaults to 36. - - Returns: - np.ndarray: [sample_num, 4, 4] of float. Sample poses. - """ - # assume gripper closing along x_axis - x_d = self._canonical_pose[:3, 0] - y_d = self._canonical_pose[:3, 1] - z_d = self._canonical_pose[:3, 2] - position = self._canonical_pose[:3, 3] - beta_list = np.linspace(2 * np.pi / sample_num, 2 * np.pi, sample_num) - grasp_poses = np.empty(shape=(sample_num, 4, 4), dtype=float) - for i in range(sample_num): - sample_z = np.sin(beta_list[i]) * y_d + np.cos(beta_list[i]) * z_d - sample_z = sample_z / np.linalg.norm(sample_z) - sample_y = np.cross(sample_z, x_d) - pose = np.eye(4, dtype=float) - pose[:3, 0] = x_d # rotation x - pose[:3, 1] = sample_y # rotation y - pose[:3, 2] = sample_z # rotation z - pose[:3, 3] = position # position - grasp_poses[i] = pose - return grasp_poses - - -class AntipodalGenerator: - def __init__( - self, - mesh_o3dt: o3d.t.geometry.TriangleMesh, - open_length: float, - min_open_length: float = 0.002, - max_angle: float = np.pi / 10, - surface_sample_num: int = 5000, - layer_num: int = 4, - sample_each_layer: int = 6, - nms_ratio: float = 0.02, - antipodal_sample_num: int = 16, - unique_id: str = None, - cache_dir: str = None, - ): - """antipodal grasp pose generator - - Args: - mesh_o3dt (o3d.t.geometry.TriangleMesh): input mesh - open_length (float): gripper maximum open length - max_angle (float, optional): maximum grasp direction with surface normal. Defaults to np.pi/10. - surface_sample_num (int, optional): contact sample number in mesh surface. Defaults to 5000. - layer_num (int, optional): cone sample layer number . Defaults to 4. - sample_each_layer (int, optional): cone sample number in each layer. Defaults to 6. - nms_ratio (float, optional): nms distance ratio. Defaults to 0.02. - antipodal_sample_num (int, optional): grasp poses sample on each antipodal contact pair. Defaults to 16. - cache_dir (str, optional): file cache directory. Defaults to None. - """ - self._antipodal_max_angle = max_angle - self._open_length = open_length - self._min_open_length = min_open_length - self._mesh_o3dt = mesh_o3dt - verts = mesh_o3dt.vertex.positions.numpy() - self._center_of_mass = verts.mean(axis=0) - if unique_id is None: - unique_file_name = self._get_unique_id( - mesh_o3dt, open_length, max_angle, surface_sample_num - ) - else: - unique_file_name = f"{unique_id}_{str(open_length)}_{str(max_angle)}_{str(surface_sample_num)}" - if cache_dir is None: - cache_dir = os.path.join(pathlib.Path.home(), "grasp_pose") - logger.log_debug(f"Set cache directory to {cache_dir}.") - if not os.path.isdir(cache_dir): - os.mkdir(cache_dir) - cache_file = os.path.join(cache_dir, f"{unique_file_name}.pickle") - if not os.path.isfile(cache_file): - # generate cache - grasp_list = self._generate_cache( - cache_file, - mesh_o3dt=mesh_o3dt, - max_angle=max_angle, - surface_sample_num=surface_sample_num, - layer_num=layer_num, - sample_each_layer=sample_each_layer, - nms_ratio=nms_ratio, - antipodal_sample_num=antipodal_sample_num, - ) - else: - # load cache - grasp_list = self._load_cache(cache_file) - self._grasp_list = grasp_list - - def _get_unique_id( - self, - mesh_o3dt: o3d.t.geometry.TriangleMesh, - open_length: float, - max_angle: float, - surface_sample_num: int, - ) -> str: - mesh_md5 = get_mesh_md5(mesh_o3dt) - return ( - f"{mesh_md5}_{str(open_length)}_{str(max_angle)}_{str(surface_sample_num)}" - ) - - def _generate_cache( - self, - cache_file: str, - mesh_o3dt: o3d.t.geometry.TriangleMesh, - max_angle: float = np.pi / 10, - surface_sample_num: int = 5000, - layer_num: int = 4, - sample_each_layer: int = 6, - nms_ratio: float = 0.02, - antipodal_sample_num: int = 16, - ): - self._mesh_o3dt = mesh_o3dt - self._cone_sampler = ConeSampler( - max_angle=max_angle, - layer_num=layer_num, - sample_each_layer=sample_each_layer, - ) - mesh_o3dt = mesh_o3dt.compute_vertex_normals() - assert 1e-4 < max_angle < np.pi / 2 - self._mesh_len = self._get_pc_size(mesh_o3dt.vertex.positions.numpy()).max() - start_time = time.time() - antipodal_list = self._antipodal_generate(mesh_o3dt, surface_sample_num) - logger.log_debug( - f"Antipodal sampling cost {(time.time() - start_time) * 1000} ms." - ) - logger.log_debug(f"Find {len(antipodal_list)} initial antipodal pairs.") - - valid_antipodal_list = self._antipodal_clean(antipodal_list) - - start_time = time.time() - nms_antipodal_list = self._antipodal_nms( - valid_antipodal_list, nms_ratio=nms_ratio - ) - logger.log_debug(f"NMS cost {(time.time() - start_time) * 1000} ms.") - logger.log_debug( - f"There are {len(nms_antipodal_list)} antipodal pair after nms." - ) - # all poses - start_time = time.time() - grasp_poses, score, open_length = self._sample_grasp_pose( - nms_antipodal_list, antipodal_sample_num - ) - logger.log_debug(f"Pose sampling cost {(time.time() - start_time) * 1000} ms.") - logger.log_debug( - f"There are {grasp_poses.shape[0]} poses after grasp pose sampling." - ) - # write data - data_dict = { - "grasp_poses": grasp_poses, - "score": score, - "open_length": open_length, - } - pickle.dump(data_dict, open(cache_file, "wb")) - # TODO: contact pair visualization - # self.antipodal_visual(nms_antipodal_list) - grasp_num = grasp_poses.shape[0] - logger.log_debug(f"Write {grasp_num} poses to pickle file {cache_file}.") - # Use list comprehension for efficient list construction - grasp_list = [ - AntipodalGrasp(grasp_poses[i], open_length[i], score[i]) - for i in range(grasp_num) - ] - return grasp_list - - def _load_cache(self, cache_file: str): - data_dict = pickle.load(open(cache_file, "rb")) - grasp_num = data_dict["grasp_poses"].shape[0] - logger.log_debug(f"Load {grasp_num} poses from pickle file {cache_file}.") - # Use list comprehension for efficient list construction - grasp_list = [ - AntipodalGrasp( - data_dict["grasp_poses"][i], - data_dict["open_length"][i], - data_dict["score"][i], - ) - for i in range(grasp_num) - ] - return grasp_list - - def _get_pc_size(self, vertices): - return np.array( - [ - vertices[:, 0].max() - vertices[:, 0].min(), - vertices[:, 1].max() - vertices[:, 1].min(), - vertices[:, 2].max() - vertices[:, 2].min(), - ] - ) - - def _antipodal_generate( - self, mesh_o3dt: o3d.t.geometry.TriangleMesh, surface_sample_num: int = 5000 - ): - surface_pcd = mesh_o3dt.to_legacy().sample_points_uniformly(surface_sample_num) - points = np.array(surface_pcd.points) - normals = np.array(surface_pcd.normals) - point_num = points.shape[0] - scene = o3d.t.geometry.RaycastingScene() - scene.add_triangles(mesh_o3dt) - # raycast - ray_n = self._cone_sampler._ray_num - ray_num = point_num * ray_n - ray_begins = np.empty(shape=(ray_num, 3), dtype=float) - ray_direcs = np.empty(shape=(ray_num, 3), dtype=float) - origin_normals = np.empty(shape=(ray_num, 3), dtype=float) - origin_points = np.empty(shape=(ray_num, 3), dtype=float) - start_time = time.time() - for i in range(point_num): - ray_direc = self._cone_sampler.cone_sample_direc( - normals[i], is_visual=False - ) - # raycast from outside of object - ray_begin = points[i] - 2 * self._mesh_len * ray_direc - ray_direcs[i * ray_n : (i + 1) * ray_n, :] = ray_direc - ray_begins[i * ray_n : (i + 1) * ray_n, :] = ray_begin - origin_normals[i * ray_n : (i + 1) * ray_n, :] = normals[i] - origin_points[i * ray_n : (i + 1) * ray_n, :] = points[i] - logger.log_debug(f"Cone sampling cost {(time.time() - start_time) * 1000} ms.") - - start_time = time.time() - rays = o3d.core.Tensor( - np.hstack([ray_begins, ray_direcs]), dtype=o3d.core.Dtype.Float32 - ) - logger.log_debug(f"Open3d raycast {(time.time() - start_time) * 1000} ms.") - - raycast_rtn = scene.cast_rays(rays) - hit_dis_all = raycast_rtn["t_hit"].numpy() - hit_normal_all = raycast_rtn["primitive_normals"].numpy() - - # max_angle_cos = np.cos(self._antipodal_max_angle) - antipodal_list = [] - # get antipodal points - start_time = time.time() - for i in range(ray_num): - hit_dis = hit_dis_all[i] - hit_normal = hit_normal_all[i] - hit_point = ray_begins[i] + ray_direcs[i] * hit_dis - antipodal_dis = np.linalg.norm(origin_points[i] - hit_point) - if ( - antipodal_dis > self._min_open_length - and antipodal_dis < self._open_length - ): - # reject thin close object - antipodal = Antipodal( - origin_points[i], hit_point, origin_normals[i], hit_normal - ) - antipodal_list.append(antipodal) - logger.log_debug( - f"Antipodal initialize cost {(time.time() - start_time) * 1000} ms." - ) - return antipodal_list - - def _sample_grasp_pose( - self, antipodal_list: List[Antipodal], antipodal_sample_num: int = 36 - ): - antipodal_num = len(antipodal_list) - grasp_poses = np.empty( - shape=(antipodal_sample_num * antipodal_num, 4, 4), dtype=float - ) - scores = np.empty(shape=(antipodal_sample_num * antipodal_num,), dtype=float) - open_length = np.empty( - shape=(antipodal_sample_num * antipodal_num,), dtype=float - ) - for i in range(antipodal_num): - grasp_poses[i * antipodal_sample_num : (i + 1) * antipodal_sample_num] = ( - antipodal_list[i].sample_pose(antipodal_sample_num) - ) - scores[i * antipodal_sample_num : (i + 1) * antipodal_sample_num] = ( - antipodal_list[i].score - ) - open_length[i * antipodal_sample_num : (i + 1) * antipodal_sample_num] = ( - antipodal_list[i].dis - ) - return grasp_poses, scores, open_length - - def get_all_grasp(self) -> List[AntipodalGrasp]: - """get all grasp - - Returns: - List[AntipodalGrasp]: list of grasp - """ - return self._grasp_list - - def select_grasp( - self, - approach_direction: np.ndarray, - select_num: int = 20, - max_angle: float = np.pi / 10, - select_method: GraspSelectMethod = GraspSelectMethod.NORMAL_SCORE, - ) -> List[AntipodalGrasp]: - """Select grasps. Masked by max_angle and sort by grasp score - - Args: - approach_direction (np.ndarray): gripper approach direction - select_num (int, optional): select grasp number. Defaults to 10. - max_angle (float, optional): max angle threshold (angle with surface normal). Defaults to np.pi/10. - select_method (select_method, optional) - Return: - List[AntipodalGrasp]: list of grasp - """ - grasp_num = len(self._grasp_list) - all_idx = np.arange(grasp_num) - - # Vectorized extraction of poses and scores using list comprehension - grasp_poses = np.array([g.pose for g in self._grasp_list], dtype=float) - scores = np.array([g.score for g in self._grasp_list], dtype=float) - position = grasp_poses[:, :3, 3] - - # mask acoording to table up direction - grasp_z = grasp_poses[:, :3, 2] - direc_dot = (grasp_z * approach_direction).sum(axis=1) - valid_mask = direc_dot > np.cos(max_angle) - valid_id = all_idx[valid_mask] - - # sort acoording to different grasp score - if select_method == GraspSelectMethod.NORMAL_SCORE: - valid_score = scores[valid_id] - sort_valid_idx = np.argsort(valid_score)[::-1] # large is better - elif select_method == GraspSelectMethod.NEAR_APPROACH: - position_dot = (position * approach_direction).sum(axis=1) - valid_height = position_dot[valid_id] - sort_valid_idx = np.argsort(valid_height) - elif select_method == GraspSelectMethod.CENTER: - center_dis = np.linalg.norm(position - self._center_of_mass, axis=1) - valid_center_dis = center_dis[valid_id] - sort_valid_idx = np.argsort(valid_center_dis) - else: - logger.log_warning(f"select_method {select_method.name} not implemented.") - # return all grasp - return self._grasp_list - - # get best score sample index - result_num = min(len(sort_valid_idx), select_num) - best_valid_idx = sort_valid_idx[:result_num] - best_idx = valid_id[best_valid_idx] - - # Use list comprehension for faster list construction - return [self._grasp_list[idx] for idx in best_idx] - - def _antipodal_nms( - self, antipodal_list: List[Antipodal], nms_ratio: float = 0.02 - ) -> List[Antipodal]: - antipodal_num = len(antipodal_list) - nms_mask = np.empty(shape=(antipodal_num,), dtype=bool) - nms_mask.fill(True) - score_list = np.empty(shape=(antipodal_num,), dtype=float) - - for i in range(antipodal_num): - score_list[i] = antipodal_list[i].score - - sort_idx = np.argsort(score_list)[::-1] - - dis_th = self._mesh_len * nms_ratio - for i in range(antipodal_num): - if not nms_mask[sort_idx[i]]: - continue - antipodal_max = antipodal_list[sort_idx[i]] - other_antipodal = [] - other_idx = [] - for j in range(i + 1, antipodal_num): - if nms_mask[sort_idx[j]]: - other_antipodal.append(antipodal_list[sort_idx[j]]) - other_idx.append(sort_idx[j]) - dis_arr = antipodal_max.get_dis_arr(other_antipodal) - invalid_mask = dis_arr < dis_th - for j, flag in enumerate(invalid_mask): - if flag: - nms_mask[other_idx[j]] = False - nms_antipodal_list = [] - for i in range(antipodal_num): - if nms_mask[sort_idx[i]]: - nms_antipodal_list.append(antipodal_list[sort_idx[i]]) - - # TODO: nms validation check. remove in future - # antipodal_num = len(nms_antipodal_list) - # for i in range(antipodal_num): - # for j in range(i + 1, antipodal_num): - # antipodal_dis = nms_antipodal_list[i].get_dis(nms_antipodal_list[j]) - # if antipodal_dis < dis_th: - # logger.log_warning(f"find near antipodal {i} and {j} with dis {antipodal_dis}") - return nms_antipodal_list - - def _antipodal_clean(self, antipodal_list: List[Antipodal]): - # TODO: need collision checker - - valid_antipodal = [] - max_angle_cos = np.cos(self._antipodal_max_angle) - for antipodal in antipodal_list: - if ( - 1e-5 < antipodal.dis < self._open_length - and antipodal.angle_cos > max_angle_cos - ): - valid_antipodal.append(antipodal) - return valid_antipodal - - def antipodal_visual(self, antipodal_list): - mesh_visual = self._mesh_o3dt.to_legacy() - antipodal_num = len(antipodal_list) - draw_points = np.empty(shape=(antipodal_num * 2, 3), dtype=float) - draw_lines = np.empty(shape=(antipodal_num, 2), dtype=int) - for i in range(antipodal_num): - direc = antipodal_list[i].point_b - antipodal_list[i].point_a - direc = direc / np.linalg.norm(direc) - anti_begin = antipodal_list[i].point_a - direc * 0.005 - anti_end = antipodal_list[i].point_b + direc * 0.005 - draw_points[i * 2] = anti_begin - draw_points[i * 2 + 1] = anti_end - draw_lines[i] = np.array([i * 2, i * 2 + 1], dtype=int) - - line_set = o3d.geometry.LineSet( - points=o3d.utility.Vector3dVector(draw_points), - lines=o3d.utility.Vector2iVector(draw_lines), - ) - draw_colors = np.empty(shape=(antipodal_num, 3), dtype=float) - draw_colors[:] = np.array([0.0, 1.0, 1.0]) - line_set.colors = o3d.utility.Vector3dVector(draw_colors) - o3d.visualization.draw_geometries([line_set, mesh_visual]) - - def grasp_pose_visual( - self, grasp_list: List[AntipodalGrasp] - ) -> List[o3d.t.geometry.TriangleMesh]: - """visualize grasp pose - - Args: - grasp_list (List[AntipodalGrasp]): list of grasp - - Returns: - List[o3d.t.geometry.TriangleMesh]: list of visualization mesh - """ - pose_num = len(grasp_list) - visual_mesh_list = [self._mesh_o3dt.compute_vertex_normals()] - - max_angle_cos = np.cos(self._antipodal_max_angle) - - for i in range(pose_num): - grasp_mesh = grasp_list[i].grasp_pose_visual_mesh( - gripper_open_length=self._open_length - ) - grasp_mesh.transform(grasp_list[i].pose) - # low score: red | hight score: blue - score_ratio = (grasp_list[i].score - max_angle_cos) / (1 - max_angle_cos) - score_ratio = min(1.0, score_ratio) - score_ratio = max(0.0, score_ratio) - grasp_mesh.paint_uniform_color(np.array([1 - score_ratio, 0, score_ratio])) - visual_mesh_list.append(o3d.t.geometry.TriangleMesh.from_legacy(grasp_mesh)) - return visual_mesh_list diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py similarity index 54% rename from embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py rename to embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py index 06d82391..a47efbb7 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py @@ -18,14 +18,17 @@ import argparse import open3d as o3d import time -from pathlib import Path -from typing import Any, cast import torch import numpy as np import trimesh +import hashlib +import torch.nn.functional as F import viser import viser.transforms as tf +from pathlib import Path +from typing import Any, cast + from embodichain.utils import logger from embodichain.utils import configclass from embodichain.toolkits.graspkit.pg_grasp.antipodal_sampler import ( @@ -33,47 +36,100 @@ AntipodalSamplerCfg, ) from embodichain.utils import configclass -from .gripper_collision_checker import ( - SimpleGripperCollisionChecker, - SimpleGripperCollisionCfg, +from embodichain.toolkits.graspkit.pg_grasp import ( + GripperCollisionChecker, + GripperCollisionCfg, ) -import hashlib -import torch.nn.functional as F -import tempfile + + +__all__ = ["GraspGenerator", "GraspGeneratorCfg"] @configclass -class GraspAnnotatorCfg: +class GraspGeneratorCfg: + """Configuration for :class:`GraspGenerator`. + + Controls the interactive grasp region annotation workflow, including the + browser-based visualizer settings, antipodal sampling parameters, and + grasp-pose filtering thresholds. + """ + viser_port: int = 15531 + """Port used by the Viser browser-based visualizer for interactive grasp + region annotation.""" + use_largest_connected_component: bool = False + """When ``True``, only the largest connected component of the selected mesh + region is retained. Useful for meshes that contain disconnected fragments + or when selecting a local feature such as a handle.""" + antipodal_sampler_cfg: AntipodalSamplerCfg = AntipodalSamplerCfg() + """Nested configuration for the antipodal point sampler. Controls the + number of sampled surface points, ray perturbation angle, and gripper jaw + distance limits. See :class:`AntipodalSamplerCfg` for details.""" + force_regenerate: bool = False + """When ``True``, the user is required to annotate the grasp region every + time, bypassing any cached results from a previous run.""" + max_deviation_angle: float = np.pi / 12 + """Maximum allowed angle (in radians) between the specified approach + direction and the axis connecting an antipodal point pair. Pairs that + deviate more than this threshold from perpendicular to the approach are + discarded during grasp pose computation.""" -@configclass -class SelectResult: - vertex_indices: np.ndarray | None = None - face_indices: np.ndarray | None = None - vertices: np.ndarray | None = None - faces: np.ndarray | None = None +class GraspGenerator: + """Antipodal grasp-pose generator for parallel-jaw grippers. + + Given an object mesh, ``GraspGenerator`` produces feasible grasp poses + through a three-stage pipeline: + + 1. **Antipodal sampling** — Surface points are uniformly sampled and + rays are cast along (and near) the inward normal to find antipodal + point pairs on opposite sides of the mesh (:meth:`generate`). + Alternatively, an interactive Viser-based annotator lets a human + select the graspable region (:meth:`annotate`). + 2. **Pose construction** — For each antipodal pair, a 6-DoF grasp + frame is built so that the gripper opening aligns with the pair axis + and the approach direction is consistent with a user-specified + vector (:meth:`get_grasp_poses`). + 3. **Filtering & ranking** — Grasp candidates that would cause the + gripper to collide with the object are discarded. Surviving poses + are scored by a weighted cost that penalises angular deviation from + the approach direction, narrow opening length, and distance to the + mesh centroid. + + Antipodal pairs are cached to disk (keyed on mesh geometry) and + automatically reused across sessions unless ``force_regenerate`` is set. + + Typical usage:: + + generator = GraspGenerator(vertices, triangles, cfg=cfg) + + # Programmatic: sample on the whole mesh or a sub-region + generator.generate() # whole mesh + generator.generate(face_indices=some_idx) # specific faces + # Interactive: pick region in a browser UI + generator.annotate() -class GraspAnnotator: - """GraspAnnotator provides functionality to annotate antipodal grasp regions on a given object mesh. It allows users to interactively select regions on the mesh and generates antipodal point pairs for grasping based on the selected region. The annotator also includes a collision checker to filter out infeasible grasp poses and can visualize the generated grasp poses in a 3D viewer.""" + # Then compute the best grasp pose + pose, open_length = generator.get_grasp_poses(object_pose, approach_dir) + """ def __init__( self, vertices: torch.Tensor, triangles: torch.Tensor, - cfg: GraspAnnotatorCfg = GraspAnnotatorCfg(), - gripper_collision_cfg: SimpleGripperCollisionCfg = SimpleGripperCollisionCfg(), + cfg: GraspGeneratorCfg = GraspGeneratorCfg(), + gripper_collision_cfg: GripperCollisionCfg = GripperCollisionCfg(), ) -> None: - """Initialize the GraspAnnotator with the given mesh vertices, triangles, and configuration. + """Initialize the GraspGenerator with the given mesh vertices, triangles, and configuration. Args: vertices (torch.Tensor): A tensor of shape (V, 3) representing the vertex positions of the mesh. triangles (torch.Tensor): A tensor of shape (F, 3) representing the triangle indices of the mesh. - cfg (GraspAnnotatorCfg, optional): Configuration for the grasp annotator. Defaults to GraspAnnotatorCfg(). + cfg (GraspGeneratorCfg, optional): Configuration for the grasp annotator. Defaults to GraspGeneratorCfg(). """ self.device = vertices.device self.vertices = vertices @@ -84,40 +140,119 @@ def __init__( process=False, force="mesh", ) - self._collision_checker = SimpleGripperCollisionChecker( + self._collision_checker = GripperCollisionChecker( object_mesh_verts=vertices, object_mesh_faces=triangles, cfg=gripper_collision_cfg, ) self.cfg = cfg - self.antipodal_sampler = AntipodalSampler(cfg=cfg.antipodal_sampler_cfg) + self._antipodal_sampler = AntipodalSampler(cfg=cfg.antipodal_sampler_cfg) + self._hit_point_pairs: torch.Tensor | None = None - def annotate(self) -> torch.Tensor: - """Annotate antipodal grasp region on the mesh and return sampled antipodal point pairs. - Returns: - torch.Tensor: A tensor of shape (N, 2, 3) representing N antipodal point pairs. Each pair consists of a hit point and its corresponding surface point. - """ + # Load cached antipodal pairs for the whole mesh if available. cache_path = self._get_cache_dir(self.vertices, self.triangles) if os.path.exists(cache_path) and not self.cfg.force_regenerate: - logger.log_info( - f"Found existing antipodal retult. Loading cached antipodal pairs from {cache_path}" - ) - hit_point_pairs = torch.tensor( + logger.log_info(f"Found cached antipodal pairs at {cache_path}. Loading.") + self._hit_point_pairs = torch.tensor( np.load(cache_path), dtype=torch.float32, device=self.device ) - return hit_point_pairs + + def generate( + self, + vertex_indices: torch.Tensor | None = None, + face_indices: torch.Tensor | None = None, + ) -> torch.Tensor: + """Generate antipodal point pairs for grasping on the given mesh region. + + Exactly one of ``vertex_indices`` or ``face_indices`` must be provided + to define the grasp region. When both are ``None``, the whole mesh is + used. + + Results are cached to disk and reused when ``force_regenerate`` is + ``False``. + + Args: + vertex_indices: 1-D ``torch.Tensor`` of vertex indices defining the + grasp region. + face_indices: 1-D ``torch.Tensor`` of face indices defining the + grasp region. + + Raises: + ValueError: If both ``vertex_indices`` and ``face_indices`` are + provided at the same time. + + Returns: + torch.Tensor: A tensor of shape ``(N, 2, 3)`` representing N + antipodal point pairs. Each pair consists of a hit point and + its corresponding surface point. + """ + if vertex_indices is not None and face_indices is not None: + raise ValueError( + "Only one of vertex_indices or face_indices should be provided, not both." + ) + + if vertex_indices is None and face_indices is None: + sub_vertices = self.vertices + sub_faces = self.triangles else: - logger.log_info( - f"[Viser] *****Annotate grasp region in http://localhost:{self.cfg.viser_port}" + if face_indices is not None: + face_idx_np = face_indices.cpu().numpy() + else: + vertex_idx_np = vertex_indices.cpu().numpy() + vertex_mask = np.zeros(self.mesh.vertices.shape[0], dtype=bool) + vertex_mask[vertex_idx_np] = True + face_all = cast(np.ndarray, self.mesh.faces) + face_idx_np = np.flatnonzero(np.all(vertex_mask[face_all], axis=1)) + ( + _, + _, + sub_vertices_np, + sub_faces_np, + ) = GraspGenerator._extract_selection_from_faces( + self.mesh, face_idx_np, self.cfg.use_largest_connected_component + ) + if sub_vertices_np is None: + return torch.empty(0, 2, 3, dtype=torch.float32, device=self.device) + sub_vertices = torch.as_tensor( + sub_vertices_np, dtype=torch.float32, device=self.device + ) + sub_faces = torch.as_tensor( + sub_faces_np, dtype=torch.int64, device=self.device + ) + + cache_path = self._get_cache_dir(sub_vertices, sub_faces) + if os.path.exists(cache_path) and not self.cfg.force_regenerate: + logger.log_info(f"Found cached antipodal pairs at {cache_path}") + return torch.tensor( + np.load(cache_path), dtype=torch.float32, device=self.device ) + self._hit_point_pairs = self._antipodal_sampler.sample(sub_vertices, sub_faces) + self._save_cache(cache_path, self._hit_point_pairs) + return self._hit_point_pairs + + def annotate(self) -> torch.Tensor: + """Annotate antipodal grasp region on the mesh and return sampled antipodal point pairs. + + Returns: + torch.Tensor: A tensor of shape (N, 2, 3) representing N antipodal point pairs. + Each pair consists of a hit point and its corresponding surface point. + """ + + logger.log_info( + f"[Viser] *****Annotate grasp region in http://localhost:{self.cfg.viser_port}" + ) + server = viser.ViserServer(port=self.cfg.viser_port) server.gui.configure_theme(brand_color=(130, 0, 150)) server.scene.set_up_direction("+z") mesh_handle = server.scene.add_mesh_trimesh(name="/mesh", mesh=self.mesh) selected_overlay: viser.GlbHandle | None = None - selection: SelectResult = SelectResult() + sel_vertex_indices: np.ndarray | None = None + sel_face_indices: np.ndarray | None = None + sel_vertices: np.ndarray | None = None + sel_faces: np.ndarray | None = None hit_point_pairs = None return_flag = False @@ -126,7 +261,10 @@ def annotate(self) -> torch.Tensor: def _(client: viser.ClientHandle) -> None: nonlocal mesh_handle nonlocal selected_overlay - nonlocal selection + nonlocal sel_vertex_indices + nonlocal sel_face_indices + nonlocal sel_vertices + nonlocal sel_faces # client.camera.position = np.array([0.0, 0.0, -0.5]) # client.camera.wxyz = np.array([1.0, 0.0, 0.0, 0.0]) @@ -144,11 +282,14 @@ def _(_evt: viser.GuiEvent) -> None: def _(event: viser.ScenePointerEvent) -> None: nonlocal mesh_handle nonlocal selected_overlay - nonlocal selection + nonlocal sel_vertex_indices + nonlocal sel_face_indices + nonlocal sel_vertices + nonlocal sel_faces nonlocal hit_point_pairs client.scene.remove_pointer_callback() - proj, depth = GraspAnnotator._project_vertices_to_screen( + proj, depth = GraspGenerator._project_vertices_to_screen( cast(np.ndarray, self.mesh.vertices), mesh_handle, event.client.camera, @@ -164,20 +305,24 @@ def _(event: viser.ScenePointerEvent) -> None: depth > 1e-6 ) - selection = GraspAnnotator._extract_selection( + ( + sel_vertex_indices, + sel_face_indices, + sel_vertices, + sel_faces, + ) = GraspGenerator._extract_selection_from_vertex_mask( self.mesh, vertex_mask, self.cfg.use_largest_connected_component ) - if selection.vertices is None: + if sel_vertices is None: logger.log_warning("[Selection] No vertices selected.") return color_mesh = self.mesh.copy() - used_vertex_indices = selection.vertex_indices vertex_colors = np.tile( np.array([[0.85, 0.85, 0.85, 1.0]]), (self.mesh.vertices.shape[0], 1), ) - vertex_colors[used_vertex_indices] = np.array( + vertex_colors[sel_vertex_indices] = np.array( [0.56, 0.17, 0.92, 1.0] ) color_mesh.visual.vertex_colors = vertex_colors # type: ignore @@ -188,8 +333,8 @@ def _(event: viser.ScenePointerEvent) -> None: if selected_overlay is not None: selected_overlay.remove() selected_mesh = trimesh.Trimesh( - vertices=selection.vertices, - faces=selection.faces, + vertices=sel_vertices, + faces=sel_faces, process=False, ) selected_mesh.visual.face_colors = (0.9, 0.2, 0.2, 0.65) # type: ignore @@ -197,14 +342,14 @@ def _(event: viser.ScenePointerEvent) -> None: name="/selected", mesh=selected_mesh ) logger.log_info( - f"[Selection] Selected {selection.vertex_indices.size} vertices and {selection.face_indices.size} faces." + f"[Selection] Selected {sel_vertex_indices.size} vertices and {sel_face_indices.size} faces." ) - hit_point_pairs = self.antipodal_sampler.sample( - torch.tensor(selection.vertices, device=self.device), - torch.tensor(selection.faces, device=self.device), + hit_point_pairs = self._antipodal_sampler.sample( + torch.tensor(sel_vertices, device=self.device), + torch.tensor(sel_faces, device=self.device), ) - extended_hit_point_pairs = GraspAnnotator._extend_hit_point_pairs( + extended_hit_point_pairs = GraspGenerator._extend_hit_point_pairs( hit_point_pairs ) server.scene.add_line_segments( @@ -221,23 +366,24 @@ def _() -> None: @confirm_button.on_click def _(_evt: viser.GuiEvent) -> None: nonlocal return_flag - if selection.vertices is None: + if sel_vertices is None: logger.log_warning("[Selection] No vertex selected.") return else: logger.log_info( - f"[Selection] {selection.vertices.shape[0]}vertices selected. Generating antipodal point pairs." + f"[Selection] {sel_vertices.shape[0]}vertices selected. Generating antipodal point pairs." ) return_flag = True while True: if return_flag: - # save result to cache if hit_point_pairs is not None: + self._hit_point_pairs = hit_point_pairs + cache_path = self._get_cache_dir(self.vertices, self.triangles) self._save_cache(cache_path, hit_point_pairs) break time.sleep(0.5) - return hit_point_pairs + return self._hit_point_pairs def _get_cache_dir(self, vertices: torch.Tensor, triangles: torch.Tensor): from embodichain.lab.sim.sim_manager import GRASP_ANNOTATOR_CACHE_DIR @@ -300,59 +446,82 @@ def _project_vertices_to_screen( projected = (1.0 + projected) / 2.0 return projected, vertices_camera[:, 2] - def _extract_selection( + @staticmethod + def _extract_selection_from_vertex_mask( mesh: trimesh.Trimesh, vertex_mask: np.ndarray, largest_component: bool, - ) -> SelectResult: - def _largest_connected_face_component(face_ids: np.ndarray) -> np.ndarray: - if face_ids.size <= 1: - return face_ids - - face_id_set = set(face_ids.tolist()) - parent: dict[int, int] = { - int(face_id): int(face_id) for face_id in face_ids - } - - def find(x: int) -> int: - root = x - while parent[root] != root: - root = parent[root] - while parent[x] != x: - x_parent = parent[x] - parent[x] = root - x = x_parent - return root - - def union(a: int, b: int) -> None: - ra, rb = find(a), find(b) - if ra != rb: - parent[rb] = ra - - face_adjacency = cast(np.ndarray, mesh.face_adjacency) - for face_a, face_b in face_adjacency: - if int(face_a) in face_id_set and int(face_b) in face_id_set: - union(int(face_a), int(face_b)) - - groups: dict[int, list[int]] = {} - for face_id in face_ids: - root = find(int(face_id)) - groups.setdefault(root, []).append(int(face_id)) - - largest_group = max(groups.values(), key=len) - return np.array(largest_group, dtype=np.int32) + ) -> tuple[ + np.ndarray | None, np.ndarray | None, np.ndarray | None, np.ndarray | None + ]: + """Extract a sub-mesh from *mesh* using a per-vertex boolean mask. + + Args: + mesh: The source mesh. + vertex_mask: Boolean array of shape ``(V,)`` indicating which + vertices are selected. + largest_component: If ``True``, keep only the largest connected + component among the selected faces. + Returns: + A tuple ``(vertex_indices, face_indices, sub_vertices, sub_faces)`` + where ``sub_vertices`` and ``sub_faces`` define the extracted + sub-mesh with remapped indices. Returns ``(None, None, None, None)`` + if no faces are selected. + """ faces = cast(np.ndarray, mesh.faces) face_mask = np.all(vertex_mask[faces], axis=1) - face_indices = np.flatnonzero(face_mask) if face_indices.size == 0: - return SelectResult() + return None, None, None, None + if largest_component: + face_indices = GraspGenerator._largest_connected_face_component( + mesh, face_indices + ) + if face_indices.size == 0: + return None, None, None, None + return GraspGenerator._build_sub_mesh(mesh, face_indices) + + @staticmethod + def _extract_selection_from_faces( + mesh: trimesh.Trimesh, + face_indices: np.ndarray, + largest_component: bool, + ) -> tuple[ + np.ndarray | None, np.ndarray | None, np.ndarray | None, np.ndarray | None + ]: + """Extract a sub-mesh from *mesh* using face indices. + + Args: + mesh: The source mesh. + face_indices: Array of face indices to include. + largest_component: If ``True``, keep only the largest connected + component among the selected faces. + + Returns: + Same as :meth:`_extract_selection_from_vertex_mask`. + """ + if face_indices.size == 0: + return None, None, None, None if largest_component: - face_indices = _largest_connected_face_component(face_indices) + face_indices = GraspGenerator._largest_connected_face_component( + mesh, face_indices + ) if face_indices.size == 0: - return SelectResult() + return None, None, None, None + return GraspGenerator._build_sub_mesh(mesh, face_indices) + + @staticmethod + def _build_sub_mesh( + mesh: trimesh.Trimesh, + face_indices: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Build a sub-mesh with remapped vertex indices from selected faces. + Returns: + ``(vertex_indices, face_indices, sub_vertices, sub_faces)`` + """ + faces = cast(np.ndarray, mesh.faces) selected_face_vertices = faces[face_indices] vertex_indices = np.unique(selected_face_vertices.reshape(-1)) @@ -362,12 +531,47 @@ def union(a: int, b: int) -> None: sub_vertices = np.asarray(mesh.vertices)[vertex_indices] sub_faces = np.asarray(old_to_new)[selected_face_vertices] - return SelectResult( - vertex_indices=vertex_indices, - face_indices=face_indices, - vertices=sub_vertices, - faces=sub_faces, - ) + return vertex_indices, face_indices, sub_vertices, sub_faces + + @staticmethod + def _largest_connected_face_component( + mesh: trimesh.Trimesh, + face_ids: np.ndarray, + ) -> np.ndarray: + """Return the face indices of the largest connected component.""" + if face_ids.size <= 1: + return face_ids + + face_id_set = set(face_ids.tolist()) + parent: dict[int, int] = {int(face_id): int(face_id) for face_id in face_ids} + + def find(x: int) -> int: + root = x + while parent[root] != root: + root = parent[root] + while parent[x] != x: + x_parent = parent[x] + parent[x] = root + x = x_parent + return root + + def union(a: int, b: int) -> None: + ra, rb = find(a), find(b) + if ra != rb: + parent[rb] = ra + + face_adjacency = cast(np.ndarray, mesh.face_adjacency) + for face_a, face_b in face_adjacency: + if int(face_a) in face_id_set and int(face_b) in face_id_set: + union(int(face_a), int(face_b)) + + groups: dict[int, list[int]] = {} + for face_id in face_ids: + root = find(int(face_id)) + groups.setdefault(root, []).append(int(face_id)) + + largest_group = max(groups.values(), key=len) + return np.array(largest_group, dtype=np.int32) @staticmethod def _apply_transform(points: torch.Tensor, transform: torch.Tensor) -> torch.Tensor: @@ -377,23 +581,42 @@ def _apply_transform(points: torch.Tensor, transform: torch.Tensor) -> torch.Ten def get_grasp_poses( self, - hit_point_pairs: torch.Tensor, object_pose: torch.Tensor, approach_direction: torch.Tensor, is_visual: bool = False, - ) -> torch.Tensor: - """Get grasp pose given approach direction + ) -> tuple[torch.Tensor, torch.Tensor]: + """Get grasp pose given approach direction. + + Uses the antipodal point pairs stored in ``self._hit_point_pairs`` + (populated by :meth:`generate` or :meth:`annotate`). + + TODO: + 1. Support Top-k grasp poses selection. + 2. Support more selection criteria. Args: - hit_point_pairs (torch.Tensor): (N, 2, 3) tensor of N antipodal point pairs. Each pair consists of a hit point and its corresponding surface point. - object_pose (torch.Tensor): (4, 4) homogeneous transformation matrix representing the pose of the object in the world frame. - approach_direction (torch.Tensor): (3,) unit vector representing the desired approach direction of the gripper in the world frame. + object_pose: ``(4, 4)`` homogeneous transformation matrix + representing the pose of the object in the world frame. + approach_direction: ``(3,)`` unit vector representing the desired + approach direction of the gripper in the world frame. + is_visual: If ``True``, enable visual collision checking. Returns: - torch.Tensor: (4, 4) homogeneous transformation matrix representing the grasp pose in the world frame that aligns the gripper's approach direction with the given approach_direction. Returns None if no valid grasp pose can be found. + A tuple ``(best_grasp_pose, best_open_length)`` where + ``best_grasp_pose`` is a ``(4, 4)`` homogeneous matrix and + ``best_open_length`` is a scalar. + + Raises: + RuntimeError: If :meth:`generate` or :meth:`annotate` has not + been called yet. """ - origin_points = hit_point_pairs[:, 0, :] - hit_points = hit_point_pairs[:, 1, :] + if self._hit_point_pairs is None: + raise RuntimeError( + "No antipodal point pairs available. " + "Call generate() or annotate() first." + ) + origin_points = self._hit_point_pairs[:, 0, :] + hit_points = self._hit_point_pairs[:, 1, :] origin_points_ = self._apply_transform(origin_points, object_pose) hit_points_ = self._apply_transform(hit_points, object_pose) centers = (origin_points_ + hit_points_) / 2 @@ -412,7 +635,7 @@ def get_grasp_poses( valid_centers = centers[valid_mask] # compute grasp poses using antipodal point pairs and approach direction - valid_grasp_poses = GraspAnnotator._grasp_pose_from_approach_direction( + valid_grasp_poses = GraspGenerator._grasp_pose_from_approach_direction( valid_grasp_x, approach_direction, valid_centers ) valid_open_lengths = torch.norm( @@ -536,43 +759,3 @@ def visualize_grasp_pose( window_name="Grasp Pose Visualization", mesh_show_back_face=True, ) - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Viser mesh 标注工具:框选并导出对应顶点与三角面" - ) - parser.add_argument( - "--mesh", type=Path, required=True, help="输入 mesh 文件路径,例如 mug.obj" - ) - parser.add_argument("--scale", type=float, default=1.0, help="加载后整体缩放系数") - parser.add_argument("--port", type=int, default=12151, help="viser 服务端口") - parser.add_argument( - "--output-dir", - type=Path, - default=Path("outputs/mesh_annotations"), - help="标注结果导出目录", - ) - parser.add_argument( - "--largest-component", - action="store_true", - help="只保留框选结果中的最大连通块(常用于稳定提取把手等局部)", - ) - args = parser.parse_args() - - mesh = trimesh.load(args.mesh, process=False, force="mesh") - vertices = mesh.vertices * args.scale - triangles = mesh.faces - cfg = GraspAnnotatorCfg( - force_regenerate=True, - ) - tool = GraspAnnotator(cfg=cfg) - hit_point_pairs = tool.annotate( - vertices=torch.from_numpy(vertices).float(), - triangles=torch.from_numpy(triangles).long(), - ) - logger.log_info(f"Sample {hit_point_pairs.shape[0]} antipodal point pairs.") - - -if __name__ == "__main__": - main() diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py index 2d0aa518..5e4b7594 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py @@ -22,6 +22,8 @@ from embodichain.utils import configclass from embodichain.utils import logger +__all__ = ["AntipodalSamplerCfg", "AntipodalSampler"] + @configclass class AntipodalSamplerCfg: @@ -29,10 +31,18 @@ class AntipodalSamplerCfg: n_sample: int = 20000 """surface point sample number""" + max_angle: float = np.pi / 12 - """maximum angle (in radians) to randomly disturb the ray direction for antipodal point sampling, used to increase the diversity of sampled antipodal points. Note that setting max_angle to 0 will disable the random disturbance and sample antipodal points strictly along the surface normals, which may result in less diverse antipodal points and may not be ideal for all objects or grasping scenarios.""" + """maximum angle (in radians) to randomly disturb the ray direction for antipodal point sampling, + used to increase the diversity of sampled antipodal points. Note that setting max_angle to 0 will + disable the random disturbance and sample antipodal points strictly along the surface normals, + which may result in less diverse antipodal points and may not be ideal for all objects or grasping + scenarios. + """ + max_length: float = 0.1 """maximum gripper open width, used to filter out antipodal points that are too far apart to be grasped""" + min_length: float = 0.001 """minimum gripper open width, used to filter out antipodal points that are too close to be grasped""" @@ -241,15 +251,3 @@ def visualize(self, hit_point_pairs: torch.Tensor): window_name="Antipodal Point Pairs", mesh_show_back_face=True, ) - - -if __name__ == "__main__": - mesh_path = "/media/chenjian/_abc/project/grasp_annotator/dustpan_saa.ply" - mesh = o3d.t.io.read_triangle_mesh(mesh_path) - vertices = torch.from_numpy(mesh.vertex.positions.cpu().numpy()) - faces = torch.from_numpy(mesh.triangle.indices.cpu().numpy()) - - sampler = AntipodalSampler() - hit_point_pairs = sampler.sample(vertices, faces) - sampler.visualize(hit_point_pairs) - print(f"Sampled {hit_point_pairs.shape[0]} antipodal points") diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/collision_checker.py similarity index 92% rename from embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py rename to embodichain/toolkits/graspkit/pg_grasp/collision_checker.py index e8baf202..fcbfb850 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/collision_checker.py @@ -27,30 +27,30 @@ from typing import List, Tuple, Union from dexsim.kit.meshproc import convex_decomposition_coacd -from embodichain.utils import logger from embodichain.utils.warp import convex_signed_distance_kernel from embodichain.utils.device_utils import standardize_device_string +from embodichain.utils.math import transform_points_mat from embodichain.utils import configclass -CONVEX_CACHE_DIR = os.path.join( - os.path.expanduser("~"), ".cache", "embodichain_cache", "convex_decomposition" -) +__all__ = ["ConvexCollisionCheckerCfg", "ConvexCollisionChecker"] @configclass -class BatchConvexCollisionCheckerCfg: - """Configuration for BatchConvexCollisionChecker.""" +class ConvexCollisionCheckerCfg: + """Configuration for ConvexCollisionChecker.""" - collsion_threshold: float = 0.0 + collision_threshold: float = 0.0 """ Collision threshold in meters. A point is considered colliding if its signed distance to the hull interior is <= this threshold. This allows for a margin of error in collision checking, where a small positive threshold can be used to consider points near the surface as colliding, and a small negative threshold can be used to allow for slight penetration without considering it a collision.""" + n_query_mesh_samples: int = 4096 """ Number of points to sample from the query mesh surface for collision checking. A higher number of samples can provide a more accurate collision check at the cost of increased computation time. The optimal number may depend on the complexity of the mesh and the required precision of collision detection.""" + debug: bool = False """ Whether to visualize the collision checking results for debugging purposes. If set to True, the code will generate visualizations of the query points colored by their collision status (e.g., red for colliding points and green for non-colliding points) along with the original mesh. This can help in understanding and verifying the collision checking process, especially during development and testing.""" -class BatchConvexCollisionChecker: - """BatchConvexCollisionChecker performs efficient collision checking between a batch of query point clouds and a convex decomposition of a mesh. The convex decomposition is represented by plane equations of the convex hulls, which are precomputed and cached for efficiency. The collision checking is done by computing the signed distance from each query point to the convex hulls using the plane equations, and determining if any points are colliding based on a specified collision threshold. This class can be used""" +class ConvexCollisionChecker: + """ConvexCollisionChecker performs efficient collision checking between a batch of query point clouds and a convex decomposition of a mesh. The convex decomposition is represented by plane equations of the convex hulls, which are precomputed and cached for efficiency. The collision checking is done by computing the signed distance from each query point to the convex hulls using the plane equations, and determining if any points are colliding based on a specified collision threshold. This class can be used""" def __init__( self, @@ -58,14 +58,17 @@ def __init__( base_mesh_faces: torch.Tensor, max_decomposition_hulls: int = 32, ): - """Initialize the BatchConvexCollisionChecker by performing convex decomposition on the input mesh and extracting plane equations for the convex hulls. The plane equations are cached to disk to avoid redundant computation in future runs. + """Initialize the ConvexCollisionChecker by performing convex decomposition on the input mesh and extracting plane equations for the convex hulls. The plane equations are cached to disk to avoid redundant computation in future runs. + Args: base_mesh_verts: [N, 3] vertex positions of the input mesh. base_mesh_faces: [M, 3] triangle indices of the input mesh. max_decomposition_hulls: maximum number of convex hulls to decompose into. A higher number allows for a more accurate approximation of the original mesh but increases computation time and memory usage. The optimal number may depend on the complexity of the mesh and the required precision of collision checking. """ - if not os.path.isdir(CONVEX_CACHE_DIR): - os.makedirs(CONVEX_CACHE_DIR, exist_ok=True) + from embodichain.lab.sim import CONVEX_DECOMP_DIR + + if not os.path.isdir(CONVEX_DECOMP_DIR): + os.makedirs(CONVEX_DECOMP_DIR, exist_ok=True) self.device = base_mesh_verts.device base_mesh_verts_np = base_mesh_verts.cpu().numpy() base_mesh_faces_np = base_mesh_faces.cpu().numpy() @@ -81,7 +84,7 @@ def __init__( self.mesh.compute_vertex_normals() self.cache_path = os.path.join( - CONVEX_CACHE_DIR, f"{mesh_hash}_{max_decomposition_hulls}.pkl" + CONVEX_DECOMP_DIR, f"{mesh_hash}_{max_decomposition_hulls}.pkl" ) if not os.path.isfile(self.cache_path): @@ -89,7 +92,7 @@ def __init__( # [n_convex, ]: number of faces for each convex hull # generate convex hulls and extract plane equations, then cache to disk - plane_equations_np = BatchConvexCollisionChecker._compute_plane_equations( + plane_equations_np = ConvexCollisionChecker._compute_plane_equations( base_mesh_verts_np, base_mesh_faces_np, max_decomposition_hulls ) # pack as a single tensor @@ -193,7 +196,7 @@ def query_batch_points( """ n_batch = batch_points.shape[0] point_signed_distance, is_point_collide = ( - BatchConvexCollisionChecker.batch_point_convex_query( + ConvexCollisionChecker.batch_point_convex_query( self.plane_equations["plane_equations"], self.plane_equations["plane_equation_counts"], batch_points, @@ -232,7 +235,7 @@ def query( query_mesh_verts: torch.Tensor, query_mesh_faces: torch.Tensor, poses: torch.Tensor, - cfg: BatchConvexCollisionCheckerCfg = BatchConvexCollisionCheckerCfg(), + cfg: ConvexCollisionCheckerCfg = ConvexCollisionCheckerCfg(), ) -> Tuple[torch.Tensor, torch.Tensor]: query_mesh = trimesh.Trimesh( vertices=query_mesh_verts.to("cpu").numpy(), @@ -257,7 +260,7 @@ def query( normals_torch, offsets_torch, transform_points_mat(query_points, poses), - cfg.collsion_threshold, + cfg.collision_threshold, ) penetration_result = torch.max(penetration_result, penetration) collision_result = torch.logical_or(collision_result, collides) diff --git a/embodichain/toolkits/graspkit/pg_grasp/cone_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/cone_sampler.py deleted file mode 100644 index 7c9738fb..00000000 --- a/embodichain/toolkits/graspkit/pg_grasp/cone_sampler.py +++ /dev/null @@ -1,121 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -import open3d as o3d -import numpy as np -from scipy.spatial.transform import Rotation as R - - -def rotate_to_ref(direc: np.ndarray, rotate_ref: np.ndarray): - assert direc.shape == (3,) - direc_len = np.linalg.norm(direc) - assert direc_len > 1e-5 - direc_unit = direc / direc_len - - assert rotate_ref.shape == (3,) - rotate_ref_len = np.linalg.norm(rotate_ref) - assert rotate_ref_len > 1e-5 - rotate_ref_unit = rotate_ref / rotate_ref_len - - rotate_axis = np.cross(rotate_ref_unit, direc_unit) - rotate_axis_len = np.linalg.norm(rotate_axis) - if rotate_axis_len < 1e-5: - # co axis, no need to do rotation - dot_res = direc_unit.dot(rotate_ref_unit) - if dot_res > 0: - # identity rotation - return np.eye(3, dtype=float) - else: - # negative, rotate 180 degree - # rotate with a perpendicular axis - random_axis = np.random.random(size=(3,)) - perpendicular_axis = np.cross(random_axis, rotate_ref_unit) - perpendicular_axis = perpendicular_axis / np.linalg.norm(perpendicular_axis) - ref_rotation = R.from_rotvec(perpendicular_axis * np.pi).as_matrix() - return ref_rotation - else: - rotate_axis = rotate_axis / rotate_axis_len - angle = np.arccos(direc_unit.dot(rotate_ref_unit)) - ref_rotation = R.from_rotvec(angle * rotate_axis, degrees=False).as_matrix() - return ref_rotation - - -class ConeSampler: - def __init__( - self, max_angle: float, layer_num: int = 2, sample_each_layer: int = 4 - ) -> None: - """cone ray sampler - - Args: - max_angle (float): maximum ray angle to surface normal - layer_num (int, optional): circle layer. Defaults to 2. - sample_each_layer (int, optional): ray samples in each circle layer. Defaults to 4. - """ - self._max_angle = max_angle - self._layer_num = layer_num - self._ray_num = layer_num * sample_each_layer + 1 - alpha_list = np.linspace(max_angle / layer_num, max_angle, layer_num) - beta_list = np.linspace( - 2 * np.pi / sample_each_layer, 2 * np.pi, sample_each_layer - ) - self._direc_ref = np.array([0, 0, 1]) - - rotation_list = np.empty(shape=(self._ray_num, 3, 3), dtype=float) - - for i, alpha in enumerate(alpha_list): - for j, beta in enumerate(beta_list): - x_rotation = R.from_euler( - seq="XYZ", angles=np.array([alpha, 0, 0]), degrees=False - ).as_matrix() - z_rotation = R.from_euler( - seq="XYZ", angles=np.array([0, 0, beta]), degrees=False - ).as_matrix() - rotation_list[i * sample_each_layer + j + 1] = z_rotation @ x_rotation - # original direction - rotation_list[0] = np.eye(3) - self._sample_direc = rotation_list[:, :3, 2] # z-axis - - def cone_sample_direc(self, direc: np.ndarray, is_visual: bool = False): - """sample cone directly - - Args: - direc (np.ndarray): direction to sample a cone - is_visual (bool, optional): use visualization or not. Defaults to False. - - Returns: - np.ndarray: [_ray_num, 3] of float, cone direction list - """ - ref_rotation = rotate_to_ref(direc, self._direc_ref) - cone_direc_list = self._sample_direc @ ref_rotation.T - if is_visual: - self._visual(cone_direc_list) - return cone_direc_list - - def _visual(self, cone_direc_list: np.ndarray): - drawer = o3d.geometry.TriangleMesh.create_coordinate_frame(0.5) - for cone_direc in cone_direc_list: - arrow = o3d.geometry.TriangleMesh.create_arrow( - cylinder_radius=0.02, - cone_radius=0.03, - cylinder_height=0.9, - cone_height=0.1, - ) - arrow.compute_vertex_normals() - arrow_rotation = rotate_to_ref(cone_direc, self._direc_ref) - arrow.rotate(arrow_rotation, center=(0, 0, 0)) - arrow.paint_uniform_color(np.array([0.5, 0.5, 0.5])) - drawer += arrow - o3d.visualization.draw_geometries([drawer]) diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py index e2a54558..d19aa332 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -16,17 +16,22 @@ from __future__ import annotations -from embodichain.utils import configclass +import torch from typing import Sequence -from .batch_collision_checker import BatchConvexCollisionChecker -import torch + +from embodichain.utils import configclass +from embodichain.toolkits.graspkit.pg_grasp.collision_checker import ( + ConvexCollisionChecker, +) from embodichain.utils.math import transform_points_mat +__all__ = ["GripperCollisionCfg", "GripperCollisionChecker", "box_surface_grid"] + @configclass -class SimpleGripperCollisionCfg: - """Configuration for the SimpleGripperCollisionChecker. This class defines various parameters related to the +class GripperCollisionCfg: + """Configuration for the GripperCollisionChecker. This class defines various parameters related to the gripper geometry, point cloud generation, and collision checking process. Users can customize these parameters based on the specific gripper being modeled and the requirements of the application. """ @@ -82,14 +87,14 @@ class SimpleGripperCollisionCfg: """ -class SimpleGripperCollisionChecker: +class GripperCollisionChecker: def __init__( self, object_mesh_verts: torch.Tensor, object_mesh_faces: torch.Tensor, - cfg: SimpleGripperCollisionCfg = SimpleGripperCollisionCfg(), + cfg: GripperCollisionCfg = GripperCollisionCfg(), ): - self._checker = BatchConvexCollisionChecker( + self._checker = ConvexCollisionChecker( base_mesh_verts=object_mesh_verts, base_mesh_faces=object_mesh_faces, max_decomposition_hulls=cfg.max_decomposition_hulls, diff --git a/scripts/tutorials/grasp/grasp_generator.py b/scripts/tutorials/grasp/grasp_generator.py index 72077a72..75a88eb6 100644 --- a/scripts/tutorials/grasp/grasp_generator.py +++ b/scripts/tutorials/grasp/grasp_generator.py @@ -39,12 +39,12 @@ RigidObjectCfg, URDFCfg, ) -from embodichain.toolkits.graspkit.pg_grasp.antipodal_annotator import ( - GraspAnnotatorCfg, +from embodichain.toolkits.graspkit.pg_grasp.antipodal_generator import ( + GraspGeneratorCfg, AntipodalSamplerCfg, ) from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import ( - SimpleGripperCollisionCfg, + GripperCollisionCfg, ) @@ -232,7 +232,7 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso mug = create_mug(sim) # get mug grasp pose - grasp_cfg = GraspAnnotatorCfg( + grasp_cfg = GraspGeneratorCfg( viser_port=11801, antipodal_sampler_cfg=AntipodalSamplerCfg( n_sample=20000, max_length=0.088, min_length=0.003 @@ -248,7 +248,7 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso start_time = time.time() - gripper_collision_cfg = SimpleGripperCollisionCfg( + gripper_collision_cfg = GripperCollisionCfg( max_open_length=0.088, finger_length=0.078, point_sample_dense=0.012 ) grasp_xpos = mug.get_grasp_pose( diff --git a/tests/toolkits/test_batch_convex_collision.py b/tests/toolkits/test_batch_convex_collision.py index 819079df..4bf852c8 100644 --- a/tests/toolkits/test_batch_convex_collision.py +++ b/tests/toolkits/test_batch_convex_collision.py @@ -16,9 +16,9 @@ import torch from embodichain.data import get_data_path import trimesh -from embodichain.toolkits.graspkit.pg_grasp.batch_collision_checker import ( - BatchConvexCollisionChecker, - BatchConvexCollisionCheckerCfg, +from embodichain.toolkits.graspkit.pg_grasp.collision_checker import ( + ConvexCollisionChecker, + ConvexCollisionCheckerCfg, ) from embodichain.utils.math import transform_points_mat import warp as wp @@ -29,9 +29,7 @@ def batch_convex_collision_query(device=torch.device("cuda")): mug_mesh = trimesh.load(mug_path, force="mesh", process=False) verts = torch.tensor(mug_mesh.vertices, dtype=torch.float32, device=device) faces = torch.tensor(mug_mesh.faces, dtype=torch.int32, device=device) - collision_checker = BatchConvexCollisionChecker( - verts, faces, max_decomposition_hulls=16 - ) + collision_checker = ConvexCollisionChecker(verts, faces, max_decomposition_hulls=16) poses = torch.tensor( [ diff --git a/tests/toolkits/test_pg_grasp.py b/tests/toolkits/test_pg_grasp.py deleted file mode 100644 index 10f96bd7..00000000 --- a/tests/toolkits/test_pg_grasp.py +++ /dev/null @@ -1,96 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -import open3d as o3d -import numpy as np -import os -from embodichain.toolkits.graspkit.pg_grasp import ( - AntipodalGenerator, - GraspSelectMethod, -) -from embodichain.data import get_data_path - - -def test_antipodal_score_selector(is_visual: bool = False): - mesh_path = get_data_path("ChainRainSec/mesh.ply") - mesh_o3dt = o3d.t.io.read_triangle_mesh(mesh_path) - generator = AntipodalGenerator( - mesh_o3dt=mesh_o3dt, - open_length=0.1, - max_angle=np.pi / 6, - surface_sample_num=5000, - cache_dir=None, - ) - grasp_list = generator.select_grasp( - approach_direction=np.array([0, 0, -1]), - select_num=5, - select_method=GraspSelectMethod.NORMAL_SCORE, - ) - assert len(grasp_list) == 5 - if is_visual: - visual_mesh_list = generator.grasp_pose_visual(grasp_list) - visual_mesh_list = [visual_mesh.to_legacy() for visual_mesh in visual_mesh_list] - o3d.visualization.draw_geometries(visual_mesh_list) - - -def test_antipodal_position_selector(is_visual: bool = False): - mesh_path = get_data_path("ChainRainSec/mesh.ply") - mesh_o3dt = o3d.t.io.read_triangle_mesh(mesh_path) - generator = AntipodalGenerator( - mesh_o3dt=mesh_o3dt, - open_length=0.1, - max_angle=np.pi / 6, - surface_sample_num=5000, - cache_dir=None, - ) - grasp_list = generator.select_grasp( - approach_direction=np.array([0, 0, -1]), - select_num=5, - select_method=GraspSelectMethod.NEAR_APPROACH, - ) - assert len(grasp_list) == 5 - if is_visual: - visual_mesh_list = generator.grasp_pose_visual(grasp_list) - visual_mesh_list = [visual_mesh.to_legacy() for visual_mesh in visual_mesh_list] - o3d.visualization.draw_geometries(visual_mesh_list) - - -def test_antipodal_center_selector(is_visual: bool = False): - mesh_path = get_data_path("ChainRainSec/mesh.ply") - mesh_o3dt = o3d.t.io.read_triangle_mesh(mesh_path) - generator = AntipodalGenerator( - mesh_o3dt=mesh_o3dt, - open_length=0.1, - max_angle=np.pi / 6, - surface_sample_num=5000, - cache_dir=None, - ) - grasp_list = generator.select_grasp( - approach_direction=np.array([0, 0, -1]), - select_num=5, - select_method=GraspSelectMethod.CENTER, - ) - assert len(grasp_list) == 5 - if is_visual: - visual_mesh_list = generator.grasp_pose_visual(grasp_list) - visual_mesh_list = [visual_mesh.to_legacy() for visual_mesh in visual_mesh_list] - o3d.visualization.draw_geometries(visual_mesh_list) - - -if __name__ == "__main__": - test_antipodal_score_selector(True) - test_antipodal_position_selector(True) - test_antipodal_center_selector(True) From 87b0095b0df242d1b19d3a64c55467d8e8bed426 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Thu, 2 Apr 2026 11:26:08 +0800 Subject: [PATCH 22/25] wip --- .../features/toolkits/grasp_generator.rst | 4 +- embodichain/lab/sim/objects/rigid_object.py | 70 ------------------- embodichain/lab/sim/sim_manager.py | 2 - .../graspkit/pg_grasp/antipodal_generator.py | 16 ++++- scripts/tutorials/grasp/grasp_generator.py | 29 ++++++-- 5 files changed, 40 insertions(+), 81 deletions(-) diff --git a/docs/source/features/toolkits/grasp_generator.rst b/docs/source/features/toolkits/grasp_generator.rst index aac49ddc..2cb846d3 100644 --- a/docs/source/features/toolkits/grasp_generator.rst +++ b/docs/source/features/toolkits/grasp_generator.rst @@ -41,13 +41,13 @@ The parsed arguments are passed to ``initialize_simulation``, which builds a :cl Annotating and computing grasp poses ------------------------------------- -Grasp generation is performed by :meth:`objects.RigidObject.get_grasp_pose`, which internally runs an antipodal sampler on the object mesh. A :class:`toolkits.graspkit.pg_grasp.GraspGeneratorCfg` controls sampler parameters (sample count, gripper jaw limits) and the interactive annotation workflow: +Grasp generation is performed by :class:`toolkits.graspkit.pg_grasp.GraspGenerator`, which runs an antipodal sampler on the object mesh. The mesh data (vertices and triangles) is extracted from the :class:`objects.RigidObject` via its accessor methods. A :class:`toolkits.graspkit.pg_grasp.GraspGeneratorCfg` controls sampler parameters (sample count, gripper jaw limits) and the interactive annotation workflow: 1. Open the visualization in a browser at the reported port (e.g. ``http://localhost:11801``). 2. Use *Rect Select Region* to highlight the area of the object that should be grasped. 3. Click *Confirm Selection* to finalize the region. -The function returns a batch of ``(N_envs, 4, 4)`` homogeneous transformation matrices representing candidate grasp frames in the world coordinate system. +For each environment, a grasp pose is computed by calling :meth:`toolkits.graspkit.pg_grasp.GraspGenerator.get_grasp_poses` with the object pose and desired approach direction. The result is a ``(4, 4)`` homogeneous transformation matrix representing the grasp frame in world coordinates. For each grasp pose, gripper approach direction in world coordinate is required to compute the antipodal grasp. In this tutorial, we use a fixed approach direction (straight down in world frame) for simplicity, but it can be customized based on the task or object geometry. diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index 41dc2f83..24de293b 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -35,13 +35,6 @@ from embodichain.utils.math import convert_quat from embodichain.utils.math import matrix_from_quat, quat_from_matrix, matrix_from_euler from embodichain.utils import logger -from embodichain.toolkits.graspkit.pg_grasp import ( - GraspGenerator, - GraspGeneratorCfg, -) -from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import ( - GripperCollisionCfg, -) @dataclass @@ -1155,66 +1148,3 @@ def destroy(self) -> None: arenas = [env] for i, entity in enumerate(self._entities): arenas[i].remove_actor(entity) - - def get_grasp_pose( - self, - cfg: GraspGeneratorCfg = GraspGeneratorCfg(), - gripper_collision_cfg: GripperCollisionCfg = GripperCollisionCfg(), - approach_direction: torch.Tensor = None, - is_visual: bool = False, - ) -> torch.Tensor: - if approach_direction is None: - approach_direction = torch.tensor( - [0, 0, -1], dtype=torch.float32, device=self.device - ) - approach_direction = torch.nn.functional.normalize(approach_direction, dim=-1) - if hasattr(self, "_grasp_annotator") is False: - vertices = torch.tensor( - self._entities[0].get_vertices(), - dtype=torch.float32, - device=self.device, - ) - triangles = torch.tensor( - self._entities[0].get_triangles(), dtype=torch.int32, device=self.device - ) - scale = torch.tensor( - self._entities[0].get_body_scale(), - dtype=torch.float32, - device=self.device, - ) - vertices = vertices * scale - self._grasp_annotator = GraspGenerator( - vertices=vertices, - triangles=triangles, - cfg=cfg, - gripper_collision_cfg=gripper_collision_cfg, - ) - - # Annotate antipodal point pairs - if hasattr(self, "_hit_point_pairs") is False or cfg.force_regenerate: - self._hit_point_pairs = self._grasp_annotator.annotate() - - poses = self.get_local_pose(to_matrix=True) - grasp_poses: tuple[torch.Tensor] = [] - open_lengths: tuple[torch.Tensor] = [] - for pose in poses: - grasp_pose, open_length = self._grasp_annotator.get_grasp_poses( - self._hit_point_pairs, pose, approach_direction, is_visual=False - ) - grasp_poses.append(grasp_pose) - open_lengths.append(open_length) - grasp_poses = torch.cat( - [grasp_pose.unsqueeze(0) for grasp_pose in grasp_poses], dim=0 - ) - - if is_visual: - vertices = self._entities[0].get_vertices() - triangles = self._entities[0].get_triangles() - scale = self._entities[0].get_body_scale() - vertices = vertices * scale - self._grasp_annotator.visualize_grasp_pose( - obj_pose=poses[0], - grasp_pose=grasp_poses[0], - open_length=open_lengths[0].item(), - ) - return grasp_poses diff --git a/embodichain/lab/sim/sim_manager.py b/embodichain/lab/sim/sim_manager.py index 5e3e2880..6a8bf537 100644 --- a/embodichain/lab/sim/sim_manager.py +++ b/embodichain/lab/sim/sim_manager.py @@ -35,7 +35,6 @@ MATERIAL_CACHE_DIR = SIM_CACHE_DIR / "mat_cache" CONVEX_DECOMP_DIR = SIM_CACHE_DIR / "convex_decomposition" REACHABLE_XPOS_DIR = SIM_CACHE_DIR / "robot_reachable_xpos" -GRASP_ANNOTATOR_CACHE_DIR = SIM_CACHE_DIR / "grasp_annotator_cache" from dexsim.types import ( Backend, @@ -90,7 +89,6 @@ "MATERIAL_CACHE_DIR", "CONVEX_DECOMP_DIR", "REACHABLE_XPOS_DIR", - "GRASP_ANNOTATOR_CACHE_DIR", ] diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py index a47efbb7..cdb1bf7d 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py @@ -41,6 +41,11 @@ GripperCollisionCfg, ) +GRASP_ANNOTATOR_CACHE_DIR = ( + Path.home() / ".cache" / "embodichain" / "grasp_annotator_cache" +) +GRASP_ANNOTATOR_CACHE_DIR.mkdir(parents=True, exist_ok=True) + __all__ = ["GraspGenerator", "GraspGeneratorCfg"] @@ -386,8 +391,6 @@ def _(_evt: viser.GuiEvent) -> None: return self._hit_point_pairs def _get_cache_dir(self, vertices: torch.Tensor, triangles: torch.Tensor): - from embodichain.lab.sim.sim_manager import GRASP_ANNOTATOR_CACHE_DIR - vert_bytes = vertices.to("cpu").numpy().tobytes() face_bytes = triangles.to("cpu").numpy().tobytes() md5_hash = hashlib.md5(vert_bytes + face_bytes).hexdigest() @@ -584,6 +587,7 @@ def get_grasp_poses( object_pose: torch.Tensor, approach_direction: torch.Tensor, is_visual: bool = False, + visualize: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """Get grasp pose given approach direction. @@ -600,6 +604,8 @@ def get_grasp_poses( approach_direction: ``(3,)`` unit vector representing the desired approach direction of the gripper in the world frame. is_visual: If ``True``, enable visual collision checking. + visualize: If ``True``, visualize the best grasp pose using Open3D + after computation. Returns: A tuple ``(best_grasp_pose, best_open_length)`` where @@ -667,6 +673,12 @@ def get_grasp_poses( best_idx = torch.argmin(total_cost) best_grasp_pose = valid_grasp_poses[best_idx] best_open_length = valid_open_lengths[best_idx] + if visualize: + self.visualize_grasp_pose( + obj_pose=object_pose, + grasp_pose=best_grasp_pose, + open_length=best_open_length.item(), + ) return best_grasp_pose, best_open_length @staticmethod diff --git a/scripts/tutorials/grasp/grasp_generator.py b/scripts/tutorials/grasp/grasp_generator.py index 75a88eb6..0875e7bf 100644 --- a/scripts/tutorials/grasp/grasp_generator.py +++ b/scripts/tutorials/grasp/grasp_generator.py @@ -40,6 +40,7 @@ URDFCfg, ) from embodichain.toolkits.graspkit.pg_grasp.antipodal_generator import ( + GraspGenerator, GraspGeneratorCfg, AntipodalSamplerCfg, ) @@ -251,14 +252,32 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso gripper_collision_cfg = GripperCollisionCfg( max_open_length=0.088, finger_length=0.078, point_sample_dense=0.012 ) - grasp_xpos = mug.get_grasp_pose( - approach_direction=torch.tensor( - [0, 0, -1], dtype=torch.float32, device=sim.device - ), # gripper approach direction in the world frame + + # Extract mesh data from the mug and create grasp generator + vertices = mug.get_vertices(env_ids=[0], scale=True)[0] + triangles = mug.get_triangles(env_ids=[0])[0] + grasp_generator = GraspGenerator( + vertices=vertices, + triangles=triangles, cfg=grasp_cfg, gripper_collision_cfg=gripper_collision_cfg, - is_visual=False, # visualize selected grasp pose finally ) + + # Annotate grasp region (populates internal antipodal point pairs) + grasp_generator.annotate() + + # Compute grasp poses per environment + approach_direction = torch.tensor( + [0, 0, -1], dtype=torch.float32, device=sim.device + ) + poses = mug.get_local_pose(to_matrix=True) + grasp_xpos_list = [] + for pose in poses: + grasp_pose, _ = grasp_generator.get_grasp_poses( + pose, approach_direction, visualize=False + ) + grasp_xpos_list.append(grasp_pose.unsqueeze(0)) + grasp_xpos = torch.cat(grasp_xpos_list, dim=0) cost_time = time.time() - start_time logger.log_info(f"Get grasp pose cost time: {cost_time:.2f} seconds") From e2151f954dd140adcff213665148f0e48b4ad119 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Thu, 2 Apr 2026 11:35:57 +0800 Subject: [PATCH 23/25] wip --- .../embodichain/embodichain.toolkits.rst | 86 +++++++++++++++++-- .../features/toolkits/grasp_generator.rst | 45 ++++++++-- 2 files changed, 118 insertions(+), 13 deletions(-) diff --git a/docs/source/api_reference/embodichain/embodichain.toolkits.rst b/docs/source/api_reference/embodichain/embodichain.toolkits.rst index cc2639d5..d773b9f9 100644 --- a/docs/source/api_reference/embodichain/embodichain.toolkits.rst +++ b/docs/source/api_reference/embodichain/embodichain.toolkits.rst @@ -1,4 +1,4 @@ -embodichain.toolkits +embodichain.toolkits ==================== .. automodule:: embodichain.toolkits @@ -11,12 +11,88 @@ urdf_assembly -GraspKit --------- +GraspKit — Parallel-Gripper Grasp Sampling +------------------------------------------- -.. automodule:: embodichain.toolkits.graspkit +The ``embodichain.toolkits.graspkit.pg_grasp`` module provides a complete pipeline for generating antipodal grasp poses for parallel-jaw grippers. The pipeline consists of three stages: + +1. **Antipodal sampling** — Surface points are uniformly sampled on the mesh and rays are cast to find antipodal point pairs on opposite sides. +2. **Pose construction** — For each antipodal pair, a 6-DoF grasp frame is built aligned with the approach direction. +3. **Filtering & ranking** — Grasp candidates that cause the gripper to collide with the object are discarded; survivors are scored by a weighted cost. + +.. rubric:: Public API + +.. currentmodule:: embodichain.toolkits.graspkit.pg_grasp + +The main entry point is :class:`GraspGenerator`. It is configured via :class:`GraspGeneratorCfg` and :class:`GripperCollisionCfg`. + +.. autosummary:: + :nosignatures: + + GraspGenerator + GraspGeneratorCfg + AntipodalSampler + AntipodalSamplerCfg + GripperCollisionChecker + GripperCollisionCfg + ConvexCollisionChecker + ConvexCollisionCheckerCfg + + +GraspGenerator +~~~~~~~~~~~~~~~ + +.. autoclass:: GraspGenerator + :members: generate, annotate, get_grasp_poses, visualize_grasp_pose + :show-inheritance: + +GraspGeneratorCfg +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: GraspGeneratorCfg + :members: + :show-inheritance: + +AntipodalSampler +~~~~~~~~~~~~~~~~~ + +.. autoclass:: AntipodalSampler + :members: sample + :show-inheritance: + +AntipodalSamplerCfg +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: AntipodalSamplerCfg + :members: + :show-inheritance: + +GripperCollisionChecker +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: GripperCollisionChecker + :members: query + :show-inheritance: + +GripperCollisionCfg +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: GripperCollisionCfg + :members: + :show-inheritance: + +ConvexCollisionChecker +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: ConvexCollisionChecker + :members: query, query_batch + :show-inheritance: + +ConvexCollisionCheckerCfg +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: ConvexCollisionCheckerCfg :members: - :undoc-members: :show-inheritance: diff --git a/docs/source/features/toolkits/grasp_generator.rst b/docs/source/features/toolkits/grasp_generator.rst index 2cb846d3..3f9f2c81 100644 --- a/docs/source/features/toolkits/grasp_generator.rst +++ b/docs/source/features/toolkits/grasp_generator.rst @@ -38,28 +38,57 @@ The parsed arguments are passed to ``initialize_simulation``, which builds a :cl :start-at: def initialize_simulation(args) -> SimulationManager: :end-at: return sim +Creating a robot and a target object +------------------------------------ + +A UR10 arm with a parallel-jaw gripper is created via :meth:`SimulationManager.add_robot`. The gripper URDF and drive properties are configured so that the arm joints and finger joints can be controlled independently. + +.. literalinclude:: ../../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :start-at: def create_robot(sim: SimulationManager + :end-at: return sim.add_robot(cfg=cfg) + +The target object (a mug) is loaded as a :class:`objects.RigidObject` from a PLY mesh file: + +.. literalinclude:: ../../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :start-at: def create_mug(sim: SimulationManager): + :end-at: return mug + Annotating and computing grasp poses ------------------------------------- -Grasp generation is performed by :class:`toolkits.graspkit.pg_grasp.GraspGenerator`, which runs an antipodal sampler on the object mesh. The mesh data (vertices and triangles) is extracted from the :class:`objects.RigidObject` via its accessor methods. A :class:`toolkits.graspkit.pg_grasp.GraspGeneratorCfg` controls sampler parameters (sample count, gripper jaw limits) and the interactive annotation workflow: +Grasp generation is performed by :class:`~embodichain.toolkits.graspkit.pg_grasp.GraspGenerator`, which runs an antipodal sampler on the object mesh. The mesh data (vertices and triangles) is extracted from the :class:`objects.RigidObject` via its accessor methods. A :class:`~embodichain.toolkits.graspkit.pg_grasp.GraspGeneratorCfg` controls sampler parameters (sample count, gripper jaw limits) and the interactive annotation workflow: 1. Open the visualization in a browser at the reported port (e.g. ``http://localhost:11801``). 2. Use *Rect Select Region* to highlight the area of the object that should be grasped. 3. Click *Confirm Selection* to finalize the region. -For each environment, a grasp pose is computed by calling :meth:`toolkits.graspkit.pg_grasp.GraspGenerator.get_grasp_poses` with the object pose and desired approach direction. The result is a ``(4, 4)`` homogeneous transformation matrix representing the grasp frame in world coordinates. +After annotation, antipodal point pairs are cached to disk and automatically reused unless ``force_regenerate`` is set. + +For each environment, a grasp pose is computed by calling :meth:`~embodichain.toolkits.graspkit.pg_grasp.GraspGenerator.get_grasp_poses` with the object pose and desired approach direction. The result is a ``(4, 4)`` homogeneous transformation matrix representing the grasp frame in world coordinates. Set ``visualize=True`` to open an Open3D window showing the selected grasp on the object. -For each grasp pose, gripper approach direction in world coordinate is required to compute the antipodal grasp. In this tutorial, we use a fixed approach direction (straight down in world frame) for simplicity, but it can be customized based on the task or object geometry. +The approach direction is the unit vector along which the gripper approaches the object. In this tutorial, we use a fixed approach direction (straight down in world frame) for simplicity, but it can be customized based on the task or object geometry. .. literalinclude:: ../../../../scripts/tutorials/grasp/grasp_generator.py :language: python - :start-at: # get mug grasp pose + :start-at: gripper_collision_cfg = GripperCollisionCfg( :end-at: logger.log_info(f"Get grasp pose cost time: {cost_time:.2f} seconds") +Building and executing the grasp trajectory +------------------------------------------- + +Once a grasp pose is obtained, a waypoint trajectory is built that moves the arm from its rest configuration to an approach pose (offset above the grasp), down to the grasp pose, closes the fingers, lifts, and returns. The trajectory is interpolated for smooth motion and executed step-by-step in the simulation loop. + +.. literalinclude:: ../../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :start-at: def get_grasp_traj(sim: SimulationManager + :end-at: return interp_trajectory + Configuring GraspGeneratorCfg ------------------------------ -:class:`toolkits.graspkit.pg_grasp.GraspGeneratorCfg` controls the overall grasp annotation workflow. The key parameters are listed below. +:class:`~embodichain.toolkits.graspkit.pg_grasp.GraspGeneratorCfg` controls the overall grasp annotation workflow. The key parameters are listed below. .. list-table:: GraspGeneratorCfg parameters :header-rows: 1 @@ -84,7 +113,7 @@ Configuring GraspGeneratorCfg - ``π / 12`` - Maximum allowed angle (in radians) between the specified approach direction and the axis connecting an antipodal point pair. Pairs that deviate more than this threshold are discarded. -The ``antipodal_sampler_cfg`` field accepts an :class:`toolkits.graspkit.pg_grasp.AntipodalSamplerCfg` instance, which controls how antipodal point pairs are sampled on the mesh surface. +The ``antipodal_sampler_cfg`` field accepts an :class:`~embodichain.toolkits.graspkit.pg_grasp.AntipodalSamplerCfg` instance, which controls how antipodal point pairs are sampled on the mesh surface. .. list-table:: AntipodalSamplerCfg parameters :header-rows: 1 @@ -107,9 +136,9 @@ The ``antipodal_sampler_cfg`` field accepts an :class:`toolkits.graspkit.pg_gras - Minimum allowed distance (in metres) between an antipodal pair. Pairs closer together than this value are discarded to avoid degenerate or self-intersecting grasps. Configuring GripperCollisionCfg --------------------------------------- +------------------------------- -:class:`toolkits.graspkit.pg_grasp.GripperCollisionCfg` models the geometry of a parallel-jaw gripper as a point cloud and is used to filter out grasp candidates that would collide with the object. All length parameters are in metres. +:class:`~embodichain.toolkits.graspkit.pg_grasp.GripperCollisionCfg` models the geometry of a parallel-jaw gripper as a point cloud and is used to filter out grasp candidates that would collide with the object. All length parameters are in metres. .. list-table:: GripperCollisionCfg parameters :header-rows: 1 From c862be5d46fce3587f4571537458506ecd7d8fa8 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Thu, 2 Apr 2026 12:24:05 +0800 Subject: [PATCH 24/25] wip --- .../features/toolkits/grasp_generator.rst | 56 +++++++ embodichain/__main__.py | 12 ++ .../toolkits/graspkit/scripts/__init__.py | 15 ++ .../graspkit/scripts/annotate_grasp.py | 138 ++++++++++++++++++ 4 files changed, 221 insertions(+) create mode 100644 embodichain/toolkits/graspkit/scripts/__init__.py create mode 100644 embodichain/toolkits/graspkit/scripts/annotate_grasp.py diff --git a/docs/source/features/toolkits/grasp_generator.rst b/docs/source/features/toolkits/grasp_generator.rst index 3f9f2c81..d5a8123a 100644 --- a/docs/source/features/toolkits/grasp_generator.rst +++ b/docs/source/features/toolkits/grasp_generator.rst @@ -194,3 +194,59 @@ You can customize the run with additional arguments: python scripts/tutorials/grasp/grasp_generator.py --num_envs --device --enable_rt --headless After confirming the grasp region in the browser, the script will compute a grasp pose, print the elapsed time, and then wait for you to press **Enter** before executing the full grasp trajectory in the simulation. Press **Enter** again to exit once the motion is complete. + + +Grasp Annotation CLI +~~~~~~~~~~~~~~~~~~~~ + +EmbodiChain provides a dedicated CLI for interactively annotating grasp regions on a mesh and caching the resulting antipodal point pairs, without requiring a full simulation environment. + +Basic usage:: + + python -m embodichain annotate-grasp --mesh_path /path/to/object.ply + +This will: + +1. Load the mesh file via ``trimesh``. +2. Launch a browser-based annotator (default port ``15531``). +3. Open http://localhost:15531 in your browser, use *Rect Select Region* to highlight the graspable area, then click *Confirm Selection*. +4. Compute antipodal point pairs on the selected region and cache them to disk. + +Common options:: + + python -m embodichain annotate-grasp \ + --mesh_path /path/to/object.ply \ + --viser_port 15531 \ + --n_sample 20000 \ + --max_length 0.1 \ + --min_length 0.001 \ + --force_regenerate + +.. list-table:: CLI options + :header-rows: 1 + :widths: 25 15 60 + + * - Option + - Default + - Description + * - ``--mesh_path`` + - *(required)* + - Path to the mesh file (``.ply``, ``.obj``, ``.stl``, etc.). + * - ``--viser_port`` + - ``15531`` + - Port for the browser-based annotation UI. + * - ``--n_sample`` + - ``20000`` + - Number of surface points to sample for antipodal pair detection. + * - ``--max_length`` + - ``0.1`` + - Maximum distance (metres) between antipodal pairs; should match the gripper's maximum opening width. + * - ``--min_length`` + - ``0.001`` + - Minimum distance (metres) between antipodal pairs; filters out degenerate pairs. + * - ``--force_regenerate`` + - ``False`` + - Force re-annotation and ignore any cached antipodal pairs. + * - ``--device`` + - ``cpu`` + - Compute device (``cpu`` or ``cuda``). diff --git a/embodichain/__main__.py b/embodichain/__main__.py index 5125296b..522ca48f 100644 --- a/embodichain/__main__.py +++ b/embodichain/__main__.py @@ -20,6 +20,7 @@ python -m embodichain preview-asset --asset_path /path/to/asset.usda --preview python -m embodichain run-env --env_name my_env + python -m embodichain annotate-grasp --mesh_path /path/to/object.ply """ from __future__ import annotations @@ -63,6 +64,17 @@ def main() -> None: run_env_parser.set_defaults(func=run_env_cli) + # -- annotate-grasp ------------------------------------------------------ + annotate_grasp_parser = subparsers.add_parser( + "annotate-grasp", + help="Interactively annotate grasp region on a mesh.", + ) + from embodichain.toolkits.graspkit.scripts.annotate_grasp import ( + cli as annotate_grasp_cli, + ) + + annotate_grasp_parser.set_defaults(func=annotate_grasp_cli) + # -- Parse --------------------------------------------------------------- # If no sub-command is given, print help and exit. if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"): diff --git a/embodichain/toolkits/graspkit/scripts/__init__.py b/embodichain/toolkits/graspkit/scripts/__init__.py new file mode 100644 index 00000000..dd650e90 --- /dev/null +++ b/embodichain/toolkits/graspkit/scripts/__init__.py @@ -0,0 +1,15 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- diff --git a/embodichain/toolkits/graspkit/scripts/annotate_grasp.py b/embodichain/toolkits/graspkit/scripts/annotate_grasp.py new file mode 100644 index 00000000..435aacaf --- /dev/null +++ b/embodichain/toolkits/graspkit/scripts/annotate_grasp.py @@ -0,0 +1,138 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""CLI for interactive grasp region annotation on a mesh. + +Loads a mesh file via *trimesh*, launches a browser-based annotator so the +user can select the graspable region, and saves the resulting antipodal +point pairs to the grasp-annotator cache. + +Usage examples:: + + python -m embodichain annotate-grasp --mesh_path /path/to/object.ply + python -m embodichain annotate-grasp --mesh_path mug.obj --force_regenerate +""" + +from __future__ import annotations + +import argparse + +import torch +import trimesh + +from embodichain.toolkits.graspkit.pg_grasp import ( + AntipodalSamplerCfg, + GraspGenerator, + GraspGeneratorCfg, +) +from embodichain.utils.logger import log_info + + +def cli() -> None: + """Command-line interface for grasp pose annotation. + + Parses CLI arguments, loads the mesh, and launches interactive + annotation via the Viser browser UI. + """ + parser = argparse.ArgumentParser( + description=( + "Interactively annotate a grasp region on a mesh and " + "compute antipodal point pairs." + ), + ) + + parser.add_argument( + "--mesh_path", + type=str, + required=True, + help="Path to the mesh file (e.g. .ply, .obj, .stl).", + ) + parser.add_argument( + "--viser_port", + type=int, + default=15531, + help="Port for the browser-based annotation UI (default: 15531).", + ) + parser.add_argument( + "--n_sample", + type=int, + default=20000, + help="Number of surface points to sample (default: 20000).", + ) + parser.add_argument( + "--max_length", + type=float, + default=0.1, + help="Maximum distance between antipodal pairs in metres (default: 0.1).", + ) + parser.add_argument( + "--min_length", + type=float, + default=0.001, + help="Minimum distance between antipodal pairs in metres (default: 0.001).", + ) + parser.add_argument( + "--force_regenerate", + action="store_true", + default=False, + help="Force re-annotation, ignoring cached antipodal pairs.", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="Compute device, e.g. 'cpu' or 'cuda' (default: cpu).", + ) + + args = parser.parse_args() + + # Load mesh via trimesh + log_info(f"Loading mesh from {args.mesh_path}", color="green") + mesh = trimesh.load(args.mesh_path, force="mesh") + vertices = torch.tensor(mesh.vertices, dtype=torch.float32, device=args.device) + triangles = torch.tensor(mesh.faces, dtype=torch.int64, device=args.device) + + # Build configuration + sampler_cfg = AntipodalSamplerCfg( + n_sample=args.n_sample, + max_length=args.max_length, + min_length=args.min_length, + ) + cfg = GraspGeneratorCfg( + viser_port=args.viser_port, + antipodal_sampler_cfg=sampler_cfg, + force_regenerate=args.force_regenerate, + ) + + # Create generator and run annotation + generator = GraspGenerator(vertices=vertices, triangles=triangles, cfg=cfg) + log_info( + "Annotate the grasp region in the browser window:\n" + " 1. Open http://localhost:{port}\n" + " 2. Click 'Rect Select Region' and drag to select\n" + " 3. Click 'Confirm Selection' to finish", + color="green", + ) + hit_point_pairs = generator.annotate() + + log_info( + f"Annotation complete. {hit_point_pairs.shape[0]} antipodal pairs cached.", + color="green", + ) + + +if __name__ == "__main__": + cli() From 762775601531d59b6f37646e56b74fcadfea8b87 Mon Sep 17 00:00:00 2001 From: chenjian Date: Thu, 2 Apr 2026 15:52:28 +0800 Subject: [PATCH 25/25] update docs; remove deprecated cfg --- .../features/toolkits/grasp_generator.rst | 12 +------ .../graspkit/pg_grasp/antipodal_generator.py | 31 ++++++++----------- .../pg_grasp/gripper_collision_checker.py | 6 ---- .../graspkit/scripts/annotate_grasp.py | 9 +----- scripts/tutorials/grasp/grasp_generator.py | 7 ++--- 5 files changed, 18 insertions(+), 47 deletions(-) diff --git a/docs/source/features/toolkits/grasp_generator.rst b/docs/source/features/toolkits/grasp_generator.rst index d5a8123a..ba77e77b 100644 --- a/docs/source/features/toolkits/grasp_generator.rst +++ b/docs/source/features/toolkits/grasp_generator.rst @@ -64,7 +64,7 @@ Grasp generation is performed by :class:`~embodichain.toolkits.graspkit.pg_grasp 2. Use *Rect Select Region* to highlight the area of the object that should be grasped. 3. Click *Confirm Selection* to finalize the region. -After annotation, antipodal point pairs are cached to disk and automatically reused unless ``force_regenerate`` is set. +After annotation, antipodal point pairs are cached to disk and automatically reused unless user call `GraspGenerator.annotate()`. For each environment, a grasp pose is computed by calling :meth:`~embodichain.toolkits.graspkit.pg_grasp.GraspGenerator.get_grasp_poses` with the object pose and desired approach direction. The result is a ``(4, 4)`` homogeneous transformation matrix representing the grasp frame in world coordinates. Set ``visualize=True`` to open an Open3D window showing the selected grasp on the object. @@ -106,9 +106,6 @@ Configuring GraspGeneratorCfg * - ``antipodal_sampler_cfg`` - ``AntipodalSamplerCfg()`` - Nested configuration for the antipodal point sampler. See the table below for its parameters. - * - ``force_regenerate`` - - ``False`` - - When ``True``, the user is required to annotate the grasp region every time, bypassing any cached results from a previous run. * - ``max_deviation_angle`` - ``π / 12`` - Maximum allowed angle (in radians) between the specified approach direction and the axis connecting an antipodal point pair. Pairs that deviate more than this threshold are discarded. @@ -162,9 +159,6 @@ Configuring GripperCollisionCfg * - ``root_z_width`` - ``0.08`` - Extent of the gripper root block along the Z-axis. - * - ``device`` - - ``cpu`` - - PyTorch device on which the gripper point cloud is generated and processed. Set to ``cuda`` when GPU-accelerated collision checking is required. * - ``point_sample_dense`` - ``0.01`` - Approximate number of sample points per unit length along each edge of the gripper point cloud. Higher values produce denser point clouds and improve collision-check accuracy at the cost of additional computation. @@ -220,7 +214,6 @@ Common options:: --n_sample 20000 \ --max_length 0.1 \ --min_length 0.001 \ - --force_regenerate .. list-table:: CLI options :header-rows: 1 @@ -244,9 +237,6 @@ Common options:: * - ``--min_length`` - ``0.001`` - Minimum distance (metres) between antipodal pairs; filters out degenerate pairs. - * - ``--force_regenerate`` - - ``False`` - - Force re-annotation and ignore any cached antipodal pairs. * - ``--device`` - ``cpu`` - Compute device (``cpu`` or ``cuda``). diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py index cdb1bf7d..f6389ff8 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py @@ -73,10 +73,6 @@ class GraspGeneratorCfg: number of sampled surface points, ray perturbation angle, and gripper jaw distance limits. See :class:`AntipodalSamplerCfg` for details.""" - force_regenerate: bool = False - """When ``True``, the user is required to annotate the grasp region every - time, bypassing any cached results from a previous run.""" - max_deviation_angle: float = np.pi / 12 """Maximum allowed angle (in radians) between the specified approach direction and the axis connecting an antipodal point pair. Pairs that @@ -105,9 +101,6 @@ class GraspGenerator: the approach direction, narrow opening length, and distance to the mesh centroid. - Antipodal pairs are cached to disk (keyed on mesh geometry) and - automatically reused across sessions unless ``force_regenerate`` is set. - Typical usage:: generator = GraspGenerator(vertices, triangles, cfg=cfg) @@ -156,7 +149,7 @@ def __init__( # Load cached antipodal pairs for the whole mesh if available. cache_path = self._get_cache_dir(self.vertices, self.triangles) - if os.path.exists(cache_path) and not self.cfg.force_regenerate: + if os.path.exists(cache_path): logger.log_info(f"Found cached antipodal pairs at {cache_path}. Loading.") self._hit_point_pairs = torch.tensor( np.load(cache_path), dtype=torch.float32, device=self.device @@ -167,14 +160,14 @@ def generate( vertex_indices: torch.Tensor | None = None, face_indices: torch.Tensor | None = None, ) -> torch.Tensor: - """Generate antipodal point pairs for grasping on the given mesh region. + """ + Generate antipodal point pairs for grasping on the given mesh region. Exactly one of ``vertex_indices`` or ``face_indices`` must be provided to define the grasp region. When both are ``None``, the whole mesh is used. - Results are cached to disk and reused when ``force_regenerate`` is - ``False``. + Results are cached to disk. Args: vertex_indices: 1-D ``torch.Tensor`` of vertex indices defining the @@ -226,7 +219,7 @@ def generate( ) cache_path = self._get_cache_dir(sub_vertices, sub_faces) - if os.path.exists(cache_path) and not self.cfg.force_regenerate: + if os.path.exists(cache_path): logger.log_info(f"Found cached antipodal pairs at {cache_path}") return torch.tensor( np.load(cache_path), dtype=torch.float32, device=self.device @@ -354,6 +347,8 @@ def _(event: viser.ScenePointerEvent) -> None: torch.tensor(sel_vertices, device=self.device), torch.tensor(sel_faces, device=self.device), ) + + # for visualization only extended_hit_point_pairs = GraspGenerator._extend_hit_point_pairs( hit_point_pairs ) @@ -586,8 +581,8 @@ def get_grasp_poses( self, object_pose: torch.Tensor, approach_direction: torch.Tensor, - is_visual: bool = False, - visualize: bool = False, + visualize_collision: bool = False, + visualize_pose: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """Get grasp pose given approach direction. @@ -603,8 +598,8 @@ def get_grasp_poses( representing the pose of the object in the world frame. approach_direction: ``(3,)`` unit vector representing the desired approach direction of the gripper in the world frame. - is_visual: If ``True``, enable visual collision checking. - visualize: If ``True``, visualize the best grasp pose using Open3D + visualize_collision: If ``True``, enable visual collision checking. + visualize_pose: If ``True``, visualize the best grasp pose using Open3D after computation. Returns: @@ -652,7 +647,7 @@ def get_grasp_poses( object_pose, valid_grasp_poses, valid_open_lengths, - is_visual=is_visual, + is_visual=visualize_collision, collision_threshold=0.0, ) # get best grasp pose @@ -673,7 +668,7 @@ def get_grasp_poses( best_idx = torch.argmin(total_cost) best_grasp_pose = valid_grasp_poses[best_idx] best_open_length = valid_open_lengths[best_idx] - if visualize: + if visualize_pose: self.visualize_grasp_pose( obj_pose=object_pose, grasp_pose=best_grasp_pose, diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py index d19aa332..5f02176c 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -63,12 +63,6 @@ class GripperCollisionCfg: according to the specific gripper being modeled, and it defines how far the root extends along the Z direction. """ - device = torch.device("cpu") - """ Device on which the gripper point cloud will be generated and processed. This should be set according to - the computational resources available and the requirements of the application. For example, if using a GPU for collision - checking, this should be set to torch.device('cuda'). - """ - point_sample_dense: float = 0.01 """ Approximate number of points per unit length for the gripper point cloud. Higher values will yield denser point clouds, which can improve collision checking accuracy but also increase computational cost. This should be set based diff --git a/embodichain/toolkits/graspkit/scripts/annotate_grasp.py b/embodichain/toolkits/graspkit/scripts/annotate_grasp.py index 435aacaf..3e1c9211 100644 --- a/embodichain/toolkits/graspkit/scripts/annotate_grasp.py +++ b/embodichain/toolkits/graspkit/scripts/annotate_grasp.py @@ -23,7 +23,7 @@ Usage examples:: python -m embodichain annotate-grasp --mesh_path /path/to/object.ply - python -m embodichain annotate-grasp --mesh_path mug.obj --force_regenerate + python -m embodichain annotate-grasp --mesh_path mug.obj """ from __future__ import annotations @@ -84,12 +84,6 @@ def cli() -> None: default=0.001, help="Minimum distance between antipodal pairs in metres (default: 0.001).", ) - parser.add_argument( - "--force_regenerate", - action="store_true", - default=False, - help="Force re-annotation, ignoring cached antipodal pairs.", - ) parser.add_argument( "--device", type=str, @@ -114,7 +108,6 @@ def cli() -> None: cfg = GraspGeneratorCfg( viser_port=args.viser_port, antipodal_sampler_cfg=sampler_cfg, - force_regenerate=args.force_regenerate, ) # Create generator and run annotation diff --git a/scripts/tutorials/grasp/grasp_generator.py b/scripts/tutorials/grasp/grasp_generator.py index 0875e7bf..bab09c03 100644 --- a/scripts/tutorials/grasp/grasp_generator.py +++ b/scripts/tutorials/grasp/grasp_generator.py @@ -238,7 +238,6 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso antipodal_sampler_cfg=AntipodalSamplerCfg( n_sample=20000, max_length=0.088, min_length=0.003 ), - force_regenerate=False, # force user to annotate grasp region each time ) sim.open_window() @@ -270,11 +269,11 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso approach_direction = torch.tensor( [0, 0, -1], dtype=torch.float32, device=sim.device ) - poses = mug.get_local_pose(to_matrix=True) + obj_poses = mug.get_local_pose(to_matrix=True) grasp_xpos_list = [] - for pose in poses: + for obj_pose in obj_poses: grasp_pose, _ = grasp_generator.get_grasp_poses( - pose, approach_direction, visualize=False + obj_pose, approach_direction, visualize_pose=False ) grasp_xpos_list.append(grasp_pose.unsqueeze(0)) grasp_xpos = torch.cat(grasp_xpos_list, dim=0)