77from genjax .incremental import Diff , NoChange , UnknownChange
88
99import bayes3d as b
10- import bayes3d .scene_graph
1110
1211from .genjax_distributions import (
1312 contact_params_uniform ,
@@ -128,14 +127,14 @@ def get_far_plane(trace):
128127
129128
130129def add_object (trace , key , obj_id , parent , face_parent , face_child ):
131- N = get_indices (trace ).shape [0 ] + 1
130+ N = b . get_indices (trace ).shape [0 ] + 1
132131 choices = trace .get_choices ()
133132 choices [f"parent_{ N - 1 } " ] = parent
134133 choices [f"id_{ N - 1 } " ] = obj_id
135134 choices [f"face_parent_{ N - 1 } " ] = face_parent
136135 choices [f"face_child_{ N - 1 } " ] = face_child
137136 choices [f"contact_params_{ N - 1 } " ] = jnp .zeros (3 )
138- return model .importance (key , choices , (jnp .arange (N ), * trace .get_args ()[1 :]))[0 ]
137+ return model .importance (key , choices , (jnp .arange (N ), * trace .get_args ()[1 :]))[1 ]
139138
140139
141140add_object_jit = jax .jit (add_object )
@@ -152,7 +151,7 @@ def print_trace(trace):
152151
153152
154153def viz_trace_meshcat (trace , colors = None ):
155- b .clear_visualizer ()
154+ b .clear ()
156155 b .show_cloud (
157156 "1" , b .apply_transform_jit (trace ["image" ].reshape (- 1 , 3 ), trace ["camera_pose" ])
158157 )
@@ -224,14 +223,14 @@ def enumerator(trace, key, *args):
224223 key ,
225224 chm_builder (addresses , args , chm_args ),
226225 argdiff_f (trace ),
227- )[0 ]
226+ )[2 ]
228227
229228 def enumerator_with_weight (trace , key , * args ):
230229 return trace .update (
231230 key ,
232231 chm_builder (addresses , args , chm_args ),
233232 argdiff_f (trace ),
234- )[0 : 2 ]
233+ )[1 : 3 ]
235234
236235 def enumerator_score (trace , key , * args ):
237236 return enumerator (trace , key , * args ).get_score ()
@@ -302,4 +301,4 @@ def update_address(trace, key, address, value):
302301 key ,
303302 genjax .choice_map ({address : value }),
304303 tuple (map (lambda v : Diff (v , UnknownChange ), trace .args )),
305- )[0 ]
304+ )[2 ]
0 commit comments