Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ Fixed
derived command state (e.g. relative body positions in tracking
environments) is populated before the first observation is returned
(:issue:`761`).
- ``RayCastSensor`` with ``ray_alignment="yaw"`` or ``"world"`` now correctly
aligns the frame offset when attached to a site or geom with a local offset
from its parent body. Previously only ray directions and pattern offsets were
aligned, causing the frame position to swing with body pitch/roll
(:issue:`775`).

Version 1.2.0 (March 6, 2026)
-----------------------------
Expand Down
43 changes: 35 additions & 8 deletions src/mjlab/sensor/raycast_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,8 @@ def __init__(self, cfg: RayCastSensorCfg) -> None:
self._cached_frame_pos: torch.Tensor | None = None
self._cached_frame_mat: torch.Tensor | None = None

self._frame_local_pos: torch.Tensor | None = None

self._ctx: SensorContext | None = None

def edit_spec(
Expand Down Expand Up @@ -524,11 +526,21 @@ def initialize(
# Look up parent body for exclusion.
self._frame_body_id = int(mj_model.site_bodyid[self._frame_site_id])
self._frame_type = "site"
self._frame_local_pos = torch.tensor(
mj_model.site_pos[self._frame_site_id],
dtype=torch.float32,
device=device,
)
elif frame.type == "geom":
self._frame_geom_id = mj_model.geom(frame_name).id
# Look up parent body for exclusion.
self._frame_body_id = int(mj_model.geom_bodyid[self._frame_geom_id])
self._frame_type = "geom"
self._frame_local_pos = torch.tensor(
mj_model.geom_pos[self._frame_geom_id],
dtype=torch.float32,
device=device,
)
else:
raise ValueError(
f"RayCastSensor frame must be 'body', 'site', or 'geom', got '{frame.type}'"
Expand Down Expand Up @@ -624,11 +636,21 @@ def debug_vis(self, visualizer: DebugVisualizer) -> None:
frame_pos = self._data.xpos[env_indices, self._frame_body_id]
frame_mat = self._data.xmat[env_indices, self._frame_body_id]
elif self._frame_type == "site":
frame_pos = self._data.site_xpos[env_indices, self._frame_site_id]
body_pos = self._data.xpos[env_indices, self._frame_body_id]
body_mat = self._data.xmat[env_indices, self._frame_body_id]
frame_mat = self._data.site_xmat[env_indices, self._frame_site_id]
body_align = self._compute_alignment_rotation(body_mat.view(-1, 3, 3))
frame_pos = body_pos + torch.einsum(
"bij,j->bi", body_align, self._frame_local_pos
)
else: # geom
frame_pos = self._data.geom_xpos[env_indices, self._frame_geom_id]
body_pos = self._data.xpos[env_indices, self._frame_body_id]
body_mat = self._data.xmat[env_indices, self._frame_body_id]
frame_mat = self._data.geom_xmat[env_indices, self._frame_geom_id]
body_align = self._compute_alignment_rotation(body_mat.view(-1, 3, 3))
frame_pos = body_pos + torch.einsum(
"bij,j->bi", body_align, self._frame_local_pos
)

rot_mats = self._compute_alignment_rotation(frame_mat.view(-1, 3, 3)).cpu().numpy()
origins = frame_pos.cpu().numpy()
Expand Down Expand Up @@ -701,12 +723,17 @@ def prepare_rays(self) -> None:
if self._frame_type == "body":
frame_pos = self._data.xpos[:, self._frame_body_id]
frame_mat = self._data.xmat[:, self._frame_body_id].view(-1, 3, 3)
elif self._frame_type == "site":
frame_pos = self._data.site_xpos[:, self._frame_site_id]
frame_mat = self._data.site_xmat[:, self._frame_site_id].view(-1, 3, 3)
else: # geom
frame_pos = self._data.geom_xpos[:, self._frame_geom_id]
frame_mat = self._data.geom_xmat[:, self._frame_geom_id].view(-1, 3, 3)
else:
body_pos = self._data.xpos[:, self._frame_body_id]
body_mat = self._data.xmat[:, self._frame_body_id].view(-1, 3, 3)
if self._frame_type == "site":
frame_mat = self._data.site_xmat[:, self._frame_site_id].view(-1, 3, 3)
else:
frame_mat = self._data.geom_xmat[:, self._frame_geom_id].view(-1, 3, 3)
body_align = self._compute_alignment_rotation(body_mat)
frame_pos = body_pos + torch.einsum(
"bij,j->bi", body_align, self._frame_local_pos
)

num_envs = frame_pos.shape[0]
rot_mat = self._compute_alignment_rotation(frame_mat)
Expand Down
128 changes: 118 additions & 10 deletions tests/test_raycast_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,11 +534,6 @@ def test_raycast_body_rotation_affects_rays(device):
), f"Expected ~{expected_distance:.2f}m, got {data_rotated.distances}"


# ============================================================================
# Pinhole Camera Pattern Tests
# ============================================================================


def test_pinhole_camera_pattern_num_rays(device):
"""Verify pinhole pattern generates width * height rays."""
simple_xml = """
Expand Down Expand Up @@ -671,11 +666,6 @@ def test_pinhole_from_mujoco_camera_fovy_mode(device):
assert torch.all(data.distances >= 0) # Should hit floor


# ============================================================================
# Ray Alignment Tests
# ============================================================================


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Likely bug on CPU MjWarp")
def test_ray_alignment_yaw(device):
"""Verify yaw alignment ignores pitch/roll."""
Expand Down Expand Up @@ -943,3 +933,121 @@ def test_height_scan_misses(device):
assert torch.allclose(
heights, torch.full_like(heights, raycast_cfg.max_distance), atol=1e-5
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Likely bug on CPU MjWarp")
def test_ray_alignment_yaw_site_offset(device):
"""Verify yaw alignment correctly aligns frame offset for sites.

When a site has a large Z offset from its parent body, the frame position should be
recomputed using the body's yaw-only rotation so that pitch/roll does not swing the
ray origin.
"""
xml = """
<mujoco>
<option gravity="0 0 0"/>
<worldbody>
<geom name="floor" type="plane" size="50 50 0.1" pos="0 0 0"/>
<body name="base" pos="0 0 2">
<freejoint name="free_joint"/>
<geom name="base_geom" type="sphere" size="0.1" mass="1.0"/>
<site name="high_site" pos="0 0 20"/>
</body>
</worldbody>
</mujoco>
"""

raycast_cfg = RayCastSensorCfg(
name="yaw_site_scan",
frame=ObjRef(type="site", name="high_site", entity="robot"),
pattern=GridPatternCfg(size=(0.0, 0.0), resolution=0.1, direction=(0.0, 0.0, -1.0)),
ray_alignment="yaw",
max_distance=50.0,
)

scene, sim = _make_scene_and_sim(device, xml, sensors=(raycast_cfg,))
sensor = scene["yaw_site_scan"]

# Baseline: unrotated body. Site at z=22, floor at z=0 -> distance ~22m.
sim.step()
scene.update(dt=sim.cfg.mujoco.timestep)
sim.sense()
baseline_dist = sensor.data.distances.clone()
assert torch.allclose(baseline_dist, torch.full_like(baseline_dist, 22.0), atol=0.2)

# Tilt body 30 degrees around X axis (pitch).
angle = math.pi / 6
quat = [math.cos(angle / 2), math.sin(angle / 2), 0, 0]
sim.data.qpos[0, 3:7] = torch.tensor(quat, device=device)
sim.step()
scene.update(dt=sim.cfg.mujoco.timestep)
sim.sense()
tilted_dist = sensor.data.distances.clone()

# With the fix, distance should remain ~22m because yaw alignment prevents the site
# offset from swinging with pitch.
assert torch.allclose(tilted_dist, baseline_dist, atol=0.5), (
f"Expected ~22m after pitch, got {tilted_dist}"
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Likely bug on CPU MjWarp")
def test_ray_alignment_world_site_offset(device):
"""Verify world alignment correctly aligns frame offset for sites.

With world alignment the site offset should be rotated by the identity matrix, so
combined pitch and roll on the body should not move the ray origin away from directly
above the body.
"""
xml = """
<mujoco>
<option gravity="0 0 0"/>
<worldbody>
<geom name="floor" type="plane" size="50 50 0.1" pos="0 0 0"/>
<body name="base" pos="0 0 2">
<freejoint name="free_joint"/>
<geom name="base_geom" type="sphere" size="0.1" mass="1.0"/>
<site name="high_site" pos="0 0 20"/>
</body>
</worldbody>
</mujoco>
"""

raycast_cfg = RayCastSensorCfg(
name="world_site_scan",
frame=ObjRef(type="site", name="high_site", entity="robot"),
pattern=GridPatternCfg(size=(0.0, 0.0), resolution=0.1, direction=(0.0, 0.0, -1.0)),
ray_alignment="world",
max_distance=50.0,
)

scene, sim = _make_scene_and_sim(device, xml, sensors=(raycast_cfg,))
sensor = scene["world_site_scan"]

# Baseline: unrotated body. Site at z=22 -> distance ~22m.
sim.step()
scene.update(dt=sim.cfg.mujoco.timestep)
sim.sense()
baseline_dist = sensor.data.distances.clone()
assert torch.allclose(baseline_dist, torch.full_like(baseline_dist, 22.0), atol=0.2)

# Apply combined pitch (30 deg) + roll (20 deg).
pitch = math.pi / 6
roll = math.pi / 9
cp, sp = math.cos(pitch / 2), math.sin(pitch / 2)
cr, sr = math.cos(roll / 2), math.sin(roll / 2)
# q_pitch (around X) then q_roll (around Y): q = q_roll * q_pitch
qw = cr * cp
qx = cr * sp
qy = sr * cp
qz = -sr * sp
sim.data.qpos[0, 3:7] = torch.tensor([qw, qx, qy, qz], device=device)
sim.step()
scene.update(dt=sim.cfg.mujoco.timestep)
sim.sense()
rotated_dist = sensor.data.distances.clone()

# With world alignment, distance should remain ~22m.
assert torch.allclose(rotated_dist, baseline_dist, atol=0.5), (
f"Expected ~22m after pitch+roll, got {rotated_dist}"
)
Loading