Skip to content

fix(terminations): index before reduce in joint_pos_out_of_limit #3153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Guidelines for modifications:
* Lukas Fröhlich
* Manuel Schweiger
* Masoud Moghani
* Maurice Rahme
* Michael Gussert
* Michael Noseworthy
* Miguel Alonso Jr
Expand Down
17 changes: 13 additions & 4 deletions source/isaaclab/isaaclab/envs/mdp/terminations.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,19 @@ def joint_pos_out_of_limit(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = S
if asset_cfg.joint_ids is None:
asset_cfg.joint_ids = slice(None)

limits = asset.data.soft_joint_pos_limits[:, asset_cfg.joint_ids]
out_of_upper_limits = torch.any(asset.data.joint_pos[:, asset_cfg.joint_ids] > limits[..., 1], dim=1)
out_of_lower_limits = torch.any(asset.data.joint_pos[:, asset_cfg.joint_ids] < limits[..., 0], dim=1)
return torch.logical_or(out_of_upper_limits, out_of_lower_limits)
# compute any per-joint violations (avoid reducing before indexing)
out_of_upper_limits = asset.data.joint_pos > asset.data.soft_joint_pos_limits[..., 1] # [N, J]
out_of_lower_limits = asset.data.joint_pos < asset.data.soft_joint_pos_limits[..., 0] # [N, J]

# truncate above output to just the joints we care about
out_of_upper_limits = out_of_upper_limits[:, asset_cfg.joint_ids] # [N, K]
out_of_lower_limits = out_of_lower_limits[:, asset_cfg.joint_ids] # [N, K]

# reduce over selected joints
out_of_upper_limits = torch.any(out_of_upper_limits, dim=1) # [N]
out_of_lower_limits = torch.any(out_of_lower_limits, dim=1) # [N]

return torch.logical_or(out_of_upper_limits, out_of_lower_limits) # [N]


def joint_pos_out_of_manual_limit(
Expand Down
36 changes: 25 additions & 11 deletions source/isaaclab/test/assets/test_articulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import isaaclab.utils.string as string_utils
from isaaclab.actuators import ActuatorBase, IdealPDActuatorCfg, ImplicitActuatorCfg
from isaaclab.assets import Articulation, ArticulationCfg
from isaaclab.envs.mdp.terminations import joint_pos_out_of_limit
from isaaclab.managers import SceneEntityCfg
from isaaclab.sim import build_simulation_context
from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR

Expand Down Expand Up @@ -658,13 +660,11 @@ def test_out_of_range_default_joint_vel(sim, device):
@pytest.mark.parametrize("add_ground_plane", [True])
def test_joint_pos_limits(sim, num_articulations, device, add_ground_plane):
"""Test write_joint_limits_to_sim API and when default pos falls outside of the new limits.

This test verifies that:
1. Joint limits can be set correctly
2. Default positions are preserved when setting new limits
3. Joint limits can be set with indexing
4. Invalid joint positions are properly handled

Args:
sim: The simulation fixture
num_articulations: Number of articulations to test
Expand All @@ -673,6 +673,14 @@ def test_joint_pos_limits(sim, num_articulations, device, add_ground_plane):
articulation_cfg = generate_articulation_cfg(articulation_type="panda")
articulation, _ = generate_articulation(articulation_cfg, num_articulations, device)

# Minimal fake env that exposes scene["robot"] -> articulation
class _Env:
def __init__(self, art):
self.scene = {"robot": art}

env = _Env(articulation)
robot_all = SceneEntityCfg(name="robot") # all joints

# Play sim
sim.reset()
# Check if articulation is initialized
Expand All @@ -691,6 +699,11 @@ def test_joint_pos_limits(sim, num_articulations, device, add_ground_plane):
torch.testing.assert_close(articulation._data.joint_pos_limits, limits)
torch.testing.assert_close(articulation._data.default_joint_pos, default_joint_pos)

# Set new joint limits that invalidate default joint pos
# Validate via function: no joint should be out of limits
out = joint_pos_out_of_limit(env, robot_all) # [N]
assert torch.all(~out)

# Set new joint limits with indexing
env_ids = torch.arange(1, device=device)
joint_ids = torch.arange(2, device=device)
Expand All @@ -704,28 +717,29 @@ def test_joint_pos_limits(sim, num_articulations, device, add_ground_plane):
torch.testing.assert_close(articulation._data.default_joint_pos, default_joint_pos)

# Set new joint limits that invalidate default joint pos
robot_subset = SceneEntityCfg(name="robot", joint_ids=joint_ids.tolist())
out = joint_pos_out_of_limit(env, robot_subset) # [N]
assert torch.all(~out)

# Set new joint limits that (narrowly) constrain default joint pos
limits = torch.zeros(num_articulations, articulation.num_joints, 2, device=device)
limits[..., 0] = torch.rand(num_articulations, articulation.num_joints, device=device) * -0.1
limits[..., 1] = torch.rand(num_articulations, articulation.num_joints, device=device) * 0.1
articulation.write_joint_position_limit_to_sim(limits)

# Check if all values are within the bounds
within_bounds = (articulation._data.default_joint_pos >= limits[..., 0]) & (
articulation._data.default_joint_pos <= limits[..., 1]
)
assert torch.all(within_bounds)
out = joint_pos_out_of_limit(env, robot_all) # [N]
assert torch.all(~out)

# Set new joint limits that invalidate default joint pos with indexing
limits = torch.zeros(env_ids.shape[0], joint_ids.shape[0], 2, device=device)
limits[..., 0] = torch.rand(env_ids.shape[0], joint_ids.shape[0], device=device) * -0.1
limits[..., 1] = torch.rand(env_ids.shape[0], joint_ids.shape[0], device=device) * 0.1
articulation.write_joint_position_limit_to_sim(limits, env_ids=env_ids, joint_ids=joint_ids)

# Check if all values are within the bounds
within_bounds = (articulation._data.default_joint_pos[env_ids][:, joint_ids] >= limits[..., 0]) & (
articulation._data.default_joint_pos[env_ids][:, joint_ids] <= limits[..., 1]
)
assert torch.all(within_bounds)
# Validate via function on selected joints
out = joint_pos_out_of_limit(env, robot_subset) # [N]
assert torch.all(~out)


@pytest.mark.parametrize("num_articulations", [1, 2])
Expand Down