diff --git a/bayes3d/__init__.py b/bayes3d/__init__.py index 6eccde99..320aae4c 100644 --- a/bayes3d/__init__.py +++ b/bayes3d/__init__.py @@ -12,7 +12,7 @@ from .transforms_3d import * from .viz import * -RENDERER = None +RENDERER: "Renderer" = None __version__ = metadata.version("bayes3d") diff --git a/bayes3d/genjax/model.py b/bayes3d/genjax/model.py index 6f667197..911b7775 100644 --- a/bayes3d/genjax/model.py +++ b/bayes3d/genjax/model.py @@ -7,6 +7,7 @@ from genjax.incremental import Diff, NoChange, UnknownChange import bayes3d as b +import bayes3d.scene_graph from .genjax_distributions import ( contact_params_uniform, @@ -127,14 +128,14 @@ def get_far_plane(trace): def add_object(trace, key, obj_id, parent, face_parent, face_child): - N = b.get_indices(trace).shape[0] + 1 + N = get_indices(trace).shape[0] + 1 choices = trace.get_choices() choices[f"parent_{N-1}"] = parent choices[f"id_{N-1}"] = obj_id choices[f"face_parent_{N-1}"] = face_parent choices[f"face_child_{N-1}"] = face_child choices[f"contact_params_{N-1}"] = jnp.zeros(3) - return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[1] + return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[0] add_object_jit = jax.jit(add_object) @@ -151,7 +152,7 @@ def print_trace(trace): def viz_trace_meshcat(trace, colors=None): - b.clear() + b.clear_visualizer() b.show_cloud( "1", b.apply_transform_jit(trace["image"].reshape(-1, 3), trace["camera_pose"]) ) @@ -223,14 +224,14 @@ def enumerator(trace, key, *args): key, chm_builder(addresses, args, chm_args), argdiff_f(trace), - )[2] + )[0] def enumerator_with_weight(trace, key, *args): return trace.update( key, chm_builder(addresses, args, chm_args), argdiff_f(trace), - )[1:3] + )[0:2] def enumerator_score(trace, key, *args): return enumerator(trace, key, *args).get_score() @@ -301,4 +302,4 @@ def update_address(trace, key, address, value): key, genjax.choice_map({address: value}), tuple(map(lambda v: Diff(v, UnknownChange), trace.args)), - )[2] + )[0] diff --git a/bayes3d/scene_graph.py b/bayes3d/scene_graph.py index 6b704e44..39eba0ae 100644 --- a/bayes3d/scene_graph.py +++ b/bayes3d/scene_graph.py @@ -211,16 +211,16 @@ def relative_pose_from_edge( face_child, dims_child, ): - x, y, angle = contact_params - contact_transform = t3d.transform_from_pos(jnp.array([x, y, 0.0])).dot( - t3d.transform_from_axis_angle(jnp.array([1.0, 1.0, 0.0]), jnp.pi).dot( - t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) - ) - ) + contact_transform = contact_params_to_pose(contact_params) child_plane = get_contact_planes(dims_child)[face_child] return contact_transform.dot(jnp.linalg.inv(child_plane)) +def relative_pose_from_edge_pose(contact_pose, face_child, dims_child): + child_plane = get_contact_planes(dims_child)[face_child] + return contact_pose.dot(jnp.linalg.inv(child_plane)) + + relative_pose_from_edge_jit = jax.jit(relative_pose_from_edge) relative_pose_from_edge_parallel_jit = jax.jit( jax.vmap( diff --git a/bayes3d/viz/viz.py b/bayes3d/viz/viz.py index c611825f..29fa2774 100644 --- a/bayes3d/viz/viz.py +++ b/bayes3d/viz/viz.py @@ -66,11 +66,11 @@ def get_depth_image(image, min_val=None, max_val=None, remove_max=True): depth = np.array(image) if max_val is None: - max_val = depth.max() + max_val = depth[depth < depth.max()].max(initial=0.0) if not remove_max: max_val += 1 if min_val is None: - min_val = depth.min() + min_val = depth[depth > depth.min()].min(initial=0.0) mask = (depth < max_val) * (depth > min_val) depth[np.logical_not(mask)] = np.nan @@ -208,7 +208,7 @@ def scale_image(img, factor): return img.resize((int(w * factor), int(h * factor))) -def vstack_images(images, border=10): +def vstack_images(images, border=10, bg_color=(255, 255, 255)): """Stack images vertically. Args: @@ -224,7 +224,7 @@ def vstack_images(images, border=10): max_w = max(max_w, w) sum_h += h - full_image = Image.new("RGB", (max_w, sum_h), (255, 255, 255)) + full_image = Image.new("RGB", (max_w, sum_h), bg_color) running_h = 0 for img in images: w, h = img.size @@ -233,7 +233,7 @@ def vstack_images(images, border=10): return full_image -def hstack_images(images, border=10): +def hstack_images(images, border=10, bg_color=(255, 255, 255)): """Stack images horizontally. Args: @@ -249,7 +249,7 @@ def hstack_images(images, border=10): max_h = max(max_h, h) sum_w += w - full_image = Image.new("RGB", (sum_w, max_h), (255, 255, 255)) + full_image = Image.new("RGB", (sum_w, max_h), bg_color) running_w = 0 for img in images: w, h = img.size @@ -258,7 +258,7 @@ def hstack_images(images, border=10): return full_image -def hvstack_images(images, h, w, border=10): +def hvstack_images(images, h, w, border=10, bg_color=(255, 255, 255)): """Stack images in a grid. Args: @@ -274,10 +274,12 @@ def hvstack_images(images, h, w, border=10): images_to_vstack = [] for row_idx in range(h): - hstacked_row = hstack_images(images[row_idx * w : (row_idx + 1) * w]) + hstacked_row = hstack_images( + images[row_idx * w : (row_idx + 1) * w], border=border, bg_color=bg_color + ) images_to_vstack.append(hstacked_row) - return vstack_images(images_to_vstack) + return vstack_images(images_to_vstack, border=border, bg_color=bg_color) def multi_panel( diff --git a/pyproject.toml b/pyproject.toml index c67cc688..b84b3c0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ ] dependencies = [ "distinctipy", - "genjax==0.1.1", + "genjax>=0.2.0", "graphviz", "imageio", "matplotlib",