Skip to content

Commit 2d8ae59

Browse files
committed
fix(terminations): index before reduce in joint_pos_out_of_limit
Root cause: reduced over joints to [N] then indexed with [:, joint_ids] → IndexError. Change: compare to limits → [N,J], slice columns with joint_ids → [N,K], then .any(dim=1). Result: returns [N] without error for both [1,J,2] and [N,J,2] limit tensors; joint_ids=None still uses all joints.
1 parent 5f71ff4 commit 2d8ae59

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

source/isaaclab/isaaclab/envs/mdp/terminations.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,19 @@ def joint_pos_out_of_limit(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = S
8181
"""Terminate when the asset's joint positions are outside of the soft joint limits."""
8282
# extract the used quantities (to enable type-hinting)
8383
asset: Articulation = env.scene[asset_cfg.name]
84-
if asset_cfg.joint_ids is None:
85-
asset_cfg.joint_ids = slice(None)
84+
# compute any per-joint violations (avoid reducing before indexing)
85+
out_of_upper_limits = asset.data.joint_pos > asset.data.soft_joint_pos_limits[..., 1] # [N, J]
86+
out_of_lower_limits = asset.data.joint_pos < asset.data.soft_joint_pos_limits[..., 0] # [N, J]
8687

87-
limits = asset.data.soft_joint_pos_limits[:, asset_cfg.joint_ids]
88-
out_of_upper_limits = torch.any(asset.data.joint_pos[:, asset_cfg.joint_ids] > limits[..., 1], dim=1)
89-
out_of_lower_limits = torch.any(asset.data.joint_pos[:, asset_cfg.joint_ids] < limits[..., 0], dim=1)
90-
return torch.logical_or(out_of_upper_limits, out_of_lower_limits)
88+
# truncate above output to just the joints we care about
89+
out_of_upper_limits = out_of_upper_limits[:, asset_cfg.joint_ids] # [N, K]
90+
out_of_lower_limits = out_of_lower_limits[:, asset_cfg.joint_ids] # [N, K]
91+
92+
# reduce over selected joints
93+
out_of_upper_limits = torch.any(out_of_upper_limits, dim=1) # [N]
94+
out_of_lower_limits = torch.any(out_of_lower_limits, dim=1) # [N]
95+
96+
return torch.logical_or(out_of_upper_limits, out_of_lower_limits) # [N]
9197

9298

9399
def joint_pos_out_of_manual_limit(

0 commit comments

Comments
 (0)