diff --git a/data/.lfs/xarm7.tar.gz b/data/.lfs/xarm7.tar.gz new file mode 100644 index 0000000000..b19d8d919a --- /dev/null +++ b/data/.lfs/xarm7.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc5c96439cc415d7d7b1296363b5684354aaef22c7dbe8e50bce81183401511c +size 6297600 diff --git a/dimos/e2e_tests/test_simulation_module.py b/dimos/e2e_tests/test_simulation_module.py new file mode 100644 index 0000000000..c1b5e68539 --- /dev/null +++ b/dimos/e2e_tests/test_simulation_module.py @@ -0,0 +1,110 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for the simulation module.""" + +import os + +import pytest + +from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState + + +def _positions_within_tolerance( + positions: list[float], + target: list[float], + tolerance: float, +) -> bool: + if len(positions) < len(target): + return False + return all(abs(positions[i] - target[i]) <= tolerance for i in range(len(target))) + + +@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM doesn't work in CI.") +@pytest.mark.e2e +class TestSimulationModuleE2E: + def test_xarm7_joint_state_published( + self, + lcm_spy, + start_blueprint, + monkeypatch, + ) -> None: + monkeypatch.setenv("DIMOS_HEADLESS", "1") + monkeypatch.delenv("DIMOS_MUJOCO_FORCE_SUBPROCESS", raising=False) + + joint_state_topic = "/xarm/joint_states#sensor_msgs.JointState" + lcm_spy.save_topic(joint_state_topic) + + start_blueprint("xarm7-trajectory-sim") + lcm_spy.wait_for_saved_topic(joint_state_topic, timeout=30.0) + + with lcm_spy._messages_lock: + raw_joint_state = lcm_spy.messages[joint_state_topic][0] + + joint_state = JointState.lcm_decode(raw_joint_state) + assert len(joint_state.name) == 8 + assert len(joint_state.position) == 8 + + def test_xarm7_robot_state_published( + self, + lcm_spy, + start_blueprint, + monkeypatch, + ) -> None: + monkeypatch.setenv("DIMOS_HEADLESS", "1") + monkeypatch.delenv("DIMOS_MUJOCO_FORCE_SUBPROCESS", raising=False) + + robot_state_topic = "/xarm/robot_state#sensor_msgs.RobotState" + lcm_spy.save_topic(robot_state_topic) + + start_blueprint("xarm7-trajectory-sim") + lcm_spy.wait_for_saved_topic(robot_state_topic, timeout=30.0) + + with lcm_spy._messages_lock: + raw_robot_state = lcm_spy.messages[robot_state_topic][0] + + robot_state = RobotState.lcm_decode(raw_robot_state) + assert robot_state.mt_able in (0, 1) + + def test_xarm7_joint_command_updates_joint_state( + self, + lcm_spy, + start_blueprint, + monkeypatch, + ) -> None: + monkeypatch.setenv("DIMOS_HEADLESS", "1") + monkeypatch.delenv("DIMOS_MUJOCO_FORCE_SUBPROCESS", raising=False) + + joint_state_topic = "/xarm/joint_states#sensor_msgs.JointState" + joint_command_topic = "/xarm/joint_position_command#sensor_msgs.JointCommand" + lcm_spy.save_topic(joint_state_topic) + + start_blueprint("xarm7-trajectory-sim") + lcm_spy.wait_for_saved_topic(joint_state_topic, timeout=30.0) + + target_positions = [0.2, -0.2, 0.1, -0.1, 0.15, -0.15, 0.05] + lcm_spy.publish(joint_command_topic, JointCommand(positions=target_positions)) + + tolerance = 0.03 + lcm_spy.wait_for_message_result( + joint_state_topic, + JointState, + predicate=lambda msg: _positions_within_tolerance( + list(msg.position), + target_positions, + tolerance, + ), + fail_message=("joint_state did not reach commanded positions within tolerance"), + timeout=10.0, + ) diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 461f602ccb..29a103e55e 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -53,6 +53,7 @@ "unitree-go2-spatial": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:unitree_go2_spatial", "unitree-go2-temporal-memory": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:unitree_go2_temporal_memory", "unitree-go2-vlm-stream-test": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:unitree_go2_vlm_stream_test", + "xarm7-trajectory-sim": "dimos.simulation.sim_blueprints:xarm7_trajectory_sim", } @@ -88,6 +89,7 @@ "replanning_a_star_planner": "dimos.navigation.replanning_a_star.module", "rerun_scene_wiring": "dimos.dashboard.rerun_scene_wiring", "ros_nav": "dimos.navigation.rosnav", + "simulation": "dimos.simulation.manipulators.sim_module", "spatial_memory": "dimos.perception.spatial_perception", "speak_skill": "dimos.agents.skills.speak_skill", "temporal_memory": "dimos.perception.experimental.temporal_memory.temporal_memory", diff --git a/dimos/simulation/engines/__init__.py b/dimos/simulation/engines/__init__.py new file mode 100644 index 0000000000..d437f9a7cd --- /dev/null +++ b/dimos/simulation/engines/__init__.py @@ -0,0 +1,25 @@ +"""Simulation engines for manipulator backends.""" + +from __future__ import annotations + +from typing import Literal + +from dimos.simulation.engines.base import SimulationEngine +from dimos.simulation.engines.mujoco_engine import MujocoEngine + +EngineType = Literal["mujoco"] + +_ENGINES: dict[EngineType, type[SimulationEngine]] = { + "mujoco": MujocoEngine, +} + + +def get_engine(engine_name: EngineType) -> type[SimulationEngine]: + return _ENGINES[engine_name] + + +__all__ = [ + "EngineType", + "SimulationEngine", + "get_engine", +] diff --git a/dimos/simulation/engines/base.py b/dimos/simulation/engines/base.py new file mode 100644 index 0000000000..d450614c62 --- /dev/null +++ b/dimos/simulation/engines/base.py @@ -0,0 +1,84 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base interfaces for simulator engines.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + + from dimos.msgs.sensor_msgs import JointState + + +class SimulationEngine(ABC): + """Abstract base class for a simulator engine instance.""" + + def __init__(self, config_path: Path, headless: bool) -> None: + self._config_path = config_path + self._headless = headless + + @property + def config_path(self) -> Path: + return self._config_path + + @property + def headless(self) -> bool: + return self._headless + + @abstractmethod + def connect(self) -> bool: + """Connect to simulation and start the engine.""" + + @abstractmethod + def disconnect(self) -> bool: + """Disconnect from simulation and stop the engine.""" + + @property + @abstractmethod + def connected(self) -> bool: + """Whether the engine is connected.""" + + @property + @abstractmethod + def num_joints(self) -> int: + """Number of joints for the loaded robot.""" + + @property + @abstractmethod + def joint_names(self) -> list[str]: + """Joint names for the loaded robot.""" + + @abstractmethod + def read_joint_positions(self) -> list[float]: + """Read joint positions in radians.""" + + @abstractmethod + def read_joint_velocities(self) -> list[float]: + """Read joint velocities in rad/s.""" + + @abstractmethod + def read_joint_efforts(self) -> list[float]: + """Read joint efforts in Nm.""" + + @abstractmethod + def write_joint_command(self, command: JointState) -> None: + """Command joints using a JointState message.""" + + @abstractmethod + def hold_current_position(self) -> None: + """Hold current joint positions.""" diff --git a/dimos/simulation/engines/mujoco_engine.py b/dimos/simulation/engines/mujoco_engine.py new file mode 100644 index 0000000000..3522552bd7 --- /dev/null +++ b/dimos/simulation/engines/mujoco_engine.py @@ -0,0 +1,408 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MuJoCo simulation engine implementation.""" + +from __future__ import annotations + +import json +from pathlib import Path +import subprocess +import sys +import threading +import time +from typing import TYPE_CHECKING + +import mujoco +import mujoco.viewer as viewer # type: ignore[import-untyped] + +from dimos.simulation.engines.base import SimulationEngine +from dimos.simulation.manipulators.mujoco_subprocess.constants import LAUNCHER_PATH +from dimos.simulation.manipulators.mujoco_subprocess.shared_memory import ShmWriter +from dimos.simulation.utils.xml_parser import JointMapping, build_joint_mappings +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.msgs.sensor_msgs import JointState + + pass + +logger = setup_logger() + +_MODE_POSITION = 0 +_MODE_VELOCITY = 1 +_MODE_EFFORT = 2 + + +class MujocoEngine(SimulationEngine): + """ + MuJoCo simulation engine. + + - starts MuJoCo simulation engine + - loads robot/environment into simulation + - applies control commands + """ + + def __init__(self, config_path: Path, headless: bool) -> None: + super().__init__(config_path=config_path, headless=headless) + + xml_path = self._resolve_xml_path(config_path) + self._model = mujoco.MjModel.from_xml_path(str(xml_path)) + self._xml_path = xml_path + + self._data = mujoco.MjData(self._model) + self._joint_mappings = build_joint_mappings(self._xml_path, self._model) + self._joint_names = [mapping.name for mapping in self._joint_mappings] + self._num_joints = len(self._joint_names) + timestep = float(self._model.opt.timestep) + self._control_frequency = 1.0 / timestep if timestep > 0.0 else 100.0 + + self._connected = False + self._use_subprocess = sys.platform == "darwin" and not headless + self._process: subprocess.Popen[bytes] | None = None + self._shm: ShmWriter | None = None + self._lock = threading.Lock() + self._stop_event = threading.Event() + self._sim_thread: threading.Thread | None = None + + self._joint_positions = [0.0] * self._num_joints + self._joint_velocities = [0.0] * self._num_joints + self._joint_efforts = [0.0] * self._num_joints + + self._joint_position_targets = [0.0] * self._num_joints + self._joint_velocity_targets = [0.0] * self._num_joints + self._joint_effort_targets = [0.0] * self._num_joints + self._command_mode = "position" + for i, mapping in enumerate(self._joint_mappings): + current_pos = self._current_position(mapping) + self._joint_position_targets[i] = current_pos + self._joint_positions[i] = current_pos + + def _resolve_xml_path(self, config_path: Path) -> Path: + resolved = config_path.expanduser() + xml_path = resolved / "scene.xml" if resolved.is_dir() else resolved + if not xml_path.exists(): + raise FileNotFoundError(f"MuJoCo XML not found: {xml_path}") + return xml_path + + def _current_position(self, mapping: JointMapping) -> float: + if mapping.joint_id is not None and mapping.qpos_adr is not None: + return float(self._data.qpos[mapping.qpos_adr]) + if mapping.tendon_qpos_adrs: + return float( + sum(self._data.qpos[adr] for adr in mapping.tendon_qpos_adrs) + / len(mapping.tendon_qpos_adrs) + ) + if mapping.actuator_id is not None: + return float(self._data.actuator_length[mapping.actuator_id]) + return 0.0 + + def _apply_control(self) -> None: + with self._lock: + if self._command_mode == "effort": + targets = list(self._joint_effort_targets) + elif self._command_mode == "velocity": + targets = list(self._joint_velocity_targets) + elif self._command_mode == "position": + targets = list(self._joint_position_targets) + for i, mapping in enumerate(self._joint_mappings): + if mapping.actuator_id is None: + continue + if i < len(targets): + self._data.ctrl[mapping.actuator_id] = targets[i] + + def _update_joint_state(self) -> None: + with self._lock: + for i, mapping in enumerate(self._joint_mappings): + if mapping.joint_id is not None: + if mapping.qpos_adr is not None: + self._joint_positions[i] = float(self._data.qpos[mapping.qpos_adr]) + if mapping.dof_adr is not None: + self._joint_velocities[i] = float(self._data.qvel[mapping.dof_adr]) + self._joint_efforts[i] = float(self._data.qfrc_actuator[mapping.dof_adr]) + continue + + if mapping.tendon_qpos_adrs: + pos_sum = sum(self._data.qpos[adr] for adr in mapping.tendon_qpos_adrs) + count = len(mapping.tendon_qpos_adrs) + self._joint_positions[i] = float(pos_sum / count) + if mapping.tendon_dof_adrs: + vel_sum = sum(self._data.qvel[adr] for adr in mapping.tendon_dof_adrs) + self._joint_velocities[i] = float(vel_sum / len(mapping.tendon_dof_adrs)) + else: + self._joint_velocities[i] = 0.0 + elif mapping.actuator_id is not None: + self._joint_positions[i] = float( + self._data.actuator_length[mapping.actuator_id] + ) + self._joint_velocities[i] = 0.0 + + if mapping.actuator_id is not None: + self._joint_efforts[i] = float(self._data.actuator_force[mapping.actuator_id]) + + def connect(self) -> bool: + try: + logger.info(f"{self.__class__.__name__}: connect()") + if self._use_subprocess: + if self._connected: + return True + + self._shm = ShmWriter(self._num_joints) + if sys.platform == "darwin": + mjpython = Path(sys.executable).with_name("mjpython") + executable = str(mjpython) if mjpython.exists() else "mjpython" + else: + executable = sys.executable + args = [ + executable, + str(LAUNCHER_PATH), + str(self._xml_path), + "1" if self._headless else "0", + str(self._num_joints), + json.dumps(self._shm.shm.to_names()), + ] + self._process = subprocess.Popen(args) + + ready_timeout = 30.0 + start_time = time.time() + while time.time() - start_time < ready_timeout: + if self._process.poll() is not None: + exit_code = self._process.returncode + self.disconnect() + raise RuntimeError( + f"MuJoCo subprocess failed to start (exit code {exit_code})" + ) + if self._shm.is_ready(): + self._connected = True + return True + time.sleep(0.1) + + self.disconnect() + raise RuntimeError("MuJoCo subprocess failed to start (timeout)") + with self._lock: + self._connected = True + self._stop_event.clear() + + if self._sim_thread is None or not self._sim_thread.is_alive(): + self._sim_thread = threading.Thread( + target=self._sim_loop, + name=f"{self.__class__.__name__}Sim", + daemon=True, + ) + self._sim_thread.start() + return True + except Exception as e: + logger.error(f"{self.__class__.__name__}: connect() failed: {e}") + return False + + def disconnect(self) -> bool: + try: + logger.info(f"{self.__class__.__name__}: disconnect()") + if self._use_subprocess: + self._connected = False + if self._shm: + self._shm.signal_stop() + + if self._process: + try: + self._process.terminate() + self._process.wait(timeout=5) + except subprocess.TimeoutExpired: + logger.warning("MuJoCo subprocess did not stop gracefully, killing") + self._process.kill() + self._process.wait(timeout=2) + self._process = None + + if self._shm: + self._shm.cleanup() + self._shm = None + return True + with self._lock: + self._connected = False + self._stop_event.set() + if self._sim_thread and self._sim_thread.is_alive(): + self._sim_thread.join(timeout=2.0) + self._sim_thread = None + return True + except Exception as e: + logger.error(f"{self.__class__.__name__}: disconnect() failed: {e}") + return False + + def _sim_loop(self) -> None: + logger.info(f"{self.__class__.__name__}: sim loop started") + dt = 1.0 / self._control_frequency + + def _step_once(sync_viewer: bool) -> None: + loop_start = time.time() + self._apply_control() + mujoco.mj_step(self._model, self._data) + if sync_viewer: + m_viewer.sync() + self._update_joint_state() + + elapsed = time.time() - loop_start + sleep_time = dt - elapsed + if sleep_time > 0: + time.sleep(sleep_time) + + if self._headless: + while not self._stop_event.is_set(): + _step_once(sync_viewer=False) + else: + with viewer.launch_passive( + self._model, self._data, show_left_ui=False, show_right_ui=False + ) as m_viewer: + while m_viewer.is_running() and not self._stop_event.is_set(): + _step_once(sync_viewer=True) + + logger.info(f"{self.__class__.__name__}: sim loop stopped") + + @property + def connected(self) -> bool: + if self._use_subprocess: + if not self._connected: + return False + if self._process and self._process.poll() is not None: + self._connected = False + return False + return True + with self._lock: + return self._connected + + @property + def num_joints(self) -> int: + return self._num_joints + + @property + def joint_names(self) -> list[str]: + return list(self._joint_names) + + @property + def model(self) -> mujoco.MjModel: + return self._model + + @property + def joint_positions(self) -> list[float]: + with self._lock: + return list(self._joint_positions) + + @property + def joint_velocities(self) -> list[float]: + with self._lock: + return list(self._joint_velocities) + + @property + def joint_efforts(self) -> list[float]: + with self._lock: + return list(self._joint_efforts) + + @property + def control_frequency(self) -> float: + return self._control_frequency + + def read_joint_positions(self) -> list[float]: + if self._use_subprocess: + if not self._shm: + return [] + positions, _, _ = self._shm.read_state() + return positions + return self.joint_positions + + def read_joint_velocities(self) -> list[float]: + if self._use_subprocess: + if not self._shm: + return [] + _, velocities, _ = self._shm.read_state() + return velocities + return self.joint_velocities + + def read_joint_efforts(self) -> list[float]: + if self._use_subprocess: + if not self._shm: + return [] + _, _, efforts = self._shm.read_state() + return efforts + return self.joint_efforts + + def write_joint_command(self, command: JointState) -> None: + if self._use_subprocess: + if not self._shm: + return + if command.position: + self._shm.write_command(_MODE_POSITION, positions=list(command.position)) + return + if command.velocity: + self._shm.write_command(_MODE_VELOCITY, velocities=list(command.velocity)) + return + if command.effort: + self._shm.write_command(_MODE_EFFORT, efforts=list(command.effort)) + return + return + if command.position: + self._command_mode = "position" + self._set_position_targets(command.position) + return + if command.velocity: + self._command_mode = "velocity" + self._set_velocity_targets(command.velocity) + return + if command.effort: + self._command_mode = "effort" + self._set_effort_targets(command.effort) + return + + def _set_position_targets(self, positions: list[float]) -> None: + if len(positions) > self._num_joints: + raise ValueError( + f"Position command has {len(positions)} joints, expected at most {self._num_joints}" + ) + with self._lock: + for i in range(len(positions)): + self._joint_position_targets[i] = float(positions[i]) + + def _set_velocity_targets(self, velocities: list[float]) -> None: + if len(velocities) > self._num_joints: + raise ValueError( + f"Velocity command has {len(velocities)} joints, expected at most {self._num_joints}" + ) + with self._lock: + for i in range(len(velocities)): + self._joint_velocity_targets[i] = float(velocities[i]) + + def _set_effort_targets(self, efforts: list[float]) -> None: + if len(efforts) > self._num_joints: + raise ValueError( + f"Effort command has {len(efforts)} joints, expected at most {self._num_joints}" + ) + with self._lock: + for i in range(len(efforts)): + self._joint_effort_targets[i] = float(efforts[i]) + + def hold_current_position(self) -> None: + if self._use_subprocess: + if not self._shm: + return + positions = self.read_joint_positions() + if positions: + self._shm.write_command(_MODE_POSITION, positions=positions) + return + with self._lock: + self._command_mode = "position" + for i, mapping in enumerate(self._joint_mappings): + self._joint_position_targets[i] = self._current_position(mapping) + + +__all__ = [ + "MujocoEngine", +] diff --git a/dimos/simulation/manipulators/__init__.py b/dimos/simulation/manipulators/__init__.py new file mode 100644 index 0000000000..816de0a18d --- /dev/null +++ b/dimos/simulation/manipulators/__init__.py @@ -0,0 +1,54 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulation manipulator utilities.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface + from dimos.simulation.manipulators.sim_module import ( + SimulationModule, + SimulationModuleConfig, + simulation, + ) + +__all__ = [ + "SimManipInterface", + "SimulationModule", + "SimulationModuleConfig", + "simulation", +] + + +def __getattr__(name: str): # type: ignore[no-untyped-def] + if name == "SimManipInterface": + from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface + + return SimManipInterface + if name in {"SimulationModule", "SimulationModuleConfig", "simulation"}: + from dimos.simulation.manipulators.sim_module import ( + SimulationModule, + SimulationModuleConfig, + simulation, + ) + + return { + "SimulationModule": SimulationModule, + "SimulationModuleConfig": SimulationModuleConfig, + "simulation": simulation, + }[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/dimos/simulation/manipulators/mujoco_subprocess/__init__.py b/dimos/simulation/manipulators/mujoco_subprocess/__init__.py new file mode 100644 index 0000000000..bc1a2ce5cc --- /dev/null +++ b/dimos/simulation/manipulators/mujoco_subprocess/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dimos/simulation/manipulators/mujoco_subprocess/constants.py b/dimos/simulation/manipulators/mujoco_subprocess/constants.py new file mode 100644 index 0000000000..5073333647 --- /dev/null +++ b/dimos/simulation/manipulators/mujoco_subprocess/constants.py @@ -0,0 +1,17 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +LAUNCHER_PATH = Path(__file__).parent / "mujoco_subprocess.py" diff --git a/dimos/simulation/manipulators/mujoco_subprocess/mujoco_subprocess.py b/dimos/simulation/manipulators/mujoco_subprocess/mujoco_subprocess.py new file mode 100644 index 0000000000..bd9b1e0167 --- /dev/null +++ b/dimos/simulation/manipulators/mujoco_subprocess/mujoco_subprocess.py @@ -0,0 +1,175 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path +import signal +import sys +import time +from typing import Any + +import mujoco +import mujoco.viewer as viewer # type: ignore[import-untyped] + +from dimos.simulation.manipulators.mujoco_subprocess.shared_memory import ShmReader +from dimos.simulation.utils.xml_parser import JointMapping, build_joint_mappings + +_MODE_POSITION = 0 +_MODE_VELOCITY = 1 +_MODE_EFFORT = 2 + + +def _resolve_xml_path(config_path: Path) -> Path: + resolved = config_path.expanduser() + xml_path = resolved / "scene.xml" if resolved.is_dir() else resolved + if not xml_path.exists(): + raise FileNotFoundError(f"MuJoCo XML not found: {xml_path}") + return xml_path + + +def _current_position(data: mujoco.MjData, mapping: JointMapping) -> float: + if mapping.joint_id is not None and mapping.qpos_adr is not None: + return float(data.qpos[mapping.qpos_adr]) + if mapping.tendon_qpos_adrs: + return float( + sum(data.qpos[adr] for adr in mapping.tendon_qpos_adrs) / len(mapping.tendon_qpos_adrs) + ) + if mapping.actuator_id is not None: + return float(data.actuator_length[mapping.actuator_id]) + return 0.0 + + +def _run_simulation(xml_path: Path, headless: bool, shm: ShmReader, dof: int) -> None: + model = mujoco.MjModel.from_xml_path(str(xml_path)) + data = mujoco.MjData(model) + + joint_mappings = build_joint_mappings(xml_path, model) + num_joints = len(joint_mappings) + + if num_joints != dof: + raise ValueError(f"Shared memory DOF mismatch: shm={dof} model={num_joints}") + + joint_position_targets = [0.0] * num_joints + joint_velocity_targets = [0.0] * num_joints + joint_effort_targets = [0.0] * num_joints + command_mode = _MODE_POSITION + + for i, mapping in enumerate(joint_mappings): + current_pos = _current_position(data, mapping) + joint_position_targets[i] = current_pos + + control_frequency = ( + 1.0 / float(model.opt.timestep) if float(model.opt.timestep) > 0.0 else 100.0 + ) + dt = 1.0 / control_frequency + + def apply_control() -> None: + if command_mode == _MODE_EFFORT: + targets = joint_effort_targets + elif command_mode == _MODE_VELOCITY: + targets = joint_velocity_targets + else: + targets = joint_position_targets + for i, mapping in enumerate(joint_mappings): + if mapping.actuator_id is None or i >= len(targets): + continue + data.ctrl[mapping.actuator_id] = targets[i] + + def update_joint_state() -> tuple[list[float], list[float], list[float]]: + positions = [0.0] * num_joints + velocities = [0.0] * num_joints + efforts = [0.0] * num_joints + for i, mapping in enumerate(joint_mappings): + if mapping.joint_id is not None: + if mapping.qpos_adr is not None: + positions[i] = float(data.qpos[mapping.qpos_adr]) + if mapping.dof_adr is not None: + velocities[i] = float(data.qvel[mapping.dof_adr]) + efforts[i] = float(data.qfrc_actuator[mapping.dof_adr]) + continue + + if mapping.tendon_qpos_adrs: + pos_sum = sum(data.qpos[adr] for adr in mapping.tendon_qpos_adrs) + positions[i] = float(pos_sum / len(mapping.tendon_qpos_adrs)) + if mapping.tendon_dof_adrs: + vel_sum = sum(data.qvel[adr] for adr in mapping.tendon_dof_adrs) + velocities[i] = float(vel_sum / len(mapping.tendon_dof_adrs)) + elif mapping.actuator_id is not None: + positions[i] = float(data.actuator_length[mapping.actuator_id]) + + if mapping.actuator_id is not None: + efforts[i] = float(data.actuator_force[mapping.actuator_id]) + return positions, velocities, efforts + + def step_once(sync_viewer: bool) -> None: + nonlocal command_mode + loop_start = time.time() + cmd = shm.read_command() + if cmd is not None: + mode, cmd_pos, cmd_vel, cmd_eff = cmd + if mode == _MODE_POSITION: + joint_position_targets[:] = cmd_pos.tolist() + command_mode = _MODE_POSITION + elif mode == _MODE_VELOCITY: + joint_velocity_targets[:] = cmd_vel.tolist() + command_mode = _MODE_VELOCITY + elif mode == _MODE_EFFORT: + joint_effort_targets[:] = cmd_eff.tolist() + command_mode = _MODE_EFFORT + else: + command_mode = _MODE_POSITION + + apply_control() + mujoco.mj_step(model, data) + if sync_viewer: + m_viewer.sync() + + positions, velocities, efforts = update_joint_state() + shm.write_state(positions, velocities, efforts) + + elapsed = time.time() - loop_start + sleep_time = dt - elapsed + if sleep_time > 0: + time.sleep(sleep_time) + + shm.signal_ready() + + if headless: + while not shm.should_stop(): + step_once(sync_viewer=False) + return + + with viewer.launch_passive(model, data, show_left_ui=False, show_right_ui=False) as m_viewer: + while m_viewer.is_running() and not shm.should_stop(): + step_once(sync_viewer=True) + + +if __name__ == "__main__": + + def signal_handler(_signum: int, _frame: Any) -> None: + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + xml_path = Path(sys.argv[1]) + headless = bool(int(sys.argv[2])) + dof = int(sys.argv[3]) + shm_names = json.loads(sys.argv[4]) + + shm = ShmReader(shm_names, dof) + try: + _run_simulation(_resolve_xml_path(xml_path), headless, shm, dof) + finally: + shm.cleanup() diff --git a/dimos/simulation/manipulators/mujoco_subprocess/shared_memory.py b/dimos/simulation/manipulators/mujoco_subprocess/shared_memory.py new file mode 100644 index 0000000000..08cfd01c9a --- /dev/null +++ b/dimos/simulation/manipulators/mujoco_subprocess/shared_memory.py @@ -0,0 +1,236 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from multiprocessing import resource_tracker +from multiprocessing.shared_memory import SharedMemory +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +_FLOAT_BYTES = 8 +_INT_BYTES = 4 +_INT64_BYTES = 8 + + +def _shm_sizes(dof: int) -> dict[str, int]: + return { + "control": 2 * _INT_BYTES, # ready, stop + "seq": 2 * _INT64_BYTES, # cmd, state + "mode": _INT_BYTES, # command mode + "cmd_pos": dof * _FLOAT_BYTES, + "cmd_vel": dof * _FLOAT_BYTES, + "cmd_eff": dof * _FLOAT_BYTES, + "state_pos": dof * _FLOAT_BYTES, + "state_vel": dof * _FLOAT_BYTES, + "state_eff": dof * _FLOAT_BYTES, + } + + +def _unregister(shm: SharedMemory) -> SharedMemory: + try: + resource_tracker.unregister(shm._name, "shared_memory") # type: ignore[attr-defined] + except Exception: + pass + return shm + + +@dataclass(frozen=True) +class ShmSet: + control: SharedMemory + seq: SharedMemory + mode: SharedMemory + cmd_pos: SharedMemory + cmd_vel: SharedMemory + cmd_eff: SharedMemory + state_pos: SharedMemory + state_vel: SharedMemory + state_eff: SharedMemory + + @classmethod + def from_names(cls, shm_names: dict[str, str]) -> "ShmSet": + return cls(**{k: _unregister(SharedMemory(name=shm_names[k])) for k in shm_names}) + + @classmethod + def from_sizes(cls, sizes: dict[str, int]) -> "ShmSet": + return cls(**{k: _unregister(SharedMemory(create=True, size=sizes[k])) for k in sizes}) + + def to_names(self) -> dict[str, str]: + return {name: getattr(self, name).name for name in self.__dataclass_fields__} + + def as_list(self) -> list[SharedMemory]: + return [getattr(self, name) for name in self.__dataclass_fields__] + + +class ShmReader: + shm: ShmSet + _last_cmd_seq: int + _dof: int + + def __init__(self, shm_names: dict[str, str], dof: int) -> None: + self.shm = ShmSet.from_names(shm_names) + self._last_cmd_seq = 0 + self._dof = dof + + def signal_ready(self) -> None: + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + control_array[0] = 1 + + def should_stop(self) -> bool: + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + return bool(control_array[1] == 1) + + def read_command(self) -> tuple[int, NDArray[Any], NDArray[Any], NDArray[Any]] | None: + seq = self._get_seq(0) + if seq <= self._last_cmd_seq: + return None + self._last_cmd_seq = seq + mode = int(np.ndarray((1,), dtype=np.int32, buffer=self.shm.mode.buf)[0]) + cmd_pos: NDArray[Any] = np.ndarray( + (self._dof,), dtype=np.float64, buffer=self.shm.cmd_pos.buf + ).copy() + cmd_vel: NDArray[Any] = np.ndarray( + (self._dof,), dtype=np.float64, buffer=self.shm.cmd_vel.buf + ).copy() + cmd_eff: NDArray[Any] = np.ndarray( + (self._dof,), dtype=np.float64, buffer=self.shm.cmd_eff.buf + ).copy() + return mode, cmd_pos, cmd_vel, cmd_eff + + def write_state( + self, + positions: list[float], + velocities: list[float], + efforts: list[float], + ) -> None: + pos_array: NDArray[Any] = np.ndarray( + (self._dof,), dtype=np.float64, buffer=self.shm.state_pos.buf + ) + vel_array: NDArray[Any] = np.ndarray( + (self._dof,), dtype=np.float64, buffer=self.shm.state_vel.buf + ) + eff_array: NDArray[Any] = np.ndarray( + (self._dof,), dtype=np.float64, buffer=self.shm.state_eff.buf + ) + pos_array[:] = positions + vel_array[:] = velocities + eff_array[:] = efforts + self._increment_seq(1) + + def _increment_seq(self, index: int) -> None: + seq_array: NDArray[Any] = np.ndarray((2,), dtype=np.int64, buffer=self.shm.seq.buf) + seq_array[index] += 1 + + def _get_seq(self, index: int) -> int: + seq_array: NDArray[Any] = np.ndarray((2,), dtype=np.int64, buffer=self.shm.seq.buf) + return int(seq_array[index]) + + def cleanup(self) -> None: + for shm in self.shm.as_list(): + try: + shm.close() + except Exception: + pass + + +class ShmWriter: + shm: ShmSet + _dof: int + + def __init__(self, dof: int) -> None: + self._dof = dof + sizes = _shm_sizes(dof) + self.shm = ShmSet.from_sizes(sizes) + + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + control_array[:] = 0 + + seq_array: NDArray[Any] = np.ndarray((2,), dtype=np.int64, buffer=self.shm.seq.buf) + seq_array[:] = 0 + + mode_array: NDArray[Any] = np.ndarray((1,), dtype=np.int32, buffer=self.shm.mode.buf) + mode_array[0] = 0 + + for name in ("cmd_pos", "cmd_vel", "cmd_eff", "state_pos", "state_vel", "state_eff"): + arr: NDArray[Any] = np.ndarray( + (dof,), dtype=np.float64, buffer=getattr(self.shm, name).buf + ) + arr[:] = 0.0 + + def is_ready(self) -> bool: + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + return bool(control_array[0] == 1) + + def signal_stop(self) -> None: + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + control_array[1] = 1 + + def write_command( + self, + mode: int, + positions: list[float] | None = None, + velocities: list[float] | None = None, + efforts: list[float] | None = None, + ) -> None: + mode_array: NDArray[Any] = np.ndarray((1,), dtype=np.int32, buffer=self.shm.mode.buf) + mode_array[0] = mode + + if positions is not None: + pos_array: NDArray[Any] = np.ndarray( + (self._dof,), dtype=np.float64, buffer=self.shm.cmd_pos.buf + ) + count = min(len(positions), self._dof) + pos_array[:count] = positions[:count] + if velocities is not None: + vel_array: NDArray[Any] = np.ndarray( + (self._dof,), dtype=np.float64, buffer=self.shm.cmd_vel.buf + ) + count = min(len(velocities), self._dof) + vel_array[:count] = velocities[:count] + if efforts is not None: + eff_array: NDArray[Any] = np.ndarray( + (self._dof,), dtype=np.float64, buffer=self.shm.cmd_eff.buf + ) + count = min(len(efforts), self._dof) + eff_array[:count] = efforts[:count] + + self._increment_seq(0) + + def read_state(self) -> tuple[list[float], list[float], list[float]]: + pos_array: NDArray[Any] = np.ndarray( + (self._dof,), dtype=np.float64, buffer=self.shm.state_pos.buf + ) + vel_array: NDArray[Any] = np.ndarray( + (self._dof,), dtype=np.float64, buffer=self.shm.state_vel.buf + ) + eff_array: NDArray[Any] = np.ndarray( + (self._dof,), dtype=np.float64, buffer=self.shm.state_eff.buf + ) + return pos_array.tolist(), vel_array.tolist(), eff_array.tolist() + + def _increment_seq(self, index: int) -> None: + seq_array: NDArray[Any] = np.ndarray((2,), dtype=np.int64, buffer=self.shm.seq.buf) + seq_array[index] += 1 + + def cleanup(self) -> None: + for shm in self.shm.as_list(): + try: + shm.close() + except Exception: + pass + try: + shm.unlink() + except Exception: + pass diff --git a/dimos/simulation/manipulators/sim_manip_interface.py b/dimos/simulation/manipulators/sim_manip_interface.py new file mode 100644 index 0000000000..c829f0c864 --- /dev/null +++ b/dimos/simulation/manipulators/sim_manip_interface.py @@ -0,0 +1,200 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulation-agnostic manipulator interface.""" + +from __future__ import annotations + +import logging +import math +from typing import TYPE_CHECKING + +from dimos.hardware.manipulators.spec import ControlMode, JointLimits, ManipulatorInfo +from dimos.msgs.sensor_msgs import JointState + +if TYPE_CHECKING: + from dimos.simulation.engines.base import SimulationEngine + + +class SimManipInterface: + """Adapter wrapper around a simulation engine to provide a uniform manipulator API.""" + + def __init__(self, engine: SimulationEngine) -> None: + self.logger = logging.getLogger(self.__class__.__name__) + self._engine = engine + self._joint_names = list(engine.joint_names) + self._dof = len(self._joint_names) + self._connected = False + self._servos_enabled = False + self._control_mode = ControlMode.POSITION + self._error_code = 0 + self._error_message = "" + + def connect(self) -> bool: + """Connect to the simulation engine.""" + try: + self.logger.info("Connecting to simulation engine...") + if not self._engine.connect(): + self.logger.error("Failed to connect to simulation engine") + return False + if self._engine.connected: + self._connected = True + self._servos_enabled = True + self._joint_names = list(self._engine.joint_names) + self._dof = len(self._joint_names) + self.logger.info( + "Successfully connected to simulation", + extra={"dof": self._dof}, + ) + return True + self.logger.error("Failed to connect to simulation engine") + return False + except Exception as exc: + self.logger.error(f"Sim connection failed: {exc}") + return False + + def disconnect(self) -> bool: + """Disconnect from simulation.""" + try: + return self._engine.disconnect() + except Exception as exc: + self._connected = False + self.logger.error(f"Sim disconnection failed: {exc}") + return False + + def is_connected(self) -> bool: + return bool(self._connected and self._engine.connected) + + def get_info(self) -> ManipulatorInfo: + vendor = "Simulation" + model = "Simulation" + dof = self._dof + return ManipulatorInfo( + vendor=vendor, + model=model, + dof=dof, + firmware_version=None, + serial_number=None, + ) + + def get_dof(self) -> int: + return self._dof + + def get_joint_names(self) -> list[str]: + return list(self._joint_names) + + def get_limits(self) -> JointLimits: + lower = [-math.pi] * self._dof + upper = [math.pi] * self._dof + max_vel_rad = math.radians(180.0) + return JointLimits( + position_lower=lower, + position_upper=upper, + velocity_max=[max_vel_rad] * self._dof, + ) + + def set_control_mode(self, mode: ControlMode) -> bool: + self._control_mode = mode + return True + + def get_control_mode(self) -> ControlMode: + return self._control_mode + + def read_joint_positions(self) -> list[float]: + positions = self._engine.read_joint_positions() + return positions[: self._dof] + + def read_joint_velocities(self) -> list[float]: + velocities = self._engine.read_joint_velocities() + return velocities[: self._dof] + + def read_joint_efforts(self) -> list[float]: + efforts = self._engine.read_joint_efforts() + return efforts[: self._dof] + + def read_state(self) -> dict[str, int]: + velocities = self.read_joint_velocities() + is_moving = any(abs(v) > 1e-4 for v in velocities) + mode_int = list(ControlMode).index(self._control_mode) + return { + "state": 1 if is_moving else 0, + "mode": mode_int, + } + + def read_error(self) -> tuple[int, str]: + return self._error_code, self._error_message + + def write_joint_positions(self, positions: list[float]) -> bool: + if not self._servos_enabled: + return False + self._control_mode = ControlMode.POSITION + self._engine.write_joint_command(JointState(position=positions[: self._dof])) + return True + + def write_joint_velocities(self, velocities: list[float]) -> bool: + if not self._servos_enabled: + return False + self._control_mode = ControlMode.VELOCITY + self._engine.write_joint_command(JointState(velocity=velocities[: self._dof])) + return True + + def write_joint_efforts(self, efforts: list[float]) -> bool: + if not self._servos_enabled: + return False + self._control_mode = ControlMode.TORQUE + self._engine.write_joint_command(JointState(effort=efforts[: self._dof])) + return True + + def write_stop(self) -> bool: + self._engine.hold_current_position() + return True + + def write_enable(self, enable: bool) -> bool: + self._servos_enabled = enable + return True + + def read_enabled(self) -> bool: + return self._servos_enabled + + def write_clear_errors(self) -> bool: + self._error_code = 0 + self._error_message = "" + return True + + def read_cartesian_position(self) -> dict[str, float] | None: + return None + + def write_cartesian_position( + self, + pose: dict[str, float], + velocity: float = 1.0, + ) -> bool: + _pose = pose + _velocity = velocity + return False + + def read_gripper_position(self) -> float | None: + return None + + def write_gripper_position(self, position: float) -> bool: + _ = position + return False + + def read_force_torque(self) -> list[float] | None: + return None + + +__all__ = [ + "SimManipInterface", +] diff --git a/dimos/simulation/manipulators/sim_module.py b/dimos/simulation/manipulators/sim_module.py new file mode 100644 index 0000000000..4f1bb986d3 --- /dev/null +++ b/dimos/simulation/manipulators/sim_module.py @@ -0,0 +1,247 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulator-agnostic manipulator simulation module.""" + +from __future__ import annotations + +from dataclasses import dataclass +import threading +import time +from typing import TYPE_CHECKING, Any + +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.core.module import ModuleConfig +from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState +from dimos.simulation.engines import EngineType, get_engine +from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface + +if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + + +@dataclass(kw_only=True) +class SimulationModuleConfig(ModuleConfig): + engine: EngineType + config_path: Path | Callable[[], Path] + headless: bool = False + + +class SimulationModule(Module[SimulationModuleConfig]): + """Module wrapper for manipulator simulation across engines.""" + + default_config = SimulationModuleConfig + config: SimulationModuleConfig + + joint_state: Out[JointState] + robot_state: Out[RobotState] + joint_position_command: In[JointCommand] + joint_velocity_command: In[JointCommand] + + MIN_CONTROL_RATE = 1.0 + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._backend: SimManipInterface | None = None + self._control_rate = 100.0 + self._monitor_rate = 100.0 + self._joint_prefix = "joint" + self._stop_event = threading.Event() + self._control_thread: threading.Thread | None = None + self._monitor_thread: threading.Thread | None = None + self._command_lock = threading.Lock() + self._pending_positions: list[float] | None = None + self._pending_velocities: list[float] | None = None + + def _create_backend(self) -> SimManipInterface: + engine_cls = get_engine(self.config.engine) + config_path = ( + self.config.config_path() + if callable(self.config.config_path) + else self.config.config_path + ) + engine = engine_cls( + config_path=config_path, + headless=self.config.headless, + ) + return SimManipInterface(engine=engine) + + @rpc + def start(self) -> None: + super().start() + if self._backend is None: + self._backend = self._create_backend() + if not self._backend.connect(): + raise RuntimeError("Failed to connect to simulation backend") + self._backend.write_enable(True) + + self._disposables.add( + Disposable(self.joint_position_command.subscribe(self._on_joint_position_command)) + ) + self._disposables.add( + Disposable(self.joint_velocity_command.subscribe(self._on_joint_velocity_command)) + ) + + self._stop_event.clear() + self._control_thread = threading.Thread( + target=self._control_loop, + daemon=True, + name=f"{self.__class__.__name__}-control", + ) + self._monitor_thread = threading.Thread( + target=self._monitor_loop, + daemon=True, + name=f"{self.__class__.__name__}-monitor", + ) + self._control_thread.start() + self._monitor_thread.start() + + @rpc + def stop(self) -> None: + self._stop_event.set() + if self._control_thread and self._control_thread.is_alive(): + self._control_thread.join(timeout=2.0) + if self._monitor_thread and self._monitor_thread.is_alive(): + self._monitor_thread.join(timeout=2.0) + if self._backend: + self._backend.disconnect() + super().stop() + + @rpc + def enable_servos(self) -> bool: + if not self._backend: + return False + return self._backend.write_enable(True) + + @rpc + def disable_servos(self) -> bool: + if not self._backend: + return False + return self._backend.write_enable(False) + + @rpc + def clear_errors(self) -> bool: + if not self._backend: + return False + return self._backend.write_clear_errors() + + @rpc + def emergency_stop(self) -> bool: + if not self._backend: + return False + return self._backend.write_stop() + + def _on_joint_position_command(self, msg: JointCommand) -> None: + with self._command_lock: + self._pending_positions = list(msg.positions) + self._pending_velocities = None + + def _on_joint_velocity_command(self, msg: JointCommand) -> None: + with self._command_lock: + self._pending_velocities = list(msg.positions) + self._pending_positions = None + + def _control_loop(self) -> None: + period = 1.0 / max(self._control_rate, self.MIN_CONTROL_RATE) + next_tick = time.monotonic() # monotonic time used to avoid time drift + while not self._stop_event.is_set(): + with self._command_lock: + positions = ( + None if self._pending_positions is None else list(self._pending_positions) + ) + velocities = ( + None if self._pending_velocities is None else list(self._pending_velocities) + ) + + if self._backend: + if positions is not None: + self._backend.write_joint_positions(positions) + elif velocities is not None: + self._backend.write_joint_velocities(velocities) + dof = self._backend.get_dof() + names = self._resolve_joint_names(dof) + positions = self._backend.read_joint_positions() + velocities = self._backend.read_joint_velocities() + efforts = self._backend.read_joint_efforts() + self.joint_state.publish( + JointState( + frame_id=self.frame_id, + name=names, + position=positions, + velocity=velocities, + effort=efforts, + ) + ) + next_tick += period + sleep_for = next_tick - time.monotonic() + if sleep_for > 0: + if self._stop_event.wait(sleep_for): + break + else: + next_tick = time.monotonic() + + def _monitor_loop(self) -> None: + period = 1.0 / max(self._monitor_rate, self.MIN_CONTROL_RATE) + next_tick = time.monotonic() # monotonic time used to avoid time drift + while not self._stop_event.is_set(): + if not self._backend: + pass + else: + dof = self._backend.get_dof() + self._resolve_joint_names(dof) + positions = self._backend.read_joint_positions() + self._backend.read_joint_velocities() + self._backend.read_joint_efforts() + state = self._backend.read_state() + error_code, _ = self._backend.read_error() + self.robot_state.publish( + RobotState( + state=state.get("state", 0), + mode=state.get("mode", 0), + error_code=error_code, + warn_code=0, + cmdnum=0, + mt_brake=0, + mt_able=1 if self._backend.read_enabled() else 0, + tcp_pose=[], + tcp_offset=[], + joints=[float(p) for p in positions], + ) + ) + next_tick += period + sleep_for = next_tick - time.monotonic() + if sleep_for > 0: + if self._stop_event.wait(sleep_for): + break + else: + next_tick = time.monotonic() + + def _resolve_joint_names(self, dof: int) -> list[str]: + if self._backend: + names = self._backend.get_joint_names() + if len(names) >= dof: + return list(names[:dof]) + return [f"{self._joint_prefix}{i + 1}" for i in range(dof)] + + +simulation = SimulationModule.blueprint + +__all__ = [ + "SimulationModule", + "SimulationModuleConfig", + "simulation", +] diff --git a/dimos/simulation/manipulators/test_sim_module.py b/dimos/simulation/manipulators/test_sim_module.py new file mode 100644 index 0000000000..77e7e93c46 --- /dev/null +++ b/dimos/simulation/manipulators/test_sim_module.py @@ -0,0 +1,160 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import threading + +import pytest + +from dimos.simulation.manipulators.mujoco_subprocess.shared_memory import ShmReader, ShmWriter +from dimos.simulation.manipulators.sim_module import SimulationModule + + +class _DummyRPC: + def serve_module_rpc(self, _module) -> None: # type: ignore[no-untyped-def] + return None + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + +class _FakeBackend: + def __init__(self) -> None: + self._names = ["joint1", "joint2", "joint3"] + + def get_dof(self) -> int: + return len(self._names) + + def get_joint_names(self) -> list[str]: + return list(self._names) + + def read_joint_positions(self) -> list[float]: + return [0.1, 0.2, 0.3] + + def read_joint_velocities(self) -> list[float]: + return [0.0, 0.0, 0.0] + + def read_joint_efforts(self) -> list[float]: + return [0.0, 0.0, 0.0] + + def read_state(self) -> dict[str, int]: + return {"state": 1, "mode": 2} + + def read_error(self) -> tuple[int, str]: + return 0, "" + + def read_enabled(self) -> bool: + return True + + def disconnect(self) -> None: + return None + + +def _run_single_monitor_iteration(module: SimulationModule, monkeypatch) -> None: # type: ignore[no-untyped-def] + def _wait_once(_: float) -> bool: + module._stop_event.set() + raise StopIteration + + monkeypatch.setattr(module._stop_event, "wait", _wait_once) + with pytest.raises(StopIteration): + module._monitor_loop() + + +def _run_single_control_iteration(module: SimulationModule, monkeypatch) -> None: # type: ignore[no-untyped-def] + def _wait_once(_: float) -> bool: + module._stop_event.set() + raise StopIteration + + monkeypatch.setattr(module._stop_event, "wait", _wait_once) + with pytest.raises(StopIteration): + module._control_loop() + + +def test_simulation_module_publishes_joint_state(monkeypatch) -> None: + module = SimulationModule( + engine="mujoco", + config_path=Path("."), + rpc_transport=_DummyRPC, + ) + module._backend = _FakeBackend() # type: ignore[assignment] + module._stop_event = threading.Event() + + joint_states: list[object] = [] + module.joint_state.subscribe(joint_states.append) + try: + _run_single_control_iteration(module, monkeypatch) + finally: + module.stop() + + assert len(joint_states) >= 1 + assert joint_states[0].name == ["joint1", "joint2", "joint3"] + + +def test_simulation_module_publishes_robot_state(monkeypatch) -> None: + module = SimulationModule( + engine="mujoco", + config_path=Path("."), + rpc_transport=_DummyRPC, + ) + module._backend = _FakeBackend() # type: ignore[assignment] + module._stop_event = threading.Event() + + robot_states: list[object] = [] + module.robot_state.subscribe(robot_states.append) + try: + _run_single_monitor_iteration(module, monkeypatch) + finally: + module.stop() + + assert len(robot_states) == 1 + assert robot_states[0].state == 1 + + +def test_mujoco_subprocess_shared_memory_roundtrip() -> None: + writer = ShmWriter(dof=2) + reader = ShmReader(writer.shm.to_names(), dof=2) + try: + assert writer.is_ready() is False + reader.signal_ready() + assert writer.is_ready() is True + + writer.write_command( + mode=1, + positions=[1.0, 2.0], + velocities=[3.0, 4.0], + efforts=[5.0, 6.0], + ) + cmd = reader.read_command() + assert cmd is not None + mode, positions, velocities, efforts = cmd + assert mode == 1 + assert positions.tolist() == [1.0, 2.0] + assert velocities.tolist() == [3.0, 4.0] + assert efforts.tolist() == [5.0, 6.0] + assert reader.read_command() is None + + reader.write_state([0.1, 0.2], [0.0, 0.1], [0.2, 0.3]) + positions, velocities, efforts = writer.read_state() + assert positions == pytest.approx([0.1, 0.2]) + assert velocities == pytest.approx([0.0, 0.1]) + assert efforts == pytest.approx([0.2, 0.3]) + + writer.signal_stop() + assert reader.should_stop() is True + finally: + reader.cleanup() + writer.cleanup() diff --git a/dimos/simulation/sim_blueprints.py b/dimos/simulation/sim_blueprints.py new file mode 100644 index 0000000000..836c3b28aa --- /dev/null +++ b/dimos/simulation/sim_blueprints.py @@ -0,0 +1,68 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +from dimos.agents.cli.human import human_input +from dimos.core.blueprints import autoconnect +from dimos.core.transport import LCMTransport +from dimos.msgs.sensor_msgs import ( # type: ignore[attr-defined] + JointCommand, + JointState, + RobotState, +) +from dimos.msgs.trajectory_msgs import JointTrajectory +from dimos.simulation.manipulators.sim_module import simulation +from dimos.utils.data import get_data + + +def _parse_headless_env(default: bool) -> bool: + raw = os.getenv("DIMOS_HEADLESS") + if raw is None: + return default + normalized = raw.strip().lower() + if normalized in {"1", "true", "yes", "y", "on"}: + return True + if normalized in {"0", "false", "no", "n", "off"}: + return False + return default + + +_headless = _parse_headless_env(sys.platform != "darwin") + +xarm7_trajectory_sim = simulation( + engine="mujoco", + config_path=lambda: get_data("xarm7") + / "scene.xml", # avoid triggering LFS downloads during tests + headless=_headless, +).transports( + { + ("joint_state", JointState): LCMTransport("/xarm/joint_states", JointState), + ("robot_state", RobotState): LCMTransport("/xarm/robot_state", RobotState), + ("joint_position_command", JointCommand): LCMTransport( + "/xarm/joint_position_command", JointCommand + ), + ("trajectory", JointTrajectory): LCMTransport("/trajectory", JointTrajectory), + } +) + + +__all__ = [ + "simulation", + "xarm7_trajectory_sim", +] + +if __name__ == "__main__": + xarm7_trajectory_sim.build().loop() diff --git a/dimos/simulation/utils/xml_parser.py b/dimos/simulation/utils/xml_parser.py new file mode 100644 index 0000000000..052657ea95 --- /dev/null +++ b/dimos/simulation/utils/xml_parser.py @@ -0,0 +1,277 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MuJoCo XML parsing helpers for joint/actuator metadata.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING +import xml.etree.ElementTree as ET + +import mujoco + +if TYPE_CHECKING: + from pathlib import Path + + +@dataclass(frozen=True) +class JointMapping: + name: str + joint_id: int | None + actuator_id: int | None + qpos_adr: int | None + dof_adr: int | None + tendon_qpos_adrs: tuple[int, ...] + tendon_dof_adrs: tuple[int, ...] + + +@dataclass(frozen=True) +class _ActuatorSpec: + name: str + joint: str | None + tendon: str | None + + +def build_joint_mappings(xml_path: Path, model: mujoco.MjModel) -> list[JointMapping]: + specs = _parse_actuator_specs(xml_path) + if specs: + return _build_joint_mappings_from_specs(specs, model) + if int(model.nu) > 0: + return _build_joint_mappings_from_actuators(model) + return _build_joint_mappings_from_model(model) + + +def _parse_actuator_specs(xml_path: Path) -> list[_ActuatorSpec]: + return _collect_actuator_specs(xml_path.resolve(), seen=set()) + + +def _collect_actuator_specs(xml_path: Path, seen: set[Path]) -> list[_ActuatorSpec]: + if xml_path in seen: + return [] + seen.add(xml_path) + + root = ET.parse(xml_path).getroot() + base_dir = xml_path.parent + specs: list[_ActuatorSpec] = [] + + def walk(node: ET.Element) -> None: + for child in node: + if child.tag == "include": + include_file = child.attrib.get("file") + if include_file: + include_path = (base_dir / include_file).resolve() + specs.extend(_collect_actuator_specs(include_path, seen)) + continue + if child.tag == "actuator": + specs.extend(_parse_actuator_block(child)) + continue + walk(child) + + walk(root) + return specs + + +def _parse_actuator_block(actuator_elem: ET.Element) -> list[_ActuatorSpec]: + specs: list[_ActuatorSpec] = [] + for child in actuator_elem: + joint = child.attrib.get("joint") + tendon = child.attrib.get("tendon") + if not joint and not tendon: + continue + name = child.attrib.get("name") or joint or tendon or "actuator" + specs.append(_ActuatorSpec(name=name, joint=joint, tendon=tendon)) + return specs + + +def _build_joint_mappings_from_specs( + specs: list[_ActuatorSpec], + model: mujoco.MjModel, +) -> list[JointMapping]: + mappings: list[JointMapping] = [] + for spec in specs: + if spec.joint: + mappings.append(_mapping_for_joint(spec, model)) + elif spec.tendon: + mappings.append(_mapping_for_tendon(spec, model)) + return mappings + + +def _mapping_for_joint(spec: _ActuatorSpec, model: mujoco.MjModel) -> JointMapping: + joint_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, spec.joint) + if joint_id < 0: + raise ValueError(f"Unknown joint '{spec.joint}' in MuJoCo model") + actuator_id = _find_actuator_id_for_joint(model, joint_id, spec.name) + joint_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_JOINT, joint_id) or spec.name + return JointMapping( + name=joint_name, + joint_id=joint_id, + actuator_id=actuator_id, + qpos_adr=int(model.jnt_qposadr[joint_id]), + dof_adr=int(model.jnt_dofadr[joint_id]), + tendon_qpos_adrs=(), + tendon_dof_adrs=(), + ) + + +def _mapping_for_tendon(spec: _ActuatorSpec, model: mujoco.MjModel) -> JointMapping: + name = spec.name or spec.tendon + if not name: + raise ValueError("Tendon actuator is missing a name and tendon reference") + tendon_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_TENDON, spec.tendon) + if tendon_id < 0: + raise ValueError(f"Unknown tendon '{spec.tendon}' in MuJoCo model") + actuator_id = _find_actuator_id_for_tendon(model, tendon_id, spec.name) + joint_ids = _tendon_joint_ids(model, tendon_id) + return JointMapping( + name=name, + joint_id=None, + actuator_id=actuator_id, + qpos_adr=None, + dof_adr=None, + tendon_qpos_adrs=tuple(int(model.jnt_qposadr[joint_id]) for joint_id in joint_ids), + tendon_dof_adrs=tuple(int(model.jnt_dofadr[joint_id]) for joint_id in joint_ids), + ) + + +def _find_actuator_id_for_joint( + model: mujoco.MjModel, + joint_id: int, + actuator_name: str | None, +) -> int | None: + if actuator_name: + act_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, actuator_name) + if act_id >= 0: + return int(act_id) + for act_id in range(int(model.nu)): + trn_type = int(model.actuator_trntype[act_id]) + if trn_type != int(mujoco.mjtTrn.mjTRN_JOINT): + continue + if int(model.actuator_trnid[act_id, 0]) == joint_id: + return act_id + return None + + +def _find_actuator_id_for_tendon( + model: mujoco.MjModel, + tendon_id: int, + actuator_name: str | None, +) -> int | None: + if actuator_name: + act_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, actuator_name) + if act_id >= 0: + return int(act_id) + for act_id in range(int(model.nu)): + trn_type = int(model.actuator_trntype[act_id]) + if trn_type != int(mujoco.mjtTrn.mjTRN_TENDON): + continue + if int(model.actuator_trnid[act_id, 0]) == tendon_id: + return act_id + return None + + +def _tendon_joint_ids(model: mujoco.MjModel, tendon_id: int) -> tuple[int, ...]: + adr = int(model.tendon_adr[tendon_id]) + num = int(model.tendon_num[tendon_id]) + joint_ids: list[int] = [] + for wrap_id in range(adr, adr + num): + wrap_type = int(model.wrap_type[wrap_id]) + if wrap_type == int(mujoco.mjtWrap.mjWRAP_JOINT): + joint_ids.append(int(model.wrap_objid[wrap_id])) + return tuple(joint_ids) + + +def _build_joint_mappings_from_actuators(model: mujoco.MjModel) -> list[JointMapping]: + mappings: list[JointMapping] = [] + for actuator_id in range(int(model.nu)): + actuator_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_ACTUATOR, actuator_id) + name = actuator_name or f"actuator{actuator_id}" + trn_type = int(model.actuator_trntype[actuator_id]) + if trn_type == int(mujoco.mjtTrn.mjTRN_JOINT): + joint_id = int(model.actuator_trnid[actuator_id, 0]) + joint_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_JOINT, joint_id) + mappings.append( + JointMapping( + name=joint_name or name, + joint_id=joint_id, + actuator_id=actuator_id, + qpos_adr=int(model.jnt_qposadr[joint_id]), + dof_adr=int(model.jnt_dofadr[joint_id]), + tendon_qpos_adrs=(), + tendon_dof_adrs=(), + ) + ) + continue + + if trn_type == int(mujoco.mjtTrn.mjTRN_TENDON): + tendon_id = int(model.actuator_trnid[actuator_id, 0]) + tendon_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_TENDON, tendon_id) + if not actuator_name and tendon_name: + name = tendon_name + joint_ids = _tendon_joint_ids(model, tendon_id) + mappings.append( + JointMapping( + name=name, + joint_id=None, + actuator_id=actuator_id, + qpos_adr=None, + dof_adr=None, + tendon_qpos_adrs=tuple( + int(model.jnt_qposadr[joint_id]) for joint_id in joint_ids + ), + tendon_dof_adrs=tuple( + int(model.jnt_dofadr[joint_id]) for joint_id in joint_ids + ), + ) + ) + continue + + mappings.append( + JointMapping( + name=name, + joint_id=None, + actuator_id=actuator_id, + qpos_adr=None, + dof_adr=None, + tendon_qpos_adrs=(), + tendon_dof_adrs=(), + ) + ) + + return mappings + + +def _build_joint_mappings_from_model(model: mujoco.MjModel) -> list[JointMapping]: + mappings: list[JointMapping] = [] + for joint_id in range(int(model.njnt)): + jnt_type = int(model.jnt_type[joint_id]) + if jnt_type not in ( + int(mujoco.mjtJoint.mjJNT_HINGE), + int(mujoco.mjtJoint.mjJNT_SLIDE), + ): + continue + joint_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_JOINT, joint_id) + name = joint_name or f"joint{joint_id}" + mappings.append( + JointMapping( + name=name, + joint_id=joint_id, + actuator_id=None, + qpos_adr=int(model.jnt_qposadr[joint_id]), + dof_adr=int(model.jnt_dofadr[joint_id]), + tendon_qpos_adrs=(), + tendon_dof_adrs=(), + ) + ) + return mappings