diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst
index 547e6f63c..bccf1843d 100644
--- a/docs/source/changelog.rst
+++ b/docs/source/changelog.rst
@@ -45,6 +45,11 @@ Fixed
derived command state (e.g. relative body positions in tracking
environments) is populated before the first observation is returned
(:issue:`761`).
+- ``RayCastSensor`` with ``ray_alignment="yaw"`` or ``"world"`` now correctly
+ aligns the frame offset when attached to a site or geom with a local offset
+ from its parent body. Previously only ray directions and pattern offsets were
+ aligned, causing the frame position to swing with body pitch/roll
+ (:issue:`775`).
Version 1.2.0 (March 6, 2026)
-----------------------------
diff --git a/src/mjlab/sensor/raycast_sensor.py b/src/mjlab/sensor/raycast_sensor.py
index 013bb1f76..0b166f249 100644
--- a/src/mjlab/sensor/raycast_sensor.py
+++ b/src/mjlab/sensor/raycast_sensor.py
@@ -490,6 +490,8 @@ def __init__(self, cfg: RayCastSensorCfg) -> None:
self._cached_frame_pos: torch.Tensor | None = None
self._cached_frame_mat: torch.Tensor | None = None
+ self._frame_local_pos: torch.Tensor | None = None
+
self._ctx: SensorContext | None = None
def edit_spec(
@@ -524,11 +526,21 @@ def initialize(
# Look up parent body for exclusion.
self._frame_body_id = int(mj_model.site_bodyid[self._frame_site_id])
self._frame_type = "site"
+ self._frame_local_pos = torch.tensor(
+ mj_model.site_pos[self._frame_site_id],
+ dtype=torch.float32,
+ device=device,
+ )
elif frame.type == "geom":
self._frame_geom_id = mj_model.geom(frame_name).id
# Look up parent body for exclusion.
self._frame_body_id = int(mj_model.geom_bodyid[self._frame_geom_id])
self._frame_type = "geom"
+ self._frame_local_pos = torch.tensor(
+ mj_model.geom_pos[self._frame_geom_id],
+ dtype=torch.float32,
+ device=device,
+ )
else:
raise ValueError(
f"RayCastSensor frame must be 'body', 'site', or 'geom', got '{frame.type}'"
@@ -624,11 +636,21 @@ def debug_vis(self, visualizer: DebugVisualizer) -> None:
frame_pos = self._data.xpos[env_indices, self._frame_body_id]
frame_mat = self._data.xmat[env_indices, self._frame_body_id]
elif self._frame_type == "site":
- frame_pos = self._data.site_xpos[env_indices, self._frame_site_id]
+ body_pos = self._data.xpos[env_indices, self._frame_body_id]
+ body_mat = self._data.xmat[env_indices, self._frame_body_id]
frame_mat = self._data.site_xmat[env_indices, self._frame_site_id]
+ body_align = self._compute_alignment_rotation(body_mat.view(-1, 3, 3))
+ frame_pos = body_pos + torch.einsum(
+ "bij,j->bi", body_align, self._frame_local_pos
+ )
else: # geom
- frame_pos = self._data.geom_xpos[env_indices, self._frame_geom_id]
+ body_pos = self._data.xpos[env_indices, self._frame_body_id]
+ body_mat = self._data.xmat[env_indices, self._frame_body_id]
frame_mat = self._data.geom_xmat[env_indices, self._frame_geom_id]
+ body_align = self._compute_alignment_rotation(body_mat.view(-1, 3, 3))
+ frame_pos = body_pos + torch.einsum(
+ "bij,j->bi", body_align, self._frame_local_pos
+ )
rot_mats = self._compute_alignment_rotation(frame_mat.view(-1, 3, 3)).cpu().numpy()
origins = frame_pos.cpu().numpy()
@@ -701,12 +723,17 @@ def prepare_rays(self) -> None:
if self._frame_type == "body":
frame_pos = self._data.xpos[:, self._frame_body_id]
frame_mat = self._data.xmat[:, self._frame_body_id].view(-1, 3, 3)
- elif self._frame_type == "site":
- frame_pos = self._data.site_xpos[:, self._frame_site_id]
- frame_mat = self._data.site_xmat[:, self._frame_site_id].view(-1, 3, 3)
- else: # geom
- frame_pos = self._data.geom_xpos[:, self._frame_geom_id]
- frame_mat = self._data.geom_xmat[:, self._frame_geom_id].view(-1, 3, 3)
+ else:
+ body_pos = self._data.xpos[:, self._frame_body_id]
+ body_mat = self._data.xmat[:, self._frame_body_id].view(-1, 3, 3)
+ if self._frame_type == "site":
+ frame_mat = self._data.site_xmat[:, self._frame_site_id].view(-1, 3, 3)
+ else:
+ frame_mat = self._data.geom_xmat[:, self._frame_geom_id].view(-1, 3, 3)
+ body_align = self._compute_alignment_rotation(body_mat)
+ frame_pos = body_pos + torch.einsum(
+ "bij,j->bi", body_align, self._frame_local_pos
+ )
num_envs = frame_pos.shape[0]
rot_mat = self._compute_alignment_rotation(frame_mat)
diff --git a/tests/test_raycast_sensor.py b/tests/test_raycast_sensor.py
index 44b78a142..1ea3fafde 100644
--- a/tests/test_raycast_sensor.py
+++ b/tests/test_raycast_sensor.py
@@ -534,11 +534,6 @@ def test_raycast_body_rotation_affects_rays(device):
), f"Expected ~{expected_distance:.2f}m, got {data_rotated.distances}"
-# ============================================================================
-# Pinhole Camera Pattern Tests
-# ============================================================================
-
-
def test_pinhole_camera_pattern_num_rays(device):
"""Verify pinhole pattern generates width * height rays."""
simple_xml = """
@@ -671,11 +666,6 @@ def test_pinhole_from_mujoco_camera_fovy_mode(device):
assert torch.all(data.distances >= 0) # Should hit floor
-# ============================================================================
-# Ray Alignment Tests
-# ============================================================================
-
-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Likely bug on CPU MjWarp")
def test_ray_alignment_yaw(device):
"""Verify yaw alignment ignores pitch/roll."""
@@ -943,3 +933,121 @@ def test_height_scan_misses(device):
assert torch.allclose(
heights, torch.full_like(heights, raycast_cfg.max_distance), atol=1e-5
)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="Likely bug on CPU MjWarp")
+def test_ray_alignment_yaw_site_offset(device):
+ """Verify yaw alignment correctly aligns frame offset for sites.
+
+ When a site has a large Z offset from its parent body, the frame position should be
+ recomputed using the body's yaw-only rotation so that pitch/roll does not swing the
+ ray origin.
+ """
+ xml = """
+
+
+
+
+
+
+
+
+
+
+
+ """
+
+ raycast_cfg = RayCastSensorCfg(
+ name="yaw_site_scan",
+ frame=ObjRef(type="site", name="high_site", entity="robot"),
+ pattern=GridPatternCfg(size=(0.0, 0.0), resolution=0.1, direction=(0.0, 0.0, -1.0)),
+ ray_alignment="yaw",
+ max_distance=50.0,
+ )
+
+ scene, sim = _make_scene_and_sim(device, xml, sensors=(raycast_cfg,))
+ sensor = scene["yaw_site_scan"]
+
+ # Baseline: unrotated body. Site at z=22, floor at z=0 -> distance ~22m.
+ sim.step()
+ scene.update(dt=sim.cfg.mujoco.timestep)
+ sim.sense()
+ baseline_dist = sensor.data.distances.clone()
+ assert torch.allclose(baseline_dist, torch.full_like(baseline_dist, 22.0), atol=0.2)
+
+ # Tilt body 30 degrees around X axis (pitch).
+ angle = math.pi / 6
+ quat = [math.cos(angle / 2), math.sin(angle / 2), 0, 0]
+ sim.data.qpos[0, 3:7] = torch.tensor(quat, device=device)
+ sim.step()
+ scene.update(dt=sim.cfg.mujoco.timestep)
+ sim.sense()
+ tilted_dist = sensor.data.distances.clone()
+
+ # With the fix, distance should remain ~22m because yaw alignment prevents the site
+ # offset from swinging with pitch.
+ assert torch.allclose(tilted_dist, baseline_dist, atol=0.5), (
+ f"Expected ~22m after pitch, got {tilted_dist}"
+ )
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="Likely bug on CPU MjWarp")
+def test_ray_alignment_world_site_offset(device):
+ """Verify world alignment correctly aligns frame offset for sites.
+
+ With world alignment the site offset should be rotated by the identity matrix, so
+ combined pitch and roll on the body should not move the ray origin away from directly
+ above the body.
+ """
+ xml = """
+
+
+
+
+
+
+
+
+
+
+
+ """
+
+ raycast_cfg = RayCastSensorCfg(
+ name="world_site_scan",
+ frame=ObjRef(type="site", name="high_site", entity="robot"),
+ pattern=GridPatternCfg(size=(0.0, 0.0), resolution=0.1, direction=(0.0, 0.0, -1.0)),
+ ray_alignment="world",
+ max_distance=50.0,
+ )
+
+ scene, sim = _make_scene_and_sim(device, xml, sensors=(raycast_cfg,))
+ sensor = scene["world_site_scan"]
+
+ # Baseline: unrotated body. Site at z=22 -> distance ~22m.
+ sim.step()
+ scene.update(dt=sim.cfg.mujoco.timestep)
+ sim.sense()
+ baseline_dist = sensor.data.distances.clone()
+ assert torch.allclose(baseline_dist, torch.full_like(baseline_dist, 22.0), atol=0.2)
+
+ # Apply combined pitch (30 deg) + roll (20 deg).
+ pitch = math.pi / 6
+ roll = math.pi / 9
+ cp, sp = math.cos(pitch / 2), math.sin(pitch / 2)
+ cr, sr = math.cos(roll / 2), math.sin(roll / 2)
+ # q_pitch (around X) then q_roll (around Y): q = q_roll * q_pitch
+ qw = cr * cp
+ qx = cr * sp
+ qy = sr * cp
+ qz = -sr * sp
+ sim.data.qpos[0, 3:7] = torch.tensor([qw, qx, qy, qz], device=device)
+ sim.step()
+ scene.update(dt=sim.cfg.mujoco.timestep)
+ sim.sense()
+ rotated_dist = sensor.data.distances.clone()
+
+ # With world alignment, distance should remain ~22m.
+ assert torch.allclose(rotated_dist, baseline_dist, atol=0.5), (
+ f"Expected ~22m after pitch+roll, got {rotated_dist}"
+ )