From 6e731bcc1c2c0d3ef5fa13bbb5fcc21460ce965c Mon Sep 17 00:00:00 2001 From: Kevin Zakka Date: Wed, 11 Mar 2026 11:56:12 -0700 Subject: [PATCH] Fix RayCastSensor frame offset alignment for site/geom frames When a RayCastSensor is attached to a site or geom with a local offset, ray_alignment="yaw" and "world" only aligned ray directions and pattern offsets. The frame position was read directly from MuJoCo's site_xpos, which bakes in the full body rotation on the offset. This caused the ray origin to swing with body pitch/roll. The fix decomposes the frame position for site/geom frames: instead of reading site_xpos, recompute it as body_pos + alignment(body_mat) @ frame_local_pos. For "base" alignment this is identical to the previous behavior. For "yaw" and "world", the offset now respects the alignment mode. The same decomposition is applied in debug_vis() so visualization matches the actual ray computation. Fixes #775 --- docs/source/changelog.rst | 5 ++ src/mjlab/sensor/raycast_sensor.py | 43 ++++++++-- tests/test_raycast_sensor.py | 128 ++++++++++++++++++++++++++--- 3 files changed, 158 insertions(+), 18 deletions(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 547e6f63c..bccf1843d 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -45,6 +45,11 @@ Fixed derived command state (e.g. relative body positions in tracking environments) is populated before the first observation is returned (:issue:`761`). +- ``RayCastSensor`` with ``ray_alignment="yaw"`` or ``"world"`` now correctly + aligns the frame offset when attached to a site or geom with a local offset + from its parent body. Previously only ray directions and pattern offsets were + aligned, causing the frame position to swing with body pitch/roll + (:issue:`775`). Version 1.2.0 (March 6, 2026) ----------------------------- diff --git a/src/mjlab/sensor/raycast_sensor.py b/src/mjlab/sensor/raycast_sensor.py index 013bb1f76..0b166f249 100644 --- a/src/mjlab/sensor/raycast_sensor.py +++ b/src/mjlab/sensor/raycast_sensor.py @@ -490,6 +490,8 @@ def __init__(self, cfg: RayCastSensorCfg) -> None: self._cached_frame_pos: torch.Tensor | None = None self._cached_frame_mat: torch.Tensor | None = None + self._frame_local_pos: torch.Tensor | None = None + self._ctx: SensorContext | None = None def edit_spec( @@ -524,11 +526,21 @@ def initialize( # Look up parent body for exclusion. self._frame_body_id = int(mj_model.site_bodyid[self._frame_site_id]) self._frame_type = "site" + self._frame_local_pos = torch.tensor( + mj_model.site_pos[self._frame_site_id], + dtype=torch.float32, + device=device, + ) elif frame.type == "geom": self._frame_geom_id = mj_model.geom(frame_name).id # Look up parent body for exclusion. self._frame_body_id = int(mj_model.geom_bodyid[self._frame_geom_id]) self._frame_type = "geom" + self._frame_local_pos = torch.tensor( + mj_model.geom_pos[self._frame_geom_id], + dtype=torch.float32, + device=device, + ) else: raise ValueError( f"RayCastSensor frame must be 'body', 'site', or 'geom', got '{frame.type}'" @@ -624,11 +636,21 @@ def debug_vis(self, visualizer: DebugVisualizer) -> None: frame_pos = self._data.xpos[env_indices, self._frame_body_id] frame_mat = self._data.xmat[env_indices, self._frame_body_id] elif self._frame_type == "site": - frame_pos = self._data.site_xpos[env_indices, self._frame_site_id] + body_pos = self._data.xpos[env_indices, self._frame_body_id] + body_mat = self._data.xmat[env_indices, self._frame_body_id] frame_mat = self._data.site_xmat[env_indices, self._frame_site_id] + body_align = self._compute_alignment_rotation(body_mat.view(-1, 3, 3)) + frame_pos = body_pos + torch.einsum( + "bij,j->bi", body_align, self._frame_local_pos + ) else: # geom - frame_pos = self._data.geom_xpos[env_indices, self._frame_geom_id] + body_pos = self._data.xpos[env_indices, self._frame_body_id] + body_mat = self._data.xmat[env_indices, self._frame_body_id] frame_mat = self._data.geom_xmat[env_indices, self._frame_geom_id] + body_align = self._compute_alignment_rotation(body_mat.view(-1, 3, 3)) + frame_pos = body_pos + torch.einsum( + "bij,j->bi", body_align, self._frame_local_pos + ) rot_mats = self._compute_alignment_rotation(frame_mat.view(-1, 3, 3)).cpu().numpy() origins = frame_pos.cpu().numpy() @@ -701,12 +723,17 @@ def prepare_rays(self) -> None: if self._frame_type == "body": frame_pos = self._data.xpos[:, self._frame_body_id] frame_mat = self._data.xmat[:, self._frame_body_id].view(-1, 3, 3) - elif self._frame_type == "site": - frame_pos = self._data.site_xpos[:, self._frame_site_id] - frame_mat = self._data.site_xmat[:, self._frame_site_id].view(-1, 3, 3) - else: # geom - frame_pos = self._data.geom_xpos[:, self._frame_geom_id] - frame_mat = self._data.geom_xmat[:, self._frame_geom_id].view(-1, 3, 3) + else: + body_pos = self._data.xpos[:, self._frame_body_id] + body_mat = self._data.xmat[:, self._frame_body_id].view(-1, 3, 3) + if self._frame_type == "site": + frame_mat = self._data.site_xmat[:, self._frame_site_id].view(-1, 3, 3) + else: + frame_mat = self._data.geom_xmat[:, self._frame_geom_id].view(-1, 3, 3) + body_align = self._compute_alignment_rotation(body_mat) + frame_pos = body_pos + torch.einsum( + "bij,j->bi", body_align, self._frame_local_pos + ) num_envs = frame_pos.shape[0] rot_mat = self._compute_alignment_rotation(frame_mat) diff --git a/tests/test_raycast_sensor.py b/tests/test_raycast_sensor.py index 44b78a142..1ea3fafde 100644 --- a/tests/test_raycast_sensor.py +++ b/tests/test_raycast_sensor.py @@ -534,11 +534,6 @@ def test_raycast_body_rotation_affects_rays(device): ), f"Expected ~{expected_distance:.2f}m, got {data_rotated.distances}" -# ============================================================================ -# Pinhole Camera Pattern Tests -# ============================================================================ - - def test_pinhole_camera_pattern_num_rays(device): """Verify pinhole pattern generates width * height rays.""" simple_xml = """ @@ -671,11 +666,6 @@ def test_pinhole_from_mujoco_camera_fovy_mode(device): assert torch.all(data.distances >= 0) # Should hit floor -# ============================================================================ -# Ray Alignment Tests -# ============================================================================ - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Likely bug on CPU MjWarp") def test_ray_alignment_yaw(device): """Verify yaw alignment ignores pitch/roll.""" @@ -943,3 +933,121 @@ def test_height_scan_misses(device): assert torch.allclose( heights, torch.full_like(heights, raycast_cfg.max_distance), atol=1e-5 ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Likely bug on CPU MjWarp") +def test_ray_alignment_yaw_site_offset(device): + """Verify yaw alignment correctly aligns frame offset for sites. + + When a site has a large Z offset from its parent body, the frame position should be + recomputed using the body's yaw-only rotation so that pitch/roll does not swing the + ray origin. + """ + xml = """ + + + """ + + raycast_cfg = RayCastSensorCfg( + name="yaw_site_scan", + frame=ObjRef(type="site", name="high_site", entity="robot"), + pattern=GridPatternCfg(size=(0.0, 0.0), resolution=0.1, direction=(0.0, 0.0, -1.0)), + ray_alignment="yaw", + max_distance=50.0, + ) + + scene, sim = _make_scene_and_sim(device, xml, sensors=(raycast_cfg,)) + sensor = scene["yaw_site_scan"] + + # Baseline: unrotated body. Site at z=22, floor at z=0 -> distance ~22m. + sim.step() + scene.update(dt=sim.cfg.mujoco.timestep) + sim.sense() + baseline_dist = sensor.data.distances.clone() + assert torch.allclose(baseline_dist, torch.full_like(baseline_dist, 22.0), atol=0.2) + + # Tilt body 30 degrees around X axis (pitch). + angle = math.pi / 6 + quat = [math.cos(angle / 2), math.sin(angle / 2), 0, 0] + sim.data.qpos[0, 3:7] = torch.tensor(quat, device=device) + sim.step() + scene.update(dt=sim.cfg.mujoco.timestep) + sim.sense() + tilted_dist = sensor.data.distances.clone() + + # With the fix, distance should remain ~22m because yaw alignment prevents the site + # offset from swinging with pitch. + assert torch.allclose(tilted_dist, baseline_dist, atol=0.5), ( + f"Expected ~22m after pitch, got {tilted_dist}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Likely bug on CPU MjWarp") +def test_ray_alignment_world_site_offset(device): + """Verify world alignment correctly aligns frame offset for sites. + + With world alignment the site offset should be rotated by the identity matrix, so + combined pitch and roll on the body should not move the ray origin away from directly + above the body. + """ + xml = """ + + + """ + + raycast_cfg = RayCastSensorCfg( + name="world_site_scan", + frame=ObjRef(type="site", name="high_site", entity="robot"), + pattern=GridPatternCfg(size=(0.0, 0.0), resolution=0.1, direction=(0.0, 0.0, -1.0)), + ray_alignment="world", + max_distance=50.0, + ) + + scene, sim = _make_scene_and_sim(device, xml, sensors=(raycast_cfg,)) + sensor = scene["world_site_scan"] + + # Baseline: unrotated body. Site at z=22 -> distance ~22m. + sim.step() + scene.update(dt=sim.cfg.mujoco.timestep) + sim.sense() + baseline_dist = sensor.data.distances.clone() + assert torch.allclose(baseline_dist, torch.full_like(baseline_dist, 22.0), atol=0.2) + + # Apply combined pitch (30 deg) + roll (20 deg). + pitch = math.pi / 6 + roll = math.pi / 9 + cp, sp = math.cos(pitch / 2), math.sin(pitch / 2) + cr, sr = math.cos(roll / 2), math.sin(roll / 2) + # q_pitch (around X) then q_roll (around Y): q = q_roll * q_pitch + qw = cr * cp + qx = cr * sp + qy = sr * cp + qz = -sr * sp + sim.data.qpos[0, 3:7] = torch.tensor([qw, qx, qy, qz], device=device) + sim.step() + scene.update(dt=sim.cfg.mujoco.timestep) + sim.sense() + rotated_dist = sensor.data.distances.clone() + + # With world alignment, distance should remain ~22m. + assert torch.allclose(rotated_dist, baseline_dist, atol=0.5), ( + f"Expected ~22m after pitch+roll, got {rotated_dist}" + )