Skip to content

Commit 1dad458

Browse files
committed
add normal rerun viz
1 parent 09ad80a commit 1dad458

2 files changed

Lines changed: 30 additions & 6 deletions

File tree

b3d/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,22 @@ def update_choices_get_score(trace, key, addr_const, *values):
395395
enumerate_choices_get_scores, static_argnums=(2,)
396396
)
397397

398+
def unproject_depth(depth, renderer):
399+
"""Unprojects a depth image into a point cloud.
400+
401+
Args:
402+
depth (jnp.ndarray): The depth image. Shape (H, W)
403+
intrinsics (b.camera.Intrinsics): The camera intrinsics.
404+
Returns:
405+
jnp.ndarray: The point cloud. Shape (H, W, 3)
406+
"""
407+
mask = (depth < renderer.far) * (depth > renderer.near)
408+
depth = depth * mask + renderer.far * (1.0 - mask)
409+
y, x = jnp.mgrid[: depth.shape[0], : depth.shape[1]]
410+
x = (x - renderer.cx) / renderer.fx
411+
y = (y - renderer.cy) / renderer.fy
412+
point_cloud_image = jnp.stack([x, y, jnp.ones_like(x)], axis=-1) * depth[:, :, None]
413+
return point_cloud_image
398414

399415
def nn_background_segmentation(images):
400416
import torch

test/test_render_ycb_model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import jax.numpy as jnp
33
import trimesh
44
import b3d
5+
import rerun as rr
56

7+
PORT = 8812
8+
rr.init("real")
9+
rr.connect(addr=f"127.0.0.1:{PORT}")
610

711
def test_renderer_full(renderer):
812
mesh_path = os.path.join(
@@ -15,7 +19,7 @@ def test_renderer_full(renderer):
1519
object_library.add_trimesh(mesh)
1620

1721
pose = b3d.Pose.from_position_and_target(
18-
jnp.array([0.2, 0.2, 0.0]), jnp.array([0.0, 0.0, 0.0])
22+
jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0])
1923
).inv()
2024

2125
rgb, depth = renderer.render_attribute(
@@ -39,17 +43,21 @@ def test_renderer_normal_full(renderer):
3943
object_library.add_trimesh(mesh)
4044

4145
pose = b3d.Pose.from_position_and_target(
42-
jnp.array([0.2, 0.2, 0.0]), jnp.array([0.0, 0.0, 0.0])
46+
jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0])
4347
).inv()
4448

45-
_, _, normal = renderer.render_attribute_normal(
49+
rgb, depth, normal = renderer.render_attribute_normal(
4650
pose[None, ...],
4751
object_library.vertices,
4852
object_library.faces,
4953
jnp.array([[0, len(object_library.faces)]]),
5054
object_library.attributes,
5155
)
5256

53-
normal = jnp.abs(normal)
54-
b3d.get_rgb_pil_image(normal).save(b3d.get_root_path() / "assets/test_results/test_ycb_normal.png")
55-
assert normal.sum() > 0
57+
b3d.get_rgb_pil_image((normal+1)/2).save(b3d.get_root_path() / "assets/test_results/test_ycb_normal.png")
58+
59+
point_im = b3d.utils.unproject_depth(depth, renderer)
60+
rr.log("pc", rr.Points3D(point_im.reshape(-1,3), colors=rgb.reshape(-1,3)))
61+
rr.log("arrows", rr.Arrows3D(origins=point_im[::5,::5,:].reshape(-1,3), vectors=normal[::5,::5,:].reshape(-1,3)/100))
62+
63+
assert jnp.abs(normal).sum() > 0

0 commit comments

Comments
 (0)