@@ -23,6 +23,12 @@ class LocomotionEnv(DirectRLEnv):
2323 def __init__ (self , cfg : DirectRLEnvCfg , render_mode : str | None = None , ** kwargs ):
2424 super ().__init__ (cfg , render_mode , ** kwargs )
2525
26+ self ._compute_intermediate_values_fn = _make_compute_intermediate_values (
27+ lazy .isaacsim .core .utils .torch .rotations .compute_heading_and_up ,
28+ lazy .isaacsim .core .utils .torch .rotations .compute_rot ,
29+ lazy .isaacsim .core .utils .torch .maths .unscale ,
30+ )
31+
2632 self .action_scale = self .cfg .action_scale
2733 self .joint_gears = torch .tensor (self .cfg .joint_gears , dtype = torch .float32 , device = self .sim .device )
2834 self .motor_effort_ratio = torch .ones_like (self .joint_gears , device = self .sim .device )
@@ -88,7 +94,7 @@ def _compute_intermediate_values(self):
8894 self .dof_pos_scaled ,
8995 self .prev_potentials ,
9096 self .potentials ,
91- ) = compute_intermediate_values (
97+ ) = self . _compute_intermediate_values_fn (
9298 self .targets ,
9399 self .torso_position ,
94100 self .torso_rotation ,
@@ -228,55 +234,56 @@ def compute_rewards(
228234 return total_reward
229235
230236
231- @torch .jit .script
232- def compute_intermediate_values (
233- targets : torch .Tensor ,
234- torso_position : torch .Tensor ,
235- torso_rotation : torch .Tensor ,
236- velocity : torch .Tensor ,
237- ang_velocity : torch .Tensor ,
238- dof_pos : torch .Tensor ,
239- dof_lower_limits : torch .Tensor ,
240- dof_upper_limits : torch .Tensor ,
241- inv_start_rot : torch .Tensor ,
242- basis_vec0 : torch .Tensor ,
243- basis_vec1 : torch .Tensor ,
244- potentials : torch .Tensor ,
245- prev_potentials : torch .Tensor ,
246- dt : float ,
247- ):
248- to_target = targets - torso_position
249- to_target [:, 2 ] = 0.0
237+ def _make_compute_intermediate_values (compute_heading_and_up_fn , compute_rot_fn , unscale_fn ):
238+ @torch .jit .script
239+ def _compute_intermediate_values (
240+ targets : torch .Tensor ,
241+ torso_position : torch .Tensor ,
242+ torso_rotation : torch .Tensor ,
243+ velocity : torch .Tensor ,
244+ ang_velocity : torch .Tensor ,
245+ dof_pos : torch .Tensor ,
246+ dof_lower_limits : torch .Tensor ,
247+ dof_upper_limits : torch .Tensor ,
248+ inv_start_rot : torch .Tensor ,
249+ basis_vec0 : torch .Tensor ,
250+ basis_vec1 : torch .Tensor ,
251+ potentials : torch .Tensor ,
252+ prev_potentials : torch .Tensor ,
253+ dt : float ,
254+ ):
255+ to_target = targets - torso_position
256+ to_target [:, 2 ] = 0.0
250257
251- torso_quat , up_proj , heading_proj , up_vec , heading_vec = (
252- lazy .isaacsim .core .utils .torch .rotations .compute_heading_and_up (
258+ torso_quat , up_proj , heading_proj , up_vec , heading_vec = compute_heading_and_up_fn (
253259 torso_rotation , inv_start_rot , to_target , basis_vec0 , basis_vec1 , 2
254260 )
255- )
256261
257- vel_loc , angvel_loc , roll , pitch , yaw , angle_to_target = lazy . isaacsim . core . utils . torch . rotations . compute_rot (
258- torso_quat , velocity , ang_velocity , targets , torso_position
259- )
262+ vel_loc , angvel_loc , roll , pitch , yaw , angle_to_target = compute_rot_fn (
263+ torso_quat , velocity , ang_velocity , targets , torso_position
264+ )
260265
261- dof_pos_scaled = lazy .isaacsim .core .utils .torch .maths .unscale (dof_pos , dof_lower_limits , dof_upper_limits )
262-
263- to_target = targets - torso_position
264- to_target [:, 2 ] = 0.0
265- prev_potentials [:] = potentials
266- potentials = - torch .norm (to_target , p = 2 , dim = - 1 ) / dt
267-
268- return (
269- up_proj ,
270- heading_proj ,
271- up_vec ,
272- heading_vec ,
273- vel_loc ,
274- angvel_loc ,
275- roll ,
276- pitch ,
277- yaw ,
278- angle_to_target ,
279- dof_pos_scaled ,
280- prev_potentials ,
281- potentials ,
282- )
266+ dof_pos_scaled = unscale_fn (dof_pos , dof_lower_limits , dof_upper_limits )
267+
268+ to_target = targets - torso_position
269+ to_target [:, 2 ] = 0.0
270+ prev_potentials [:] = potentials
271+ potentials = - torch .norm (to_target , p = 2 , dim = - 1 ) / dt
272+
273+ return (
274+ up_proj ,
275+ heading_proj ,
276+ up_vec ,
277+ heading_vec ,
278+ vel_loc ,
279+ angvel_loc ,
280+ roll ,
281+ pitch ,
282+ yaw ,
283+ angle_to_target ,
284+ dof_pos_scaled ,
285+ prev_potentials ,
286+ potentials ,
287+ )
288+
289+ return _compute_intermediate_values
0 commit comments