From a9b6c569891f8b9f45ca97354988ac843ce40b2a Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Sun, 22 Feb 2026 03:31:19 +0200 Subject: [PATCH 01/16] fix(metric3d): remove metric3d to fix tests (#1341) --- dimos/environment/__init__.py | 0 dimos/environment/environment.py | 178 ------------------ dimos/models/depth/__init__.py | 0 dimos/models/depth/metric3d.py | 187 ------------------- dimos/models/depth/test_metric3d.py | 102 ----------- dimos/robot/all_blueprints.py | 1 - dimos/robot/drone/README.md | 9 +- dimos/robot/drone/camera_module.py | 68 +------ dimos/robot/drone/drone.py | 4 +- dimos/robot/unitree/depth_module.py | 243 ------------------------- dimos/robot/unitree_webrtc/__init__.py | 1 - dimos/types/manipulation.py | 2 +- onnx/metric3d_vit_small.onnx | 3 - 13 files changed, 8 insertions(+), 790 deletions(-) delete mode 100644 dimos/environment/__init__.py delete mode 100644 dimos/environment/environment.py delete mode 100644 dimos/models/depth/__init__.py delete mode 100644 dimos/models/depth/metric3d.py delete mode 100644 dimos/models/depth/test_metric3d.py delete mode 100644 dimos/robot/unitree/depth_module.py delete mode 100644 onnx/metric3d_vit_small.onnx diff --git a/dimos/environment/__init__.py b/dimos/environment/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/environment/environment.py b/dimos/environment/environment.py deleted file mode 100644 index ba1923b765..0000000000 --- a/dimos/environment/environment.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod - -import numpy as np - - -class Environment(ABC): - def __init__(self) -> None: - self.environment_type = None - self.graph = None - - @abstractmethod - def label_objects(self) -> list[str]: - """ - Label all objects in the environment. - - Returns: - A list of string labels representing the objects in the environment. - """ - pass - - @abstractmethod - def get_visualization(self, format_type): # type: ignore[no-untyped-def] - """Return different visualization formats like images, NERFs, or other 3D file types.""" - pass - - @abstractmethod - def generate_segmentations( # type: ignore[no-untyped-def] - self, model: str | None = None, objects: list[str] | None = None, *args, **kwargs - ) -> list[np.ndarray]: # type: ignore[type-arg] - """ - Generate object segmentations of objects[] using neural methods. - - Args: - model (str, optional): The string of the desired segmentation model (SegmentAnything, etc.) - objects (list[str], optional): The list of strings of the specific objects to segment. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - list of numpy.ndarray: A list where each element is a numpy array - representing a binary mask for a segmented area of an object in the environment. - - Note: - The specific arguments and their usage will depend on the subclass implementation. - """ - pass - - @abstractmethod - def get_segmentations(self) -> list[np.ndarray]: # type: ignore[type-arg] - """ - Get segmentations using a method like 'segment anything'. - - Returns: - list of numpy.ndarray: A list where each element is a numpy array - representing a binary mask for a segmented area of an object in the environment. - """ - pass - - @abstractmethod - def generate_point_cloud(self, object: str | None = None, *args, **kwargs) -> np.ndarray: # type: ignore[no-untyped-def, type-arg] - """ - Generate a point cloud for the entire environment or a specific object. - - Args: - object (str, optional): The string of the specific object to get the point cloud for. - If None, returns the point cloud for the entire environment. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - np.ndarray: A numpy array representing the generated point cloud. - Shape: (n, 3) where n is the number of points and each point is [x, y, z]. - - Note: - The specific arguments and their usage will depend on the subclass implementation. - """ - pass - - @abstractmethod - def get_point_cloud(self, object: str | None = None) -> np.ndarray: # type: ignore[type-arg] - """ - Return point clouds of the entire environment or a specific object. - - Args: - object (str, optional): The string of the specific object to get the point cloud for. If None, returns the point cloud for the entire environment. - - Returns: - np.ndarray: A numpy array representing the point cloud. - Shape: (n, 3) where n is the number of points and each point is [x, y, z]. - """ - pass - - @abstractmethod - def generate_depth_map( # type: ignore[no-untyped-def] - self, - stereo: bool | None = None, - monocular: bool | None = None, - model: str | None = None, - *args, - **kwargs, - ) -> np.ndarray: # type: ignore[type-arg] - """ - Generate a depth map using monocular or stereo camera methods. - - Args: - stereo (bool, optional): Whether to stereo camera is avaliable for ground truth depth map generation. - monocular (bool, optional): Whether to use monocular camera for neural depth map generation. - model (str, optional): The string of the desired monocular depth model (Metric3D, ZoeDepth, etc.) - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - np.ndarray: A 2D numpy array representing the generated depth map. - Shape: (height, width) where each value represents the depth - at that pixel location. - - Note: - The specific arguments and their usage will depend on the subclass implementation. - """ - pass - - @abstractmethod - def get_depth_map(self) -> np.ndarray: # type: ignore[type-arg] - """ - Return a depth map of the environment. - - Returns: - np.ndarray: A 2D numpy array representing the depth map. - Shape: (height, width) where each value represents the depth - at that pixel location. Typically, closer objects have smaller - values and farther objects have larger values. - - Note: - The exact range and units of the depth values may vary depending on the - specific implementation and the sensor or method used to generate the depth map. - """ - pass - - def initialize_from_images(self, images): # type: ignore[no-untyped-def] - """Initialize the environment from a set of image frames or video.""" - raise NotImplementedError("This method is not implemented for this environment type.") - - def initialize_from_file(self, file_path): # type: ignore[no-untyped-def] - """Initialize the environment from a spatial file type. - - Supported file types include: - - GLTF/GLB (GL Transmission Format) - - FBX (Filmbox) - - OBJ (Wavefront Object) - - USD/USDA/USDC (Universal Scene Description) - - STL (Stereolithography) - - COLLADA (DAE) - - Alembic (ABC) - - PLY (Polygon File Format) - - 3DS (3D Studio) - - VRML/X3D (Virtual Reality Modeling Language) - - Args: - file_path (str): Path to the spatial file. - - Raises: - NotImplementedError: If the method is not implemented for this environment type. - """ - raise NotImplementedError("This method is not implemented for this environment type.") diff --git a/dimos/models/depth/__init__.py b/dimos/models/depth/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/models/depth/metric3d.py b/dimos/models/depth/metric3d.py deleted file mode 100644 index a668ea321e..0000000000 --- a/dimos/models/depth/metric3d.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field -from functools import cached_property -from typing import Any - -import cv2 -import torch - -from dimos.models.base import LocalModel, LocalModelConfig - - -@dataclass -class Metric3DConfig(LocalModelConfig): - """Configuration for Metric3D depth estimation model.""" - - camera_intrinsics: list[float] = field(default_factory=lambda: [500.0, 500.0, 320.0, 240.0]) - """Camera intrinsics [fx, fy, cx, cy].""" - - gt_depth_scale: float = 256.0 - """Scale factor for ground truth depth.""" - - device: str = "cuda" if torch.cuda.is_available() else "cpu" - """Device to run the model on.""" - - -class Metric3D(LocalModel): - default_config = Metric3DConfig - config: Metric3DConfig - - def __init__(self, **kwargs: object) -> None: - super().__init__(**kwargs) - self.intrinsic = self.config.camera_intrinsics - self.intrinsic_scaled: list[float] | None = None - self.gt_depth_scale = self.config.gt_depth_scale - self.pad_info: list[int] | None = None - self.rgb_origin: Any = None - - @cached_property - def _model(self) -> Any: - model = torch.hub.load( # type: ignore[no-untyped-call] - "yvanyin/metric3d", "metric3d_vit_small", pretrain=True - ) - model = model.to(self.device) - model.eval() - return model - - """ - Input: Single image in RGB format - Output: Depth map - """ - - def update_intrinsic(self, intrinsic): # type: ignore[no-untyped-def] - """ - Update the intrinsic parameters dynamically. - Ensure that the input intrinsic is valid. - """ - if len(intrinsic) != 4: - raise ValueError("Intrinsic must be a list or tuple with 4 values: [fx, fy, cx, cy]") - self.intrinsic = intrinsic - print(f"Intrinsics updated to: {self.intrinsic}") - - def infer_depth(self, img, debug: bool = False): # type: ignore[no-untyped-def] - if debug: - print(f"Input image: {img}") - try: - if isinstance(img, str): - print(f"Image type string: {type(img)}") - img_data = cv2.imread(img) - if img_data is None: - raise ValueError(f"Failed to load image from {img}") - self.rgb_origin = img_data[:, :, ::-1] - else: - # print(f"Image type not string: {type(img)}, cv2 conversion assumed to be handled. If not, this will throw an error") - self.rgb_origin = img - except Exception as e: - print(f"Error parsing into infer_depth: {e}") - - img = self.rescale_input(img, self.rgb_origin) # type: ignore[no-untyped-call] - - with torch.no_grad(): - pred_depth, confidence, output_dict = self._model.inference({"input": img}) - - # Convert to PIL format - depth_image = self.unpad_transform_depth(pred_depth) # type: ignore[no-untyped-call] - - return depth_image.cpu().numpy() - - def save_depth(self, pred_depth) -> None: # type: ignore[no-untyped-def] - # Save the depth map to a file - pred_depth_np = pred_depth.cpu().numpy() - output_depth_file = "output_depth_map.png" - cv2.imwrite(output_depth_file, pred_depth_np) - print(f"Depth map saved to {output_depth_file}") - - # Adjusts input size to fit pretrained ViT model - def rescale_input(self, rgb, rgb_origin): # type: ignore[no-untyped-def] - #### ajust input size to fit pretrained model - # keep ratio resize - input_size = (616, 1064) # for vit model - # input_size = (544, 1216) # for convnext model - h, w = rgb_origin.shape[:2] - scale = min(input_size[0] / h, input_size[1] / w) - rgb = cv2.resize( - rgb_origin, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR - ) - # remember to scale intrinsic, hold depth - self.intrinsic_scaled = [ - self.intrinsic[0] * scale, - self.intrinsic[1] * scale, - self.intrinsic[2] * scale, - self.intrinsic[3] * scale, - ] - # padding to input_size - padding = [123.675, 116.28, 103.53] - h, w = rgb.shape[:2] - pad_h = input_size[0] - h - pad_w = input_size[1] - w - pad_h_half = pad_h // 2 - pad_w_half = pad_w // 2 - rgb = cv2.copyMakeBorder( - rgb, - pad_h_half, - pad_h - pad_h_half, - pad_w_half, - pad_w - pad_w_half, - cv2.BORDER_CONSTANT, - value=padding, - ) - self.pad_info = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] - - #### normalize - mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None] - std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None] - rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float() - rgb = torch.div((rgb - mean), std) - rgb = rgb[None, :, :, :].to(self.device) - return rgb - - def unpad_transform_depth(self, pred_depth): # type: ignore[no-untyped-def] - # un pad - pred_depth = pred_depth.squeeze() - pred_depth = pred_depth[ - self.pad_info[0] : pred_depth.shape[0] - self.pad_info[1], # type: ignore[index] - self.pad_info[2] : pred_depth.shape[1] - self.pad_info[3], # type: ignore[index] - ] - - # upsample to original size - pred_depth = torch.nn.functional.interpolate( - pred_depth[None, None, :, :], - self.rgb_origin.shape[:2], - mode="bilinear", - ).squeeze() - ###################### canonical camera space ###################### - - #### de-canonical transform - canonical_to_real_scale = ( - self.intrinsic_scaled[0] / 1000.0 # type: ignore[index] - ) # 1000.0 is the focal length of canonical camera - pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric - pred_depth = torch.clamp(pred_depth, 0, 1000) - return pred_depth - - def eval_predicted_depth(self, depth_file, pred_depth) -> None: # type: ignore[no-untyped-def] - if depth_file is not None: - gt_depth_np = cv2.imread(depth_file, -1) - if gt_depth_np is None: - raise ValueError(f"Failed to load depth file from {depth_file}") - gt_depth_scaled = gt_depth_np / self.gt_depth_scale - gt_depth = torch.from_numpy(gt_depth_scaled).float().to(self.device) - assert gt_depth.shape == pred_depth.shape - - mask = gt_depth > 1e-8 # type: ignore[operator] - abs_rel_err = (torch.abs(pred_depth[mask] - gt_depth[mask]) / gt_depth[mask]).mean() # type: ignore[index] - print("abs_rel_err:", abs_rel_err.item()) diff --git a/dimos/models/depth/test_metric3d.py b/dimos/models/depth/test_metric3d.py deleted file mode 100644 index 33e39f6a29..0000000000 --- a/dimos/models/depth/test_metric3d.py +++ /dev/null @@ -1,102 +0,0 @@ -from contextlib import contextmanager - -import numpy as np -import pytest - -from dimos.models.depth.metric3d import Metric3D -from dimos.msgs.sensor_msgs import Image -from dimos.utils.data import get_data - - -@contextmanager -def skip_xformers_unsupported(): - try: - yield - except NotImplementedError as e: - if "memory_efficient_attention" in str(e): - pytest.skip(f"xformers not supported on this GPU: {e}") - raise - - -@pytest.fixture -def sample_intrinsics() -> list[float]: - """Sample camera intrinsics [fx, fy, cx, cy].""" - return [500.0, 500.0, 320.0, 240.0] - -@pytest.mark.cuda -@pytest.mark.gpu -def test_metric3d_init(sample_intrinsics: list[float]) -> None: - """Test Metric3D initialization.""" - model = Metric3D(camera_intrinsics=sample_intrinsics) - assert model.config.camera_intrinsics == sample_intrinsics - assert model.config.gt_depth_scale == 256.0 - assert model.device == "cuda" - - -@pytest.mark.gpu -def test_metric3d_update_intrinsic(sample_intrinsics: list[float]) -> None: - """Test updating camera intrinsics.""" - model = Metric3D(camera_intrinsics=sample_intrinsics) - - new_intrinsics = [600.0, 600.0, 400.0, 300.0] - model.update_intrinsic(new_intrinsics) - assert model.intrinsic == new_intrinsics - -@pytest.mark.gpu -def test_metric3d_update_intrinsic_invalid(sample_intrinsics: list[float]) -> None: - """Test that invalid intrinsics raise an error.""" - model = Metric3D(camera_intrinsics=sample_intrinsics) - - with pytest.raises(ValueError, match="Intrinsic must be a list"): - model.update_intrinsic([1.0, 2.0]) # Only 2 values - - -@pytest.mark.cuda -@pytest.mark.gpu -def test_metric3d_infer_depth(sample_intrinsics: list[float]) -> None: - """Test depth inference on a sample image.""" - model = Metric3D(camera_intrinsics=sample_intrinsics) - model.start() - - # Load test image - image = Image.from_file(get_data("cafe.jpg")).to_rgb() - rgb_array = image.data - - # Run inference - with skip_xformers_unsupported(): - depth_map = model.infer_depth(rgb_array) - - # Verify output - assert isinstance(depth_map, np.ndarray) - assert depth_map.shape[:2] == rgb_array.shape[:2] # Same spatial dimensions - assert depth_map.dtype in [np.float32, np.float64] - assert depth_map.min() >= 0 # Depth should be non-negative - - print(f"Depth map shape: {depth_map.shape}") - print(f"Depth range: [{depth_map.min():.2f}, {depth_map.max():.2f}]") - - model.stop() - - -@pytest.mark.cuda -@pytest.mark.gpu -def test_metric3d_multiple_inferences(sample_intrinsics: list[float]) -> None: - """Test multiple depth inferences.""" - model = Metric3D(camera_intrinsics=sample_intrinsics) - model.start() - - image = Image.from_file(get_data("cafe.jpg")).to_rgb() - rgb_array = image.data - - # Run multiple inferences - depths = [] - for _ in range(3): - with skip_xformers_unsupported(): - depth = model.infer_depth(rgb_array) - depths.append(depth) - - # Results should be consistent - for i in range(1, len(depths)): - assert np.allclose(depths[0], depths[i], rtol=1e-5) - - model.stop() diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 19d7e7db29..bdfd98cd17 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -95,7 +95,6 @@ "cost_mapper": "dimos.mapping.costmapper", "demo_calculator_skill": "dimos.agents.skills.demo_calculator_skill", "demo_robot": "dimos.agents.skills.demo_robot", - "depth_module": "dimos.robot.unitree.depth_module", "detection3d_module": "dimos.perception.detection.module3D", "detection_db_module": "dimos.perception.detection.moduleDB", "fastlio2_module": "dimos.hardware.sensors.lidar.fastlio2.module", diff --git a/dimos/robot/drone/README.md b/dimos/robot/drone/README.md index 6e8ceb4d63..100e2deadd 100644 --- a/dimos/robot/drone/README.md +++ b/dimos/robot/drone/README.md @@ -126,7 +126,7 @@ DJI Drone ← Wireless → DJI Controller ← USB → Android Device ← WiFi ``` drone.py # Main orchestrator ├── connection_module.py # MAVLink communication & skills -├── camera_module.py # Video processing & depth estimation +├── camera_module.py # Video processing ├── tracking_module.py # Visual servoing & object tracking ├── mavlink_connection.py # Low-level MAVLink protocol └── dji_video_stream.py # GStreamer video capture @@ -242,13 +242,6 @@ drone.start() - **ROS/DimOS**: X=Forward, Y=Left, Z=Up - Automatic conversion handled internally -### Depth Estimation -Camera module can generate depth maps using Metric3D: -```python -# Depth published to /drone/depth and /drone/pointcloud -# Requires GPU with 8GB+ VRAM -``` - ### Foxglove Visualization Connect Foxglove Studio to `ws://localhost:8765` to see: - Live video with tracking overlay diff --git a/dimos/robot/drone/camera_module.py b/dimos/robot/drone/camera_module.py index 8ba88fd028..248b1ceb6e 100644 --- a/dimos/robot/drone/camera_module.py +++ b/dimos/robot/drone/camera_module.py @@ -15,7 +15,7 @@ # Copyright 2025-2026 Dimensional Inc. -"""Camera module for drone with depth estimation.""" +"""Camera module for drone.""" import threading import time @@ -25,9 +25,8 @@ from dimos.core import In, Module, Out, rpc from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.sensor_msgs import Image from dimos.msgs.std_msgs import Header -from dimos.perception.common.utils import colorize_depth from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -35,15 +34,13 @@ class DroneCameraModule(Module): """ - Camera module for drone that processes RGB images to generate depth using Metric3D. + Camera module for drone Subscribes to: - /video: RGB camera images from drone Publishes: - /drone/color_image: RGB camera images - - /drone/depth_image: Depth images from Metric3D - - /drone/depth_colorized: Colorized depth - /drone/camera_info: Camera calibration - /drone/camera_pose: Camera pose from TF """ @@ -53,8 +50,6 @@ class DroneCameraModule(Module): # Outputs color_image: Out[Image] - depth_image: Out[Image] - depth_colorized: Out[Image] camera_info: Out[CameraInfo] camera_pose: Out[PoseStamped] @@ -64,7 +59,6 @@ def __init__( world_frame_id: str = "world", camera_frame_id: str = "camera_link", base_frame_id: str = "base_link", - gt_depth_scale: float = 2.0, **kwargs: Any, ) -> None: """Initialize drone camera module. @@ -73,7 +67,6 @@ def __init__( camera_intrinsics: [fx, fy, cx, cy] camera_frame_id: TF frame for camera base_frame_id: TF frame for drone base - gt_depth_scale: Depth scale factor """ super().__init__(**kwargs) @@ -84,10 +77,6 @@ def __init__( self.camera_frame_id = camera_frame_id self.base_frame_id = base_frame_id self.world_frame_id = world_frame_id - self.gt_depth_scale = gt_depth_scale - - # Metric3D for depth - self.metric3d: Any = None # Lazy-loaded Metric3D model # Processing state self._running = False @@ -104,7 +93,6 @@ def start(self) -> None: logger.warning("Camera module already running") return - # Start processing thread for depth (which will init Metric3D and handle video) self._running = True self._stop_processing.clear() self._processing_thread = threading.Thread(target=self._processing_loop, daemon=True) @@ -121,22 +109,9 @@ def _on_video_frame(self, frame: Image) -> None: # Publish color image immediately self.color_image.publish(frame) - # Store for depth processing self._latest_frame = frame def _processing_loop(self) -> None: - """Process depth estimation in background.""" - # Initialize Metric3D in the background thread - if self.metric3d is None: - try: - from dimos.models.depth.metric3d import Metric3D - - self.metric3d = Metric3D(camera_intrinsics=self.camera_intrinsics) - logger.info("Metric3D initialized") - except Exception as e: - logger.warning(f"Metric3D not available: {e}") - self.metric3d = None - # Subscribe to video once connection is available subscribed = False while not subscribed and not self._stop_processing.is_set(): @@ -151,12 +126,10 @@ def _processing_loop(self) -> None: logger.debug(f"Waiting for video connection: {e}") time.sleep(0.1) - logger.info("Depth processing loop started") - _reported_error = False while not self._stop_processing.is_set(): - if self._latest_frame is not None and self.metric3d is not None: + if self._latest_frame is not None: try: frame = self._latest_frame self._latest_frame = None @@ -164,34 +137,9 @@ def _processing_loop(self) -> None: # Get numpy array from Image img_array = frame.data - # Generate depth - depth_array = self.metric3d.infer_depth(img_array) / self.gt_depth_scale - # Create header header = Header(self.camera_frame_id) - # Publish depth - depth_msg = Image( - data=depth_array, - format=ImageFormat.DEPTH, - frame_id=header.frame_id, - ts=header.ts, - ) - self.depth_image.publish(depth_msg) - - # Publish colorized depth - depth_colorized_array = colorize_depth( - depth_array, max_depth=10.0, overlay_stats=True - ) - if depth_colorized_array is not None: - depth_colorized_msg = Image( - data=depth_colorized_array, - format=ImageFormat.RGB, - frame_id=header.frame_id, - ts=header.ts, - ) - self.depth_colorized.publish(depth_colorized_msg) - # Publish camera info self._publish_camera_info(header, img_array.shape) @@ -201,12 +149,10 @@ def _processing_loop(self) -> None: except Exception as e: if not _reported_error: _reported_error = True - logger.error(f"Error processing depth: {e}") + logger.error(f"Error processing frame: {e}") else: time.sleep(0.01) - logger.info("Depth processing loop stopped") - def _publish_camera_info(self, header: Header, shape: tuple[int, ...]) -> None: """Publish camera calibration info.""" try: @@ -279,8 +225,4 @@ def stop(self) -> None: if self._processing_thread and self._processing_thread.is_alive(): self._processing_thread.join(timeout=2.0) - # Cleanup Metric3D - if self.metric3d: - self.metric3d.cleanup() - logger.info("Camera module stopped") diff --git a/dimos/robot/drone/drone.py b/dimos/robot/drone/drone.py index 8e72d56ed1..6b9500804f 100644 --- a/dimos/robot/drone/drone.py +++ b/dimos/robot/drone/drone.py @@ -51,7 +51,7 @@ class Drone(Robot): - """Generic MAVLink-based drone with video and depth capabilities.""" + """Generic MAVLink-based drone with video capabilities.""" def __init__( self, @@ -164,8 +164,6 @@ def _deploy_camera(self) -> None: # Configure LCM transports self.camera.color_image.transport = core.LCMTransport("/drone/color_image", Image) - self.camera.depth_image.transport = core.LCMTransport("/drone/depth_image", Image) - self.camera.depth_colorized.transport = core.LCMTransport("/drone/depth_colorized", Image) self.camera.camera_info.transport = core.LCMTransport("/drone/camera_info", CameraInfo) self.camera.camera_pose.transport = core.LCMTransport("/drone/camera_pose", PoseStamped) diff --git a/dimos/robot/unitree/depth_module.py b/dimos/robot/unitree/depth_module.py deleted file mode 100644 index 07f065caea..0000000000 --- a/dimos/robot/unitree/depth_module.py +++ /dev/null @@ -1,243 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -import time - -from dimos_lcm.sensor_msgs import CameraInfo -import numpy as np - -from dimos.core import In, Module, Out, rpc -from dimos.core.global_config import GlobalConfig -from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos.utils.logging_config import setup_logger - -logger = setup_logger() - - -class DepthModule(Module): - """ - Depth module for Unitree Go2 that processes RGB images to generate depth using Metric3D. - - Subscribes to: - - /go2/color_image: RGB camera images from Unitree - - /go2/camera_info: Camera calibration information - - Publishes: - - /go2/depth_image: Depth images generated by Metric3D - """ - - # LCM inputs - color_image: In[Image] - camera_info: In[CameraInfo] - - # LCM outputs - depth_image: Out[Image] - - def __init__( # type: ignore[no-untyped-def] - self, - gt_depth_scale: float = 0.5, - cfg: GlobalConfig | None = None, - **kwargs, - ) -> None: - """ - Initialize Depth Module. - - Args: - gt_depth_scale: Ground truth depth scaling factor - """ - super().__init__(**kwargs) - - self.camera_intrinsics = None - self.gt_depth_scale = gt_depth_scale - self.metric3d = None - self._camera_info_received = False - - # Processing state - self._running = False - self._latest_frame = None - self._last_image = None - self._last_timestamp = None - self._last_depth = None - self._cannot_process_depth = False - - # Threading - self._processing_thread: threading.Thread | None = None - self._stop_processing = threading.Event() - - if cfg: - if cfg.simulation: - self.gt_depth_scale = 1.0 - - @rpc - def start(self) -> None: - super().start() - - if self._running: - logger.warning("Camera module already running") - return - - # Set running flag before starting - self._running = True - - # Subscribe to video and camera info inputs - self.color_image.subscribe(self._on_video) - self.camera_info.subscribe(self._on_camera_info) - - # Start processing thread - self._start_processing_thread() - - logger.info("Depth module started") - - @rpc - def stop(self) -> None: - if not self._running: - return - - self._running = False - self._stop_processing.set() - - # Wait for thread to finish - if self._processing_thread and self._processing_thread.is_alive(): - self._processing_thread.join(timeout=2.0) - - super().stop() - - def _on_camera_info(self, msg: CameraInfo) -> None: - """Process camera info to extract intrinsics.""" - if self.metric3d is not None: - return # Already initialized - - try: - # Extract intrinsics from camera matrix K - K = msg.K - fx = K[0] - fy = K[4] - cx = K[2] - cy = K[5] - - self.camera_intrinsics = [fx, fy, cx, cy] # type: ignore[assignment] - - # Initialize Metric3D with camera intrinsics - from dimos.models.depth.metric3d import Metric3D - - self.metric3d = Metric3D(camera_intrinsics=self.camera_intrinsics) # type: ignore[assignment] - self._camera_info_received = True - - logger.info( - f"Initialized Metric3D with intrinsics from camera_info: {self.camera_intrinsics}" - ) - - except Exception as e: - logger.error(f"Error processing camera info: {e}") - - def _on_video(self, msg: Image) -> None: - """Store latest video frame for processing.""" - if not self._running: - return - - # Simply store the latest frame - processing happens in main loop - self._latest_frame = msg # type: ignore[assignment] - logger.debug( - f"Received video frame: format={msg.format}, shape={msg.data.shape if hasattr(msg.data, 'shape') else 'unknown'}" - ) - - def _start_processing_thread(self) -> None: - """Start the processing thread.""" - self._stop_processing.clear() - self._processing_thread = threading.Thread(target=self._main_processing_loop, daemon=True) - self._processing_thread.start() - logger.info("Started depth processing thread") - - def _main_processing_loop(self) -> None: - """Main processing loop that continuously processes latest frames.""" - logger.info("Starting main processing loop") - - while not self._stop_processing.is_set(): - # Process latest frame if available - if self._latest_frame is not None: - try: - msg = self._latest_frame - self._latest_frame = None # Clear to avoid reprocessing - # Store for publishing - self._last_image = msg.data - self._last_timestamp = msg.ts if msg.ts else time.time() - # Process depth - self._process_depth(self._last_image) - - except Exception as e: - logger.error(f"Error in main processing loop: {e}", exc_info=True) - else: - # Small sleep to avoid busy waiting - time.sleep(0.001) - - logger.info("Main processing loop stopped") - - def _process_depth(self, img_array: np.ndarray) -> None: # type: ignore[type-arg] - """Process depth estimation using Metric3D.""" - if self._cannot_process_depth: - self._last_depth = None - return - - # Wait for camera info to initialize Metric3D - if self.metric3d is None: - logger.debug("Waiting for camera_info to initialize Metric3D") - return - - try: - logger.debug(f"Processing depth for image shape: {img_array.shape}") - - # Generate depth map - depth_array = self.metric3d.infer_depth(img_array) * self.gt_depth_scale - - self._last_depth = depth_array - logger.debug(f"Generated depth map shape: {depth_array.shape}") - - self._publish_depth() - - except Exception as e: - logger.error(f"Error processing depth: {e}") - self._cannot_process_depth = True - - def _publish_depth(self) -> None: - """Publish depth image.""" - if not self._running: - return - - try: - # Publish depth image - if self._last_depth is not None: - # Convert depth to uint16 (millimeters) for more efficient storage - # Clamp to valid range [0, 65.535] meters before converting - depth_clamped = np.clip(self._last_depth, 0, 65.535) - depth_uint16 = (depth_clamped * 1000).astype(np.uint16) - depth_msg = Image( - data=depth_uint16, - format=ImageFormat.DEPTH16, # Use DEPTH16 format for uint16 depth - frame_id="camera_link", - ts=self._last_timestamp, - ) - self.depth_image.publish(depth_msg) - logger.debug(f"Published depth image (uint16): shape={depth_uint16.shape}") - - except Exception as e: - logger.error(f"Error publishing depth data: {e}", exc_info=True) - - -depth_module = DepthModule.blueprint - - -__all__ = ["DepthModule", "depth_module"] diff --git a/dimos/robot/unitree_webrtc/__init__.py b/dimos/robot/unitree_webrtc/__init__.py index 4524bba226..451aa53128 100644 --- a/dimos/robot/unitree_webrtc/__init__.py +++ b/dimos/robot/unitree_webrtc/__init__.py @@ -20,7 +20,6 @@ _ALIAS_MODULES = { "demo_error_on_name_conflicts": "dimos.robot.unitree.demo_error_on_name_conflicts", - "depth_module": "dimos.robot.unitree.depth_module", "keyboard_teleop": "dimos.robot.unitree.keyboard_teleop", "mujoco_connection": "dimos.robot.unitree.mujoco_connection", "type": "dimos.robot.unitree.type", diff --git a/dimos/types/manipulation.py b/dimos/types/manipulation.py index 507b9e9b85..76ad7979f2 100644 --- a/dimos/types/manipulation.py +++ b/dimos/types/manipulation.py @@ -80,7 +80,7 @@ class ObjectData(TypedDict, total=False): # Basic detection information object_id: int # Unique ID for the object bbox: list[float] # Bounding box [x1, y1, x2, y2] - depth: float # Depth in meters from Metric3d + depth: float # Depth in meters confidence: float # Detection confidence class_id: int # Class ID from the detector label: str # Semantic label (e.g., 'cup', 'table') diff --git a/onnx/metric3d_vit_small.onnx b/onnx/metric3d_vit_small.onnx deleted file mode 100644 index bfddd41628..0000000000 --- a/onnx/metric3d_vit_small.onnx +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:14805174265dd721ac3b396bd5ee7190c708cec41150ed298267f6c3126bc060 -size 151333865 From 914a0907ea168e21f29b40aebff5d8a0a62cbd27 Mon Sep 17 00:00:00 2001 From: leshy Date: Mon, 23 Feb 2026 19:51:13 +0800 Subject: [PATCH 02/16] fix(tests): replace busy-wait loops with threading.Event in pubsub tests (#1350) Busy-wait loops with no timeout could hang forever. Use threading.Event for efficient blocking with 1s timeouts that fail clearly on expiry. --- dimos/protocol/pubsub/test_spec.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index 0bdfa62628..26c1cf0357 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -17,6 +17,7 @@ import asyncio from collections.abc import Callable, Generator from contextlib import contextmanager +import threading import time from typing import Any @@ -149,10 +150,12 @@ def test_store(pubsub_context: Callable[[], Any], topic: Any, values: list[Any]) with pubsub_context() as x: # Create a list to capture received messages received_messages: list[Any] = [] + msg_event = threading.Event() # Define callback function that stores received messages def callback(message: Any, _: Any) -> None: received_messages.append(message) + msg_event.set() # Subscribe to the topic with our callback x.subscribe(topic, callback) @@ -160,10 +163,8 @@ def callback(message: Any, _: Any) -> None: # Publish the first value to the topic x.publish(topic, values[0]) - # Give Redis time to process the message if needed - time.sleep(0.1) + assert msg_event.wait(timeout=1.0), "Timed out waiting for message" - print("RECEIVED", received_messages) # Verify the callback was called with the correct value assert len(received_messages) == 1 assert received_messages[0] == values[0] @@ -178,13 +179,17 @@ def test_multiple_subscribers( # Create lists to capture received messages for each subscriber received_messages_1: list[Any] = [] received_messages_2: list[Any] = [] + event_1 = threading.Event() + event_2 = threading.Event() # Define callback functions def callback_1(message: Any, topic: Any) -> None: received_messages_1.append(message) + event_1.set() def callback_2(message: Any, topic: Any) -> None: received_messages_2.append(message) + event_2.set() # Subscribe both callbacks to the same topic x.subscribe(topic, callback_1) @@ -193,8 +198,8 @@ def callback_2(message: Any, topic: Any) -> None: # Publish the first value x.publish(topic, values[0]) - # Give Redis time to process the message if needed - time.sleep(0.1) + assert event_1.wait(timeout=1.0), "Timed out waiting for subscriber 1" + assert event_2.wait(timeout=1.0), "Timed out waiting for subscriber 2" # Verify both callbacks received the message assert len(received_messages_1) == 1 @@ -238,21 +243,24 @@ def test_multiple_messages( with pubsub_context() as x: # Create a list to capture received messages received_messages: list[Any] = [] + all_received = threading.Event() + + # Publish the rest of the values (after the first one used in basic tests) + messages_to_send = values[1:] if len(values) > 1 else values # Define callback function def callback(message: Any, topic: Any) -> None: received_messages.append(message) + if len(received_messages) >= len(messages_to_send): + all_received.set() # Subscribe to the topic x.subscribe(topic, callback) - # Publish the rest of the values (after the first one used in basic tests) - messages_to_send = values[1:] if len(values) > 1 else values for msg in messages_to_send: x.publish(topic, msg) - # Give Redis time to process the messages if needed - time.sleep(0.2) + assert all_received.wait(timeout=1.0), "Timed out waiting for all messages" # Verify all messages were received in order assert len(received_messages) == len(messages_to_send) From 0d1493437c2ca8a29c49ae6a60528f0fdf8b4dbb Mon Sep 17 00:00:00 2001 From: Mustafa Bhadsorawala <39084056+mustafab0@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:19:49 -0800 Subject: [PATCH 03/16] Feature: Control Coordinator support for mobile base (#1277) * added twist base protocol spec * created mock twist base adapter * added twistbase ConnectedHardware type * updated coordinator to support twist messages * added flowbase adapter * added blueprint and test script for testing * mypy test fixes * fix validates the adapter/hardware-type pairing upfront with a clear error * fixed redundant blueprint assignments * fixed mypy type ignore flags * added thread locks to velocity read write for flowbase adapter * fix mypy errors: portal import-untyped and adapter property override * added echo cmd vel script back for testing will be deprecated * removed echo_cmd_vel test script --- dimos/control/blueprints.py | 87 +++++++- dimos/control/components.py | 35 +++ dimos/control/coordinator.py | 114 +++++++++- .../examples/twist_base_keyboard_teleop.py | 59 +++++ dimos/control/hardware_interface.py | 103 ++++++++- dimos/hardware/drive_trains/__init__.py | 15 ++ .../drive_trains/flowbase/__init__.py | 15 ++ .../hardware/drive_trains/flowbase/adapter.py | 206 ++++++++++++++++++ dimos/hardware/drive_trains/mock/__init__.py | 30 +++ dimos/hardware/drive_trains/mock/adapter.py | 137 ++++++++++++ dimos/hardware/drive_trains/registry.py | 98 +++++++++ dimos/hardware/drive_trains/spec.py | 95 ++++++++ dimos/robot/all_blueprints.py | 2 + pyproject.toml | 1 + uv.lock | 18 ++ 15 files changed, 999 insertions(+), 16 deletions(-) create mode 100644 dimos/control/examples/twist_base_keyboard_teleop.py create mode 100644 dimos/hardware/drive_trains/__init__.py create mode 100644 dimos/hardware/drive_trains/flowbase/__init__.py create mode 100644 dimos/hardware/drive_trains/flowbase/adapter.py create mode 100644 dimos/hardware/drive_trains/mock/__init__.py create mode 100644 dimos/hardware/drive_trains/mock/adapter.py create mode 100644 dimos/hardware/drive_trains/registry.py create mode 100644 dimos/hardware/drive_trains/spec.py diff --git a/dimos/control/blueprints.py b/dimos/control/blueprints.py index 8762ebd95b..5c33928e79 100644 --- a/dimos/control/blueprints.py +++ b/dimos/control/blueprints.py @@ -30,10 +30,15 @@ from __future__ import annotations -from dimos.control.components import HardwareComponent, HardwareType, make_joints +from dimos.control.components import ( + HardwareComponent, + HardwareType, + make_joints, + make_twist_base_joints, +) from dimos.control.coordinator import TaskConfig, control_coordinator from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs import PoseStamped, Twist from dimos.msgs.sensor_msgs import JointState from dimos.teleop.quest.quest_types import Buttons from dimos.utils.data import LfsPath @@ -594,6 +599,80 @@ ) +# ============================================================================= +# Twist Base Blueprints (velocity-commanded platforms) +# ============================================================================= + +# Mock holonomic twist base (3-DOF: vx, vy, wz) +_base_joints = make_twist_base_joints("base") +coordinator_mock_twist_base = control_coordinator( + hardware=[ + HardwareComponent( + hardware_id="base", + hardware_type=HardwareType.BASE, + joints=_base_joints, + adapter_type="mock_twist_base", + ), + ], + tasks=[ + TaskConfig( + name="vel_base", + type="velocity", + joint_names=_base_joints, + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("twist_command", Twist): LCMTransport("/cmd_vel", Twist), + } +) + + +# ============================================================================= +# Mobile Manipulation Blueprints (arm + twist base) +# ============================================================================= + +# Mock arm (7-DOF) + mock holonomic base (3-DOF) +_mm_base_joints = make_twist_base_joints("base") +coordinator_mobile_manip_mock = control_coordinator( + hardware=[ + HardwareComponent( + hardware_id="arm", + hardware_type=HardwareType.MANIPULATOR, + joints=make_joints("arm", 7), + adapter_type="mock", + ), + HardwareComponent( + hardware_id="base", + hardware_type=HardwareType.BASE, + joints=_mm_base_joints, + adapter_type="mock_twist_base", + ), + ], + tasks=[ + TaskConfig( + name="traj_arm", + type="trajectory", + joint_names=[f"arm_joint{i + 1}" for i in range(7)], + priority=10, + ), + TaskConfig( + name="vel_base", + type="velocity", + joint_names=_mm_base_joints, + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("twist_command", Twist): LCMTransport("/cmd_vel", Twist), + } +) + + # ============================================================================= # Raw Blueprints (for programmatic setup) # ============================================================================= @@ -624,8 +703,12 @@ # Dual arm "coordinator_dual_mock", "coordinator_dual_xarm", + # Mobile manipulation + "coordinator_mobile_manip_mock", # Single arm "coordinator_mock", + # Twist base + "coordinator_mock_twist_base", "coordinator_piper", "coordinator_piper_xarm", # Teleop IK diff --git a/dimos/control/components.py b/dimos/control/components.py index e3022468ed..8157a288d2 100644 --- a/dimos/control/components.py +++ b/dimos/control/components.py @@ -71,7 +71,41 @@ def make_joints(hardware_id: HardwareId, dof: int) -> list[JointName]: return [f"{hardware_id}_joint{i + 1}" for i in range(dof)] +# Maps virtual joint suffix → (Twist group, Twist field) +TWIST_SUFFIX_MAP: dict[str, tuple[str, str]] = { + "vx": ("linear", "x"), + "vy": ("linear", "y"), + "vz": ("linear", "z"), + "wx": ("angular", "x"), + "wy": ("angular", "y"), + "wz": ("angular", "z"), +} + +_DEFAULT_TWIST_SUFFIXES = ["vx", "vy", "wz"] + + +def make_twist_base_joints( + hardware_id: HardwareId, + suffixes: list[str] | None = None, +) -> list[JointName]: + """Create virtual joint names for a twist base. + + Args: + hardware_id: The hardware identifier (e.g., "base") + suffixes: Velocity DOF suffixes. Defaults to ["vx", "vy", "wz"] (holonomic). + + Returns: + List of joint names like ["base_vx", "base_vy", "base_wz"] + """ + suffixes = suffixes or _DEFAULT_TWIST_SUFFIXES + for s in suffixes: + if s not in TWIST_SUFFIX_MAP: + raise ValueError(f"Unknown twist suffix '{s}'. Valid: {list(TWIST_SUFFIX_MAP)}") + return [f"{hardware_id}_{s}" for s in suffixes] + + __all__ = [ + "TWIST_SUFFIX_MAP", "HardwareComponent", "HardwareId", "HardwareType", @@ -79,4 +113,5 @@ def make_joints(hardware_id: HardwareId, dof: int) -> list[JointName]: "JointState", "TaskName", "make_joints", + "make_twist_base_joints", ] diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index 5685a9f9c7..c9182e6aa8 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -32,19 +32,32 @@ import time from typing import TYPE_CHECKING, Any -from dimos.control.components import HardwareComponent, HardwareId, JointName, TaskName -from dimos.control.hardware_interface import ConnectedHardware +from dimos.control.components import ( + TWIST_SUFFIX_MAP, + HardwareComponent, + HardwareId, + HardwareType, + JointName, + TaskName, +) +from dimos.control.hardware_interface import ConnectedHardware, ConnectedTwistBase from dimos.control.task import ControlTask from dimos.control.tick_loop import TickLoop from dimos.core import In, Module, Out, rpc from dimos.core.module import ModuleConfig +from dimos.hardware.drive_trains.spec import ( + TwistBaseAdapter, +) from dimos.msgs.geometry_msgs import ( PoseStamped, # noqa: TC001 - needed at runtime for In[PoseStamped] + Twist, # noqa: TC001 - needed at runtime for In[Twist] ) from dimos.msgs.sensor_msgs import ( - JointState, # noqa: TC001 - needed at runtime for Out[JointState] + JointState, +) +from dimos.teleop.quest.quest_types import ( + Buttons, # noqa: TC001 - needed at runtime for In[Buttons] ) -from dimos.teleop.quest.quest_types import Buttons # noqa: TC001 - needed for teleop buttons from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -148,6 +161,9 @@ class ControlCoordinator(Module[ControlCoordinatorConfig]): # Uses frame_id as task name for routing cartesian_command: In[PoseStamped] + # Input: Streaming twist commands for velocity-commanded platforms + twist_command: In[Twist] + # Input: Teleop buttons for engage/disengage signaling buttons: In[Buttons] @@ -174,6 +190,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # Subscription handles for streaming commands self._joint_command_unsub: Callable[[], None] | None = None self._cartesian_command_unsub: Callable[[], None] | None = None + self._twist_command_unsub: Callable[[], None] | None = None self._buttons_unsub: Callable[[], None] | None = None logger.info(f"ControlCoordinator initialized at {self.config.tick_rate}Hz") @@ -206,7 +223,11 @@ def _setup_from_config(self) -> None: def _setup_hardware(self, component: HardwareComponent) -> None: """Connect and add a single hardware adapter.""" - adapter = self._create_adapter(component) + adapter: ManipulatorAdapter | TwistBaseAdapter + if component.hardware_type == HardwareType.BASE: + adapter = self._create_twist_base_adapter(component) + else: + adapter = self._create_adapter(component) if not adapter.connect(): raise RuntimeError(f"Failed to connect to {component.adapter_type} adapter") @@ -230,6 +251,16 @@ def _create_adapter(self, component: HardwareComponent) -> ManipulatorAdapter: address=component.address, ) + def _create_twist_base_adapter(self, component: HardwareComponent) -> TwistBaseAdapter: + """Create a twist base adapter from component config.""" + from dimos.hardware.drive_trains.registry import twist_base_adapter_registry + + return twist_base_adapter_registry.create( + component.adapter_type, + dof=len(component.joints), + address=component.address, + ) + def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: """Create a control task from config.""" task_type = cfg.type.lower() @@ -310,19 +341,34 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: @rpc def add_hardware( self, - adapter: ManipulatorAdapter, + adapter: ManipulatorAdapter | TwistBaseAdapter, component: HardwareComponent, ) -> bool: """Register a hardware adapter with the coordinator.""" + is_base = component.hardware_type == HardwareType.BASE + if is_base != isinstance(adapter, TwistBaseAdapter): + raise TypeError( + f"Hardware type / adapter mismatch for '{component.hardware_id}': " + f"hardware_type={component.hardware_type.value} but got " + f"{type(adapter).__name__}" + ) + with self._hardware_lock: if component.hardware_id in self._hardware: logger.warning(f"Hardware {component.hardware_id} already registered") return False - connected = ConnectedHardware( - adapter=adapter, - component=component, - ) + if isinstance(adapter, TwistBaseAdapter): + connected: ConnectedHardware = ConnectedTwistBase( + adapter=adapter, + component=component, + ) + else: + connected = ConnectedHardware( + adapter=adapter, + component=component, + ) + self._hardware[component.hardware_id] = connected for joint_name in connected.joint_names: @@ -490,6 +536,34 @@ def _on_cartesian_command(self, msg: PoseStamped) -> None: task.on_cartesian_command(msg, t_now) + def _on_twist_command(self, msg: Twist) -> None: + """Convert Twist → virtual joint velocities and route via _on_joint_command. + + Maps Twist fields to virtual joints using suffix convention: + base_vx ← linear.x, base_vy ← linear.y, base_wz ← angular.z, etc. + """ + names: list[str] = [] + velocities: list[float] = [] + + with self._hardware_lock: + for hw in self._hardware.values(): + if hw.component.hardware_type != HardwareType.BASE: + continue + for joint_name in hw.joint_names: + # Extract suffix (e.g., "base_vx" → "vx") + suffix = joint_name.rsplit("_", 1)[-1] + mapping = TWIST_SUFFIX_MAP.get(suffix) + if mapping is None: + continue + group, axis = mapping + value = getattr(getattr(msg, group), axis) + names.append(joint_name) + velocities.append(value) + + if names: + joint_state = JointState(name=names, velocity=velocities) + self._on_joint_command(joint_state) + def _on_buttons(self, msg: Buttons) -> None: """Forward button state to all tasks.""" with self._task_lock: @@ -536,6 +610,9 @@ def set_gripper_position(self, hardware_id: str, position: float) -> bool: if hw is None: logger.warning(f"Hardware '{hardware_id}' not found for gripper command") return False + if isinstance(hw, ConnectedTwistBase): + logger.warning(f"Hardware '{hardware_id}' is a twist base, no gripper support") + return False return hw.adapter.write_gripper_position(position) @rpc @@ -549,6 +626,8 @@ def get_gripper_position(self, hardware_id: str) -> float | None: hw = self._hardware.get(hardware_id) if hw is None: return None + if isinstance(hw, ConnectedTwistBase): + return None return hw.adapter.read_gripper_position() # ========================================================================= @@ -610,6 +689,18 @@ def start(self) -> None: "Use task_invoke RPC or set transport via blueprint." ) + # Subscribe to twist commands if any twist base hardware configured + has_twist_base = any(c.hardware_type == HardwareType.BASE for c in self.config.hardware) + if has_twist_base: + try: + self._twist_command_unsub = self.twist_command.subscribe(self._on_twist_command) + logger.info("Subscribed to twist_command for twist base control") + except Exception: + logger.warning( + "Twist base configured but could not subscribe to twist_command. " + "Use task_invoke RPC or set transport via blueprint." + ) + # Subscribe to buttons if any teleop_ik tasks configured (engage/disengage) has_teleop_ik = any(t.type == "teleop_ik" for t in self.config.tasks) if has_teleop_ik: @@ -630,6 +721,9 @@ def stop(self) -> None: if self._cartesian_command_unsub: self._cartesian_command_unsub() self._cartesian_command_unsub = None + if self._twist_command_unsub: + self._twist_command_unsub() + self._twist_command_unsub = None if self._buttons_unsub: self._buttons_unsub() self._buttons_unsub = None diff --git a/dimos/control/examples/twist_base_keyboard_teleop.py b/dimos/control/examples/twist_base_keyboard_teleop.py new file mode 100644 index 0000000000..2d7651145a --- /dev/null +++ b/dimos/control/examples/twist_base_keyboard_teleop.py @@ -0,0 +1,59 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Keyboard teleop for twist base via ControlCoordinator. + +Runs a mock holonomic twist base with pygame keyboard control. +WASD keys publish Twist → coordinator's twist_command port → virtual joints +→ tick loop → MockTwistBaseAdapter. + +Controls: + W/S: Forward/backward (linear.x) + Q/E: Strafe left/right (linear.y) + A/D: Turn left/right (angular.z) + Shift: 2x boost + Ctrl: 0.5x slow + Space: Emergency stop + ESC: Quit + +Usage: + python -m dimos.control.examples.twist_base_keyboard_teleop +""" + +from __future__ import annotations + +from dimos.control.blueprints import coordinator_mock_twist_base +from dimos.robot.unitree.keyboard_teleop import keyboard_teleop + + +def main() -> None: + """Run mock twist base + keyboard teleop.""" + coord = coordinator_mock_twist_base.build() + teleop = keyboard_teleop().build() + + print("Starting mock twist base coordinator + keyboard teleop...") + print("Coordinator tick loop: 100Hz") + print("Keyboard teleop: 50Hz on /cmd_vel") + print() + + coord.start() + teleop.start() + + # Block until Ctrl+C — loop() handles KeyboardInterrupt and calls stop() + coord.loop() + teleop.stop() + + +if __name__ == "__main__": + main() diff --git a/dimos/control/hardware_interface.py b/dimos/control/hardware_interface.py index 9f6eb99851..4e5d1d634c 100644 --- a/dimos/control/hardware_interface.py +++ b/dimos/control/hardware_interface.py @@ -14,10 +14,12 @@ """Connected hardware for the ControlCoordinator. -Wraps ManipulatorAdapter with coordinator-specific features: -- Namespaced joint names (e.g., "left_joint1") -- Unified read/write interface -- Hold-last-value for partial commands +Provides two wrapper types: +- ConnectedHardware: Wraps ManipulatorAdapter for joint-controlled arms +- ConnectedTwistBase: Wraps TwistBaseAdapter for velocity-commanded platforms + +Both share the same duck-type interface (read_state, write_command, etc.) +so the tick loop treats them uniformly. """ from __future__ import annotations @@ -30,6 +32,7 @@ if TYPE_CHECKING: from dimos.control.components import HardwareComponent, HardwareId, JointName, JointState + from dimos.hardware.drive_trains.spec import TwistBaseAdapter logger = logging.getLogger(__name__) @@ -193,6 +196,98 @@ def _build_ordered_command(self) -> list[float]: return [self._last_commanded[name] for name in self._joint_names] +class ConnectedTwistBase(ConnectedHardware): + """Runtime wrapper for a twist base connected to the coordinator. + + Inherits from ConnectedHardware and overrides behavior for + velocity-commanded platforms (holonomic bases, drones, quadrupeds, etc.). + + Key differences from ConnectedHardware: + - Positions come from odometry (or zeros if unavailable) + - Efforts are always zero + - write_command always sends velocities regardless of mode + - No retry loop for initialization (twist bases start at zero velocity) + """ + + _twist_adapter: TwistBaseAdapter + + def __init__( + self, + adapter: TwistBaseAdapter, + component: HardwareComponent, + ) -> None: + from dimos.hardware.drive_trains.spec import TwistBaseAdapter as TwistBaseAdapterProto + + if not isinstance(adapter, TwistBaseAdapterProto): + raise TypeError("adapter must implement TwistBaseAdapter") + + self._twist_adapter = adapter + self._component = component + self._joint_names = component.joints + + # Twist bases start at zero velocity — no need to read from hardware + self._last_commanded: dict[str, float] = {name: 0.0 for name in self._joint_names} + self._initialized = True + self._warned_unknown_joints: set[str] = set() + self._current_mode: ControlMode | None = None + + @property + def adapter(self) -> TwistBaseAdapter: # type: ignore[override] + """The underlying twist base adapter.""" + return self._twist_adapter + + def disconnect(self) -> None: + """Disconnect the underlying adapter.""" + self._twist_adapter.disconnect() + + def read_state(self) -> dict[JointName, JointState]: + """Read state as {joint_name: JointState}. + + Positions come from odometry (zeros if unavailable). + Velocities from adapter. Efforts are always zero. + """ + from dimos.control.components import JointState + + velocities = self._twist_adapter.read_velocities() + odometry = self._twist_adapter.read_odometry() + positions = odometry if odometry is not None else [0.0] * self.dof + + return { + name: JointState( + position=positions[i], + velocity=velocities[i], + effort=0.0, + ) + for i, name in enumerate(self._joint_names) + } + + def write_command(self, commands: dict[str, float], _mode: ControlMode) -> bool: + """Write velocity commands — always sends velocities regardless of mode. + + Args: + commands: {joint_name: velocity} - can be partial + _mode: Control mode (ignored — twist bases always use velocity) + + Returns: + True if command was sent successfully + """ + # Update last commanded for joints we received + for joint_name, value in commands.items(): + if joint_name in self._last_commanded: + self._last_commanded[joint_name] = value + elif joint_name not in self._warned_unknown_joints: + logger.warning( + f"TwistBase {self.hardware_id} received command for unknown joint " + f"{joint_name}. Valid joints: {self._joint_names}" + ) + self._warned_unknown_joints.add(joint_name) + + # Build ordered velocity list and send + ordered = self._build_ordered_command() + return self._twist_adapter.write_velocities(ordered) + + __all__ = [ "ConnectedHardware", + "ConnectedTwistBase", ] diff --git a/dimos/hardware/drive_trains/__init__.py b/dimos/hardware/drive_trains/__init__.py new file mode 100644 index 0000000000..c6e843feea --- /dev/null +++ b/dimos/hardware/drive_trains/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Drive train hardware adapters for velocity-commanded platforms.""" diff --git a/dimos/hardware/drive_trains/flowbase/__init__.py b/dimos/hardware/drive_trains/flowbase/__init__.py new file mode 100644 index 0000000000..25f95e399c --- /dev/null +++ b/dimos/hardware/drive_trains/flowbase/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FlowBase twist base adapter for holonomic base control via Portal RPC.""" diff --git a/dimos/hardware/drive_trains/flowbase/adapter.py b/dimos/hardware/drive_trains/flowbase/adapter.py new file mode 100644 index 0000000000..5b5563792d --- /dev/null +++ b/dimos/hardware/drive_trains/flowbase/adapter.py @@ -0,0 +1,206 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FlowBase adapter — wraps Portal RPC client for holonomic base control. + +Frame convention: FlowBase uses inverted Y-axis compared to standard convention. +We negate vy and wz when sending to the hardware. + + Standard (ROS): FlowBase: + +Y -Y + ↑ ↑ + ───┼──→ +X ───┼──→ +X + | | +""" + +from __future__ import annotations + +import logging +import threading +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from dimos.hardware.drive_trains.registry import TwistBaseAdapterRegistry + +logger = logging.getLogger(__name__) + + +class FlowBaseAdapter: + """TwistBaseAdapter implementation for FlowBase holonomic platform. + + Communicates with FlowBase controller via Portal RPC over TCP. + Expects 3 DOF: [vx, vy, wz] (holonomic base). + + Args: + dof: Number of velocity DOFs (must be 3 for FlowBase) + address: Portal RPC address as "host:port" (default: "172.6.2.20:11323") + """ + + def __init__(self, dof: int = 3, address: str | None = None, **_: object) -> None: + if dof != 3: + raise ValueError(f"FlowBase only supports 3 DOF (holonomic), got {dof}") + + self._address = address or "172.6.2.20:11323" + self._client = None + self._connected = False + self._enabled = False + self._lock = threading.Lock() + + # Last commanded velocities (in standard frame, before negation) + self._last_velocities = [0.0, 0.0, 0.0] + + # ========================================================================= + # Connection + # ========================================================================= + + def connect(self) -> bool: + """Connect to FlowBase controller via Portal RPC.""" + try: + import portal # type: ignore[import-untyped] + + self._client = portal.Client(self._address) + self._connected = True + logger.info(f"Connected to FlowBase at {self._address}") + return True + except Exception as e: + logger.error(f"Failed to connect to FlowBase at {self._address}: {e}") + self._connected = False + return False + + def disconnect(self) -> None: + """Disconnect and send zero velocity.""" + if self._connected and self._client: + try: + self._send_velocity(0.0, 0.0, 0.0) + except Exception: + pass + try: + self._client.close() + except Exception: + pass + self._connected = False + self._client = None + + def is_connected(self) -> bool: + """Check if connected to FlowBase.""" + return self._connected + + # ========================================================================= + # Info + # ========================================================================= + + def get_dof(self) -> int: + """FlowBase is always 3 DOF (vx, vy, wz).""" + return 3 + + # ========================================================================= + # State Reading + # ========================================================================= + + def read_velocities(self) -> list[float]: + """Return last commanded velocities (FlowBase doesn't report actual).""" + with self._lock: + return self._last_velocities.copy() + + def read_odometry(self) -> list[float] | None: + """Read odometry from FlowBase as [x, y, theta].""" + if not self._connected or not self._client: + return None + + try: + with self._lock: + odom = self._client.get_odometry({}).result() + + if odom is None: + return None + + translation = odom["translation"] # [x, y] + rotation = odom["rotation"] # theta in radians + return [float(translation[0]), float(translation[1]), float(rotation)] + except Exception as e: + logger.error(f"Error reading FlowBase odometry: {e}") + return None + + # ========================================================================= + # Control + # ========================================================================= + + def write_velocities(self, velocities: list[float]) -> bool: + """Send velocity command to FlowBase. + + Args: + velocities: [vx, vy, wz] in standard frame (m/s, rad/s) + """ + if len(velocities) != 3: + return False + + if not self._connected or not self._client: + return False + + vx, vy, wz = velocities + with self._lock: + self._last_velocities = list(velocities) + + # Negate vy and wz for FlowBase's inverted Y-axis frame + return self._send_velocity(vx, -vy, -wz) + + def write_stop(self) -> bool: + """Stop all motion.""" + with self._lock: + self._last_velocities = [0.0, 0.0, 0.0] + if not self._connected or not self._client: + return False + return self._send_velocity(0.0, 0.0, 0.0) + + # ========================================================================= + # Enable/Disable + # ========================================================================= + + def write_enable(self, enable: bool) -> bool: + """Enable/disable the platform (FlowBase is always enabled when connected).""" + self._enabled = enable + return True + + def read_enabled(self) -> bool: + """Check if platform is enabled.""" + return self._enabled + + # ========================================================================= + # Internal + # ========================================================================= + + def _send_velocity(self, vx: float, vy: float, wz: float) -> bool: + """Send raw velocity to FlowBase via Portal RPC.""" + try: + command = { + "target_velocity": np.array([vx, vy, wz]), + "frame": "local", + } + with self._lock: + assert self._client is not None + self._client.set_target_velocity(command).result() + return True + except Exception as e: + logger.error(f"Error sending FlowBase velocity: {e}") + return False + + +def register(registry: TwistBaseAdapterRegistry) -> None: + """Register this adapter with the registry.""" + registry.register("flowbase", FlowBaseAdapter) + + +__all__ = ["FlowBaseAdapter"] diff --git a/dimos/hardware/drive_trains/mock/__init__.py b/dimos/hardware/drive_trains/mock/__init__.py new file mode 100644 index 0000000000..9b6f630040 --- /dev/null +++ b/dimos/hardware/drive_trains/mock/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mock twist base adapter for testing without hardware. + +Usage: + >>> from dimos.hardware.drive_trains.mock import MockTwistBaseAdapter + >>> adapter = MockTwistBaseAdapter(dof=3) + >>> adapter.connect() + True + >>> adapter.write_velocities([0.5, 0.0, 0.1]) + True + >>> adapter.read_velocities() + [0.5, 0.0, 0.1] +""" + +from dimos.hardware.drive_trains.mock.adapter import MockTwistBaseAdapter + +__all__ = ["MockTwistBaseAdapter"] diff --git a/dimos/hardware/drive_trains/mock/adapter.py b/dimos/hardware/drive_trains/mock/adapter.py new file mode 100644 index 0000000000..2091ec59d0 --- /dev/null +++ b/dimos/hardware/drive_trains/mock/adapter.py @@ -0,0 +1,137 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mock twist base adapter for testing - no hardware required. + +Usage: + >>> from dimos.hardware.drive_trains.mock import MockTwistBaseAdapter + >>> adapter = MockTwistBaseAdapter(dof=3) + >>> adapter.connect() + True + >>> adapter.write_velocities([0.5, 0.0, 0.1]) + True +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dimos.hardware.drive_trains.registry import TwistBaseAdapterRegistry + + +class MockTwistBaseAdapter: + """Fake twist base adapter for unit tests. + + Implements TwistBaseAdapter protocol with in-memory state. + Useful for: + - Unit testing coordinator logic without hardware + - Integration testing with predictable behavior + - Development without a physical base + """ + + def __init__(self, dof: int = 3, **_: object) -> None: + self._dof = dof + self._velocities = [0.0] * dof + self._odometry: list[float] | None = [0.0] * dof + self._enabled = False + self._connected = False + + # ========================================================================= + # Connection + # ========================================================================= + + def connect(self) -> bool: + """Simulate connection.""" + self._connected = True + return True + + def disconnect(self) -> None: + """Simulate disconnection.""" + self._connected = False + + def is_connected(self) -> bool: + """Check mock connection status.""" + return self._connected + + # ========================================================================= + # Info + # ========================================================================= + + def get_dof(self) -> int: + """Return DOF.""" + return self._dof + + # ========================================================================= + # State Reading + # ========================================================================= + + def read_velocities(self) -> list[float]: + """Return mock velocities.""" + return self._velocities.copy() + + def read_odometry(self) -> list[float] | None: + """Return mock odometry.""" + if self._odometry is None: + return None + return self._odometry.copy() + + # ========================================================================= + # Control + # ========================================================================= + + def write_velocities(self, velocities: list[float]) -> bool: + """Set mock velocities.""" + if len(velocities) != self._dof: + return False + self._velocities = list(velocities) + return True + + def write_stop(self) -> bool: + """Stop mock motion.""" + self._velocities = [0.0] * self._dof + return True + + # ========================================================================= + # Enable/Disable + # ========================================================================= + + def write_enable(self, enable: bool) -> bool: + """Enable/disable mock platform.""" + self._enabled = enable + return True + + def read_enabled(self) -> bool: + """Check mock enable state.""" + return self._enabled + + # ========================================================================= + # Test Helpers (not part of Protocol) + # ========================================================================= + + def set_odometry(self, odometry: list[float] | None) -> None: + """Set odometry directly for testing.""" + self._odometry = list(odometry) if odometry is not None else None + + def set_velocities_directly(self, velocities: list[float]) -> None: + """Set velocities directly for testing (bypasses DOF check).""" + self._velocities = list(velocities) + + +def register(registry: TwistBaseAdapterRegistry) -> None: + """Register this adapter with the registry.""" + registry.register("mock_twist_base", MockTwistBaseAdapter) + + +__all__ = ["MockTwistBaseAdapter"] diff --git a/dimos/hardware/drive_trains/registry.py b/dimos/hardware/drive_trains/registry.py new file mode 100644 index 0000000000..0a513d2bd4 --- /dev/null +++ b/dimos/hardware/drive_trains/registry.py @@ -0,0 +1,98 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TwistBase adapter registry with auto-discovery. + +Automatically discovers and registers twist base adapters from subpackages. +Each adapter provides a `register()` function in its adapter.py module. + +Usage: + from dimos.hardware.drive_trains.registry import twist_base_adapter_registry + + # Create an adapter by name + adapter = twist_base_adapter_registry.create("mock_twist_base", dof=3) + adapter = twist_base_adapter_registry.create("flowbase", dof=3, address="172.6.2.20:11323") + + # List available adapters + print(twist_base_adapter_registry.available()) # ["flowbase", "mock_twist_base"] +""" + +from __future__ import annotations + +import importlib +import logging +import pkgutil +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from dimos.hardware.drive_trains.spec import TwistBaseAdapter + +logger = logging.getLogger(__name__) + + +class TwistBaseAdapterRegistry: + """Registry for twist base adapters with auto-discovery.""" + + def __init__(self) -> None: + self._adapters: dict[str, type[TwistBaseAdapter]] = {} + + def register(self, name: str, cls: type[TwistBaseAdapter]) -> None: + """Register an adapter class.""" + self._adapters[name.lower()] = cls + + def create(self, name: str, **kwargs: Any) -> TwistBaseAdapter: + """Create an adapter instance by name. + + Args: + name: Adapter name (e.g., "mock_twist_base", "flowbase") + **kwargs: Arguments passed to adapter constructor + + Returns: + Configured adapter instance + + Raises: + KeyError: If adapter name is not found + """ + key = name.lower() + if key not in self._adapters: + raise KeyError(f"Unknown twist base adapter: {name}. Available: {self.available()}") + + return self._adapters[key](**kwargs) + + def available(self) -> list[str]: + """List available adapter names.""" + return sorted(self._adapters.keys()) + + def discover(self) -> None: + """Discover and register adapters from subpackages. + + Can be called multiple times to pick up newly added adapters. + """ + import dimos.hardware.drive_trains as pkg + + for _, name, ispkg in pkgutil.iter_modules(pkg.__path__): + if not ispkg: + continue + try: + module = importlib.import_module(f"dimos.hardware.drive_trains.{name}.adapter") + if hasattr(module, "register"): + module.register(self) + except ImportError as e: + logger.warning(f"Skipping twist base adapter {name}: {e}") + + +twist_base_adapter_registry = TwistBaseAdapterRegistry() +twist_base_adapter_registry.discover() + +__all__ = ["TwistBaseAdapterRegistry", "twist_base_adapter_registry"] diff --git a/dimos/hardware/drive_trains/spec.py b/dimos/hardware/drive_trains/spec.py new file mode 100644 index 0000000000..0b288edfd4 --- /dev/null +++ b/dimos/hardware/drive_trains/spec.py @@ -0,0 +1,95 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TwistBase adapter protocol for velocity-commanded platforms. + +Lightweight protocol for mobile bases, quadrupeds, drones, RC cars, +and any other platform that accepts Twist (velocity) commands. + +Virtual joint ordering is defined by the HardwareComponent.joints list. +For a holonomic base: [vx, vy, wz] maps to joints ["base_vx", "base_vy", "base_wz"]. +""" + +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class TwistBaseAdapter(Protocol): + """Protocol for velocity-commanded platform IO. + + Implement this per vendor SDK. All methods use SI units: + - Linear velocity: m/s + - Angular velocity: rad/s + - Position: meters + - Angle: radians + """ + + # --- Connection --- + + def connect(self) -> bool: + """Connect to hardware. Returns True on success.""" + ... + + def disconnect(self) -> None: + """Disconnect from hardware.""" + ... + + def is_connected(self) -> bool: + """Check if connected.""" + ... + + # --- Info --- + + def get_dof(self) -> int: + """Get number of velocity DOFs (e.g., 3 for holonomic, 2 for differential).""" + ... + + # --- State Reading --- + + def read_velocities(self) -> list[float]: + """Read current velocities in virtual joint order (m/s or rad/s).""" + ... + + def read_odometry(self) -> list[float] | None: + """Read position estimate in virtual joint order. + + For a holonomic base this would be [x, y, theta]. + Returns None if the platform doesn't provide odometry. + """ + ... + + # --- Control --- + + def write_velocities(self, velocities: list[float]) -> bool: + """Command velocities in virtual joint order. Returns success.""" + ... + + def write_stop(self) -> bool: + """Stop all motion immediately (zero velocities).""" + ... + + # --- Enable/Disable --- + + def write_enable(self, enable: bool) -> bool: + """Enable or disable the platform. Returns success.""" + ... + + def read_enabled(self) -> bool: + """Check if platform is enabled.""" + ... + + +__all__ = [ + "TwistBaseAdapter", +] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index bdfd98cd17..0e23c82065 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -27,7 +27,9 @@ "coordinator-combined-xarm6": "dimos.control.blueprints:coordinator_combined_xarm6", "coordinator-dual-mock": "dimos.control.blueprints:coordinator_dual_mock", "coordinator-dual-xarm": "dimos.control.blueprints:coordinator_dual_xarm", + "coordinator-mobile-manip-mock": "dimos.control.blueprints:coordinator_mobile_manip_mock", "coordinator-mock": "dimos.control.blueprints:coordinator_mock", + "coordinator-mock-twist-base": "dimos.control.blueprints:coordinator_mock_twist_base", "coordinator-piper": "dimos.control.blueprints:coordinator_piper", "coordinator-piper-xarm": "dimos.control.blueprints:coordinator_piper_xarm", "coordinator-teleop-dual": "dimos.control.blueprints:coordinator_teleop_dual", diff --git a/pyproject.toml b/pyproject.toml index 9dea7e1921..ee7c4778b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,7 @@ misc = [ # Hardware SDKs "xarm-python-sdk>=1.17.0", + "portal", ] visualization = [ diff --git a/uv.lock b/uv.lock index d971dcfeaa..53b2454b40 100644 --- a/uv.lock +++ b/uv.lock @@ -1964,6 +1964,7 @@ misc = [ { name = "onnx" }, { name = "open-clip-torch" }, { name = "opencv-contrib-python" }, + { name = "portal" }, { name = "python-multipart" }, { name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -2121,6 +2122,7 @@ requires-dist = [ { name = "plotly", marker = "extra == 'manipulation'", specifier = ">=5.9.0" }, { name = "plum-dispatch", specifier = "==2.5.7" }, { name = "plum-dispatch", marker = "extra == 'docker'", specifier = "==2.5.7" }, + { name = "portal", marker = "extra == 'misc'" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = "==4.2.0" }, { name = "psycopg2-binary", marker = "extra == 'psql'", specifier = ">=2.9.11" }, { name = "py-spy", marker = "extra == 'dev'" }, @@ -6944,6 +6946,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bf/18/72c216f4ab0c82b907009668f79183ae029116ff0dd245d56ef58aac48e7/polars_runtime_32-1.38.1-cp310-abi3-win_arm64.whl", hash = "sha256:6d07d0cc832bfe4fb54b6e04218c2c27afcfa6b9498f9f6bbf262a00d58cc7c4", size = 41639413, upload-time = "2026-02-06T18:12:22.044Z" }, ] +[[package]] +name = "portal" +version = "3.7.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cloudpickle" }, + { name = "msgpack" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "psutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/57/11/c67a1b771901e4c941fe3dcda763b78a29b6c45308e3ebaf99bac96820d8/portal-3.7.4.tar.gz", hash = "sha256:67234267d1eb319fe790653822d4a8d0e0e5312fb29fd8f440d8287066f478b9", size = 17380, upload-time = "2026-01-12T18:17:45.727Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/14/0f7d227894831d2d7eb7f2c6946e8cad8e86da6135b6f902bb961d948f04/portal-3.7.4-py3-none-any.whl", hash = "sha256:3801a489766d3ec2eb73ca8cefd29c54e166d4cf5cfdf1a079ac93fe1130bedb", size = 23486, upload-time = "2026-01-12T18:17:44.326Z" }, +] + [[package]] name = "portalocker" version = "3.2.0" From b150b326278dae4ee45f43cfa038952a9b5ffc61 Mon Sep 17 00:00:00 2001 From: leshy Date: Tue, 24 Feb 2026 04:39:23 +0800 Subject: [PATCH 04/16] docs: move adding_a_custom_arm to manipulation, remove depth_camera_integration (#1351) Move adding_a_custom_arm.md from docs/development/ to docs/capabilities/manipulation/ where it belongs. Remove the outdated depth_camera_integration.md and fix resulting broken links. --- .../manipulation}/adding_a_custom_arm.md | 0 docs/capabilities/manipulation/readme.md | 4 + docs/development/README.md | 1 - docs/development/depth_camera_integration.md | 147 ------------------ 4 files changed, 4 insertions(+), 148 deletions(-) rename docs/{development => capabilities/manipulation}/adding_a_custom_arm.md (100%) delete mode 100644 docs/development/depth_camera_integration.md diff --git a/docs/development/adding_a_custom_arm.md b/docs/capabilities/manipulation/adding_a_custom_arm.md similarity index 100% rename from docs/development/adding_a_custom_arm.md rename to docs/capabilities/manipulation/adding_a_custom_arm.md diff --git a/docs/capabilities/manipulation/readme.md b/docs/capabilities/manipulation/readme.md index 91dada0395..4a943e6be5 100644 --- a/docs/capabilities/manipulation/readme.md +++ b/docs/capabilities/manipulation/readme.md @@ -99,6 +99,10 @@ KeyboardTeleopModule ──→ ControlCoordinator ──→ ManipulationModule | XArm6 | 6 | Y | Y | — | | XArm7 | 7 | Y | Y | Y | +## Adding a Custom Arm + +[guide is here](adding_a_custom_arm.md) + ## Key Files | File | Description | diff --git a/docs/development/README.md b/docs/development/README.md index 130e86fdaa..87936a100d 100644 --- a/docs/development/README.md +++ b/docs/development/README.md @@ -222,7 +222,6 @@ This will save the rerun data to `rerun.json` in the current directory. ## Where is `` located? (Architecture) * If you want to add a `dimos run ` command see [dimos_run.md](/docs/development/dimos_run.md) -* If you want to add a camera driver see [depth_camera_integration.md](/docs/development/depth_camera_integration.md) * For edits to manipulation see [manipulation](/dimos/hardware/manipulators/README.md) and the related modules under `dimos/manipulation/`. * `dimos/core/`: Is where stuff like `Module`, `In`, `Out`, and `RPC` live. * `dimos/robot/`: Robot-specific modules live here. diff --git a/docs/development/depth_camera_integration.md b/docs/development/depth_camera_integration.md deleted file mode 100644 index e152394262..0000000000 --- a/docs/development/depth_camera_integration.md +++ /dev/null @@ -1,147 +0,0 @@ -# Depth Camera Integration Guide - -This folder contains camera drivers and modules for RGB-D (depth) cameras such as RealSense and ZED. -Use this guide to add a new depth camera, wire TF correctly, and publish the required streams. - -## Add a New Depth Camera - -1) **Create a new driver module** - - Path: `dimos/hardware/sensors/camera//camera.py` - - Export a blueprint in `/__init__.py` (match the `realsense` / `zed` pattern). - -2) **Define config** - - Inherit from `ModuleConfig` and `DepthCameraConfig`: - ```python - @dataclass - class MyDepthCameraConfig(ModuleConfig, DepthCameraConfig): - width: int = 1280 - height: int = 720 - fps: int = 15 - camera_name: str = "camera" - base_frame_id: str = "base_link" - base_transform: Transform | None = field(default_factory=default_base_transform) - align_depth_to_color: bool = True - enable_depth: bool = True - enable_pointcloud: bool = False - pointcloud_fps: float = 5.0 - camera_info_fps: float = 1.0 - ``` - -3) **Implement the module** - - Inherit from `DepthCameraHardware` and `Module` (see `RealSenseCamera` / `ZEDCamera`). - - Provide these outputs (matching `RealSenseCamera` / `ZEDCamera`): - - `color_image: Out[Image]` - - `depth_image: Out[Image]` - - `pointcloud: Out[PointCloud2]` (optional, can be disabled by config) - - `camera_info: Out[CameraInfo]` - - `depth_camera_info: Out[CameraInfo]` - - Implement RPCs: - - `start()` / `stop()` - - `get_color_camera_info()` / `get_depth_camera_info()` - - `get_depth_scale()` (meters per depth unit) - -4) **Publish frames** - - Color images: `Image(format=ImageFormat.RGB, frame_id=_color_optical_frame)` - - Depth images: - - If `align_depth_to_color`: use `_color_optical_frame` - - Else: use `_depth_optical_frame` - - CameraInfo frame_id must match the image frame_id you publish. - -5) **Publish camera info** - - Build `CameraInfo` from camera intrinsics. - - Publish at `camera_info_fps`. - -6) **Publish pointcloud (optional)** - - Use `PointCloud2.from_rgbd(color_image, depth_image, camera_info, depth_scale)`. - - Publish at `pointcloud_fps`. - -## TF: Required Frames and Transforms - -Frame names are defined by the abstract depth camera spec (`dimos/hardware/sensors/camera/spec.py`). -Use the properties below to ensure consistent naming: - -- `_camera_link`: base link for the camera module (usually `{camera_name}_link`) -- `_color_frame`: non-optical color frame -- `_color_optical_frame`: optical color frame -- `_depth_frame`: non-optical depth frame -- `_depth_optical_frame`: optical depth frame - -Recommended transform chain (publish every frame or at your preferred TF rate): - -1) **Mounting transform** (from config): - - `base_frame_id -> _camera_link` - - Use `config.base_transform` if provided - -2) **Depth frame** - - `_camera_link -> _depth_frame` (identity unless the camera provides extrinsics) - - `_depth_frame -> _depth_optical_frame` using `OPTICAL_ROTATION` - -3) **Color frame** - - `_camera_link -> _color_frame` (from extrinsics, or identity if unavailable) - - `_color_frame -> _color_optical_frame` using `OPTICAL_ROTATION` - -Notes: -- If you align depth to color, keep TFs the same but publish depth images in `_color_optical_frame`. -- Ensure `color_image.frame_id` and `camera_info.frame_id` match. Same for depth. - -## Required Streams / Topics - -Use these stream names in your module and attach transports as needed. -Default LCM topics in `realsense` / `zed` demos are shown below. - -| Stream name | Type | Suggested topic | Frame ID source | -|-------------------|--------------|-------------------------|-----------------| -| `color_image` | `Image` | `/camera/color` | `_color_optical_frame` | -| `depth_image` | `Image` | `/camera/depth` | `_depth_optical_frame` or `_color_optical_frame` | -| `pointcloud` | `PointCloud2`| `/camera/pointcloud` | (derived from CameraInfo) | -| `camera_info` | `CameraInfo` | `/camera/color_info` | matches `color_image` | -| `depth_camera_info` | `CameraInfo` | `/camera/depth_info` | matches `depth_image` | - -For `ObjectSceneRegistrationModule`, the required inputs are: -- `color_image` -- `depth_image` -- `camera_info` -- TF tree resolving `target_frame` to `color_image.frame_id` - -## Object Scene Registration (Brief Overview) - -`ObjectSceneRegistrationModule` consumes synchronized RGB + depth + camera intrinsics and produces: -- 2D detections (YOLO‑E) -- 3D detections (projected via depth + intrinsics + TF) -- Overlay annotations and aggregated pointclouds - -See: -- `dimos/perception/object_scene_registration.py` -- `dimos/perception/demo_object_scene_registration.py` - -Quick wiring example: - -```python -from dimos.core.blueprints import autoconnect -from dimos.hardware.sensors.camera.realsense import realsense_camera -from dimos.perception.object_scene_registration import object_scene_registration_module - -pipeline = autoconnect( - realsense_camera(enable_pointcloud=False), - object_scene_registration_module(target_frame="world"), -) -``` - -Run the demo via CLI: -```bash -dimos run demo-object-scene-registration -``` - -## Foxglove (Viewer) - -Install Foxglove from: -- https://foxglove.dev/download - -## Modules and Skills (Short Intro) - -- **Modules** are typed components with `In[...]` / `Out[...]` streams and `start()` / `stop()` lifecycles. -- **Skills** are callable methods (decorated with `@skill`) on any `Module`, automatically discovered by agents. - -Reference: -- Modules overview: `/docs/usage/modules.md` -- TF fundamentals: `/docs/usage/transforms.md` From e7594ae55355624f9aa96eb7b32623ae70439280 Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Tue, 24 Feb 2026 02:48:41 +0200 Subject: [PATCH 05/16] feat(mcp): add mcp client and server (#1300) * Moving towards deprecating `Agent`. It has been efectivelly split between `McpClient` and `McpServer`. * McpServer exposes all the `@skill` functions as MCP tools. * McpClient starts a langchain agent which coordinates calling tools from McpServer. * McpClient is not strictly necessary. You can register the MCP server in claude code and make it call tools. * McpClient also listens to `/human_input` so it can be used through the `humancli` like `Agent` can. --- .gitignore | 2 + dimos/agents/agent.py | 5 +- dimos/agents/conftest.py | 5 - dimos/agents/mcp/README.md | 55 ++++ dimos/{protocol => agents}/mcp/__init__.py | 0 dimos/agents/mcp/conftest.py | 103 ++++++++ .../test_can_call_again_on_error[False].json | 34 +++ .../test_can_call_again_on_error[True].json | 34 +++ .../fixtures/test_can_call_tool[False].json | 22 ++ .../fixtures/test_can_call_tool[True].json | 22 ++ dimos/agents/mcp/fixtures/test_image.json | 23 ++ ...ple_tool_calls_with_multiple_messages.json | 116 ++++++++ dimos/agents/mcp/fixtures/test_prompt.json | 8 + dimos/agents/mcp/mcp_client.py | 250 ++++++++++++++++++ dimos/agents/mcp/mcp_server.py | 197 ++++++++++++++ dimos/agents/mcp/test_mcp_client.py | 210 +++++++++++++++ dimos/agents/mcp/test_mcp_client_unit.py | 145 ++++++++++ .../mcp/test_mcp_server.py} | 73 ++--- dimos/protocol/mcp/README.md | 35 --- dimos/protocol/mcp/__main__.py | 36 --- dimos/protocol/mcp/bridge.py | 53 ---- dimos/protocol/mcp/mcp.py | 139 ---------- dimos/robot/all_blueprints.py | 1 + .../agentic/unitree_go2_agentic_mcp.py | 12 +- pyproject.toml | 3 - uv.lock | 52 ---- 26 files changed, 1260 insertions(+), 375 deletions(-) create mode 100644 dimos/agents/mcp/README.md rename dimos/{protocol => agents}/mcp/__init__.py (100%) create mode 100644 dimos/agents/mcp/conftest.py create mode 100644 dimos/agents/mcp/fixtures/test_can_call_again_on_error[False].json create mode 100644 dimos/agents/mcp/fixtures/test_can_call_again_on_error[True].json create mode 100644 dimos/agents/mcp/fixtures/test_can_call_tool[False].json create mode 100644 dimos/agents/mcp/fixtures/test_can_call_tool[True].json create mode 100644 dimos/agents/mcp/fixtures/test_image.json create mode 100644 dimos/agents/mcp/fixtures/test_multiple_tool_calls_with_multiple_messages.json create mode 100644 dimos/agents/mcp/fixtures/test_prompt.json create mode 100644 dimos/agents/mcp/mcp_client.py create mode 100644 dimos/agents/mcp/mcp_server.py create mode 100644 dimos/agents/mcp/test_mcp_client.py create mode 100644 dimos/agents/mcp/test_mcp_client_unit.py rename dimos/{protocol/mcp/test_mcp_module.py => agents/mcp/test_mcp_server.py} (62%) delete mode 100644 dimos/protocol/mcp/README.md delete mode 100644 dimos/protocol/mcp/__main__.py delete mode 100644 dimos/protocol/mcp/bridge.py delete mode 100644 dimos/protocol/mcp/mcp.py diff --git a/.gitignore b/.gitignore index 24a3dd8919..f97d9f906a 100644 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,5 @@ yolo11n.pt CLAUDE.MD /assets/teleop_certs/ + +/.mcp.json diff --git a/dimos/agents/agent.py b/dimos/agents/agent.py index 76195ccea0..98f23d7e8d 100644 --- a/dimos/agents/agent.py +++ b/dimos/agents/agent.py @@ -46,9 +46,8 @@ class AgentConfig(ModuleConfig): model_fixture: str | None = None -class Agent(Module): - default_config: type[AgentConfig] = AgentConfig - config: AgentConfig +class Agent(Module[AgentConfig]): + default_config = AgentConfig agent: Out[BaseMessage] human_input: In[str] agent_idle: Out[bool] diff --git a/dimos/agents/conftest.py b/dimos/agents/conftest.py index 23d888b0fe..1be2aadc0c 100644 --- a/dimos/agents/conftest.py +++ b/dimos/agents/conftest.py @@ -31,11 +31,6 @@ FIXTURE_DIR = Path(__file__).parent / "fixtures" -@pytest.fixture -def fixture_dir() -> Path: - return FIXTURE_DIR - - @pytest.fixture def agent_setup(request): coordinator = None diff --git a/dimos/agents/mcp/README.md b/dimos/agents/mcp/README.md new file mode 100644 index 0000000000..f9e887beb1 --- /dev/null +++ b/dimos/agents/mcp/README.md @@ -0,0 +1,55 @@ +# DimOS MCP Server + +Expose DimOS robot skills to Claude Code via Model Context Protocol. + +## Setup + +```bash +uv sync --extra base --extra unitree +``` + +Add to Claude Code (one command) + +```bash +claude mcp add --transport http --scope project dimos http://localhost:9990/mcp +``` + +Verify that it was added: + +```bash +claude mcp list +``` + +## MCP Inspector + +If you want to inspect the server manually, you can use MCP Inspector. + +Install it: + +```bash +npx -y @modelcontextprotocol/inspector +``` + +It will open a browser window. + +Change **Transport Type** to "Streamable HTTP", change **URL** to `http://localhost:9990/mcp`, and **Connection Type** to "Direct". Then click on "Connect". + +## Usage + +**Terminal 1** - Start DimOS: +```bash +uv run dimos run unitree-go2-agentic-mcp +``` + +**Claude Code** - Use robot skills: +``` +> move forward 1 meter +> go to the kitchen +> tag this location as "desk" +``` + +## How It Works + +1. `McpServer` in the blueprint starts a FastAPI server on port 9990 +2. Claude Code connects directly to `http://localhost:9990/mcp` +3. Skills are exposed as MCP tools (e.g., `relative_move`, `navigate_with_text`) diff --git a/dimos/protocol/mcp/__init__.py b/dimos/agents/mcp/__init__.py similarity index 100% rename from dimos/protocol/mcp/__init__.py rename to dimos/agents/mcp/__init__.py diff --git a/dimos/agents/mcp/conftest.py b/dimos/agents/mcp/conftest.py new file mode 100644 index 0000000000..532ef16592 --- /dev/null +++ b/dimos/agents/mcp/conftest.py @@ -0,0 +1,103 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path +from threading import Event + +from dotenv import load_dotenv +from langchain_core.messages.base import BaseMessage +import pytest + +from dimos.agents.agent_test_runner import AgentTestRunner +from dimos.agents.mcp.mcp_client import McpClient +from dimos.agents.mcp.mcp_server import McpServer +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.core.transport import pLCMTransport + +load_dotenv() + +FIXTURE_DIR = Path(__file__).parent / "fixtures" + + +@pytest.fixture +def agent_setup(request): + coordinator = None + transports: list[pLCMTransport] = [] + unsubs: list = [] + recording = bool(os.getenv("RECORD")) + + def fn( + *, + blueprints, + messages: list[BaseMessage], + dask: bool = False, + system_prompt: str | None = None, + fixture: str | None = None, + ) -> list[BaseMessage]: + history: list[BaseMessage] = [] + finished_event = Event() + + agent_transport: pLCMTransport = pLCMTransport("/agent") + finished_transport: pLCMTransport = pLCMTransport("/finished") + transports.extend([agent_transport, finished_transport]) + + def on_message(msg: BaseMessage) -> None: + history.append(msg) + + unsubs.append(agent_transport.subscribe(on_message)) + unsubs.append(finished_transport.subscribe(lambda _: finished_event.set())) + + # Derive fixture path from test name if not explicitly provided. + if fixture is not None: + fixture_path = FIXTURE_DIR / fixture + else: + fixture_path = FIXTURE_DIR / f"{request.node.name}.json" + + client_kwargs: dict = {"system_prompt": system_prompt} + + if recording or fixture_path.exists(): + client_kwargs["model_fixture"] = str(fixture_path) + + blueprint = autoconnect( + *blueprints, + McpServer.blueprint(), + McpClient.blueprint(**client_kwargs), + AgentTestRunner.blueprint(messages=messages), + ) + + global_config.update( + viewer_backend="none", + dask=dask, + ) + + nonlocal coordinator + coordinator = blueprint.build() + + if not finished_event.wait(60): + raise TimeoutError("Timed out waiting for agent to finish processing messages.") + + return history + + yield fn + + if coordinator is not None: + coordinator.stop() + + for transport in transports: + transport.stop() + + for unsub in unsubs: + unsub() diff --git a/dimos/agents/mcp/fixtures/test_can_call_again_on_error[False].json b/dimos/agents/mcp/fixtures/test_can_call_again_on_error[False].json new file mode 100644 index 0000000000..8cfe2f69c7 --- /dev/null +++ b/dimos/agents/mcp/fixtures/test_can_call_again_on_error[False].json @@ -0,0 +1,34 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "register_user", + "args": { + "name": "Paul" + }, + "id": "call_NrrizXSIFaeCLuG9i05IwDy3", + "type": "tool_call" + } + ] + }, + { + "content": "", + "tool_calls": [ + { + "name": "register_user", + "args": { + "name": "paul" + }, + "id": "call_2QPx4GsL61Xjrggbq7afXTjn", + "type": "tool_call" + } + ] + }, + { + "content": "The user named 'paul' has been registered successfully.", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents/mcp/fixtures/test_can_call_again_on_error[True].json b/dimos/agents/mcp/fixtures/test_can_call_again_on_error[True].json new file mode 100644 index 0000000000..3d3765f43a --- /dev/null +++ b/dimos/agents/mcp/fixtures/test_can_call_again_on_error[True].json @@ -0,0 +1,34 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "register_user", + "args": { + "name": "Paul" + }, + "id": "call_XSy1Dx1dGtQv5zPaEJtb2hd7", + "type": "tool_call" + } + ] + }, + { + "content": "", + "tool_calls": [ + { + "name": "register_user", + "args": { + "name": "paul" + }, + "id": "call_aYFug1g3TATnaYus9HUVxoQS", + "type": "tool_call" + } + ] + }, + { + "content": "The user named \"paul\" has been registered successfully.", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents/mcp/fixtures/test_can_call_tool[False].json b/dimos/agents/mcp/fixtures/test_can_call_tool[False].json new file mode 100644 index 0000000000..7d1ac3075b --- /dev/null +++ b/dimos/agents/mcp/fixtures/test_can_call_tool[False].json @@ -0,0 +1,22 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "add", + "args": { + "x": 33333, + "y": 100 + }, + "id": "call_RssRDDd9apDjNoVLz4jRLVk0", + "type": "tool_call" + } + ] + }, + { + "content": "The result of 33333 + 100 is 33433.", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents/mcp/fixtures/test_can_call_tool[True].json b/dimos/agents/mcp/fixtures/test_can_call_tool[True].json new file mode 100644 index 0000000000..d375c82235 --- /dev/null +++ b/dimos/agents/mcp/fixtures/test_can_call_tool[True].json @@ -0,0 +1,22 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "add", + "args": { + "x": 33333, + "y": 100 + }, + "id": "call_pzzddF9mBynGYZVdCmGHOB5V", + "type": "tool_call" + } + ] + }, + { + "content": "The result of 33333 + 100 is 33433.", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents/mcp/fixtures/test_image.json b/dimos/agents/mcp/fixtures/test_image.json new file mode 100644 index 0000000000..0e4816b8ee --- /dev/null +++ b/dimos/agents/mcp/fixtures/test_image.json @@ -0,0 +1,23 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "take_a_picture", + "args": {}, + "id": "call_7Qwsr8QMLWhKRMektcGiKYf7", + "type": "tool_call" + } + ] + }, + { + "content": "I've taken a picture. Let me analyze and describe it for you.\nThe image features an expansive outdoor stadium. From the camera's perspective, the word 'stadium' best matches the image. Is there anything else you'd like to know or do?", + "tool_calls": [] + }, + { + "content": "The image shows a group of people sitting and enjoying their time at an outdoor cafe. Therefore, the word 'cafe' best matches the image.", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents/mcp/fixtures/test_multiple_tool_calls_with_multiple_messages.json b/dimos/agents/mcp/fixtures/test_multiple_tool_calls_with_multiple_messages.json new file mode 100644 index 0000000000..5c0d551e13 --- /dev/null +++ b/dimos/agents/mcp/fixtures/test_multiple_tool_calls_with_multiple_messages.json @@ -0,0 +1,116 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "locate_person", + "args": { + "name": "John" + }, + "id": "call_eOoKTtyvvXBk171ro4bXzW5C", + "type": "tool_call" + } + ] + }, + { + "content": "", + "tool_calls": [ + { + "name": "register_person", + "args": { + "name": "John" + }, + "id": "call_tTB5A3q60teaBrdonRvCwcwM", + "type": "tool_call" + } + ] + }, + { + "content": "", + "tool_calls": [ + { + "name": "locate_person", + "args": { + "name": "John" + }, + "id": "call_uEhafkL3f7BLQKhRuZlEAany", + "type": "tool_call" + } + ] + }, + { + "content": "", + "tool_calls": [ + { + "name": "go_to_location", + "args": { + "description": "kitchen" + }, + "id": "call_oxnH4gCGi6aSeVLPrhnp31yP", + "type": "tool_call" + } + ] + }, + { + "content": "I have moved to the kitchen where John is located.", + "tool_calls": [] + }, + { + "content": "", + "tool_calls": [ + { + "name": "locate_person", + "args": { + "name": "Jane" + }, + "id": "call_2HinxBmffnafloaP4b7DkBZW", + "type": "tool_call" + } + ] + }, + { + "content": "", + "tool_calls": [ + { + "name": "register_person", + "args": { + "name": "Jane" + }, + "id": "call_XtHavMmgpzrhmVi3XB6RUFrW", + "type": "tool_call" + } + ] + }, + { + "content": "", + "tool_calls": [ + { + "name": "locate_person", + "args": { + "name": "Jane" + }, + "id": "call_fRHHO4cPWDXi4IvQ4qQqidwT", + "type": "tool_call" + } + ] + }, + { + "content": "", + "tool_calls": [ + { + "name": "go_to_location", + "args": { + "description": "living room" + }, + "id": "call_Hcc7C0FMWS8rfKwMP0sUL7XN", + "type": "tool_call" + } + ] + }, + { + "content": "I have moved to the living room where Jane is located.", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents/mcp/fixtures/test_prompt.json b/dimos/agents/mcp/fixtures/test_prompt.json new file mode 100644 index 0000000000..acb77fe350 --- /dev/null +++ b/dimos/agents/mcp/fixtures/test_prompt.json @@ -0,0 +1,8 @@ +{ + "responses": [ + { + "content": "Hello! My name is Johnny. How can I assist you today?", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py new file mode 100644 index 0000000000..7c5eda5302 --- /dev/null +++ b/dimos/agents/mcp/mcp_client.py @@ -0,0 +1,250 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from queue import Empty, Queue +from threading import Event, RLock, Thread +import time +from typing import Any +import uuid + +import httpx +from langchain.agents import create_agent +from langchain_core.messages import HumanMessage +from langchain_core.messages.base import BaseMessage +from langchain_core.tools import StructuredTool +from langgraph.graph.state import CompiledStateGraph +from reactivex.disposable import Disposable + +from dimos.agents.system_prompt import SYSTEM_PROMPT +from dimos.agents.utils import pretty_print_langchain_message +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.rpc_client import RPCClient +from dimos.core.stream import In, Out +from dimos.utils.logging_config import setup_logger +from dimos.utils.sequential_ids import SequentialIds + +logger = setup_logger() + + +@dataclass +class McpClientConfig(ModuleConfig): + system_prompt: str | None = SYSTEM_PROMPT + model: str = "gpt-4o" + model_fixture: str | None = None + mcp_server_url: str = "http://localhost:9990/mcp" + + +class McpClient(Module[McpClientConfig]): + default_config = McpClientConfig + agent: Out[BaseMessage] + human_input: In[str] + agent_idle: Out[bool] + + _lock: RLock + _state_graph: CompiledStateGraph[Any, Any, Any, Any] | None + _message_queue: Queue[BaseMessage] + _history: list[BaseMessage] + _thread: Thread + _stop_event: Event + _http_client: httpx.Client + _seq_ids: SequentialIds + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._lock = RLock() + self._state_graph = None + self._message_queue = Queue() + self._history = [] + self._thread = Thread( + target=self._thread_loop, + name=f"{self.__class__.__name__}-thread", + daemon=True, + ) + self._stop_event = Event() + self._http_client = httpx.Client(timeout=120.0) + self._seq_ids = SequentialIds() + + def __reduce__(self) -> Any: + return (self.__class__, (), {}) + + def _mcp_request(self, method: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + body: dict[str, Any] = { + "jsonrpc": "2.0", + "id": self._seq_ids.next(), + "method": method, + } + if params is not None: + body["params"] = params + + resp = self._http_client.post(self.config.mcp_server_url, json=body) + resp.raise_for_status() + data = resp.json() + + if "error" in data: + raise RuntimeError(f"MCP error {data['error']['code']}: {data['error']['message']}") + + result: dict[str, Any] = data.get("result") + return result + + def _fetch_tools(self, timeout: float = 60.0, interval: float = 1.0) -> list[StructuredTool]: + result = self._try_fetch_tools(timeout=timeout, interval=interval) + if result is None: + raise RuntimeError( + f"Failed to fetch tools from MCP server {self.config.mcp_server_url}" + ) + + tools = [self._mcp_tool_to_langchain(t) for t in result.get("tools", [])] + + if not tools: + logger.warning("No tools found from MCP server.") + else: + tool_names = [t.name for t in tools] + logger.info("Discovered tools from MCP server.", tools=tool_names, n_tools=len(tools)) + + return tools + + def _try_fetch_tools(self, timeout: float, interval: float) -> dict[str, Any] | None: + deadline = time.monotonic() + timeout + + while True: + try: + self._mcp_request("initialize") + break + except (httpx.ConnectError, httpx.RemoteProtocolError): + if time.monotonic() >= deadline: + return None + time.sleep(interval) + + return self._mcp_request("tools/list") + + def _mcp_tool_to_langchain(self, mcp_tool: dict[str, Any]) -> StructuredTool: + name = mcp_tool["name"] + description = mcp_tool.get("description", "") + input_schema = mcp_tool.get("inputSchema", {"type": "object", "properties": {}}) + + def call_tool(**kwargs: Any) -> str: + result = self._mcp_request("tools/call", {"name": name, "arguments": kwargs}) + content = result.get("content", []) + parts = [c.get("text", "") for c in content if c.get("type") == "text"] + text = "\n".join(parts) + + # Images need to be added to the history separately because they + # cannot be included in the tool response for OpenAI models and + # probably others. + for item in content: + if item.get("type") != "text": + uuid_ = str(uuid.uuid4()) + text += f"Tool call started with UUID: {uuid_}. You will be updated with the result soon." + _append_image_to_history(self, name, uuid_, item) + + return text + + return StructuredTool( + name=name, + description=description, + func=call_tool, + args_schema=input_schema, + ) + + @rpc + def start(self) -> None: + super().start() + + def _on_human_input(string: str) -> None: + self._message_queue.put(HumanMessage(content=string)) + + self._disposables.add(Disposable(self.human_input.subscribe(_on_human_input))) + + @rpc + def on_system_modules(self, _modules: list[RPCClient]) -> None: + tools = self._fetch_tools() + + model: str | Any = self.config.model + if self.config.model_fixture is not None: + from dimos.agents.testing import MockModel + + model = MockModel(json_path=self.config.model_fixture) + + with self._lock: + self._state_graph = create_agent( + model=model, + tools=tools, + system_prompt=self.config.system_prompt, + ) + self._thread.start() + + @rpc + def stop(self) -> None: + self._stop_event.set() + if self._thread.is_alive(): + self._thread.join(timeout=2.0) + self._http_client.close() + super().stop() + + @rpc + def add_message(self, message: BaseMessage) -> None: + self._message_queue.put(message) + + def _thread_loop(self) -> None: + while not self._stop_event.is_set(): + try: + message = self._message_queue.get(timeout=0.5) + except Empty: + continue + + with self._lock: + if not self._state_graph: + raise ValueError("No state graph initialized") + self._process_message(self._state_graph, message) + + def _process_message( + self, state_graph: CompiledStateGraph[Any, Any, Any, Any], message: BaseMessage + ) -> None: + self.agent_idle.publish(False) + self._history.append(message) + pretty_print_langchain_message(message) + self.agent.publish(message) + + for update in state_graph.stream({"messages": self._history}, stream_mode="updates"): + for node_output in update.values(): + for msg in node_output.get("messages", []): + self._history.append(msg) + pretty_print_langchain_message(msg) + self.agent.publish(msg) + + if self._message_queue.empty(): + self.agent_idle.publish(True) + + +def _append_image_to_history( + mcp_client: McpClient, func_name: str, uuid_: str, result: Any +) -> None: + mcp_client.add_message( + HumanMessage( + content=[ + { + "type": "text", + "text": f"This is the artefact for the '{func_name}' tool with UUID:={uuid_}.", + }, + result, + ] + ) + ) + + +mcp_client = McpClient.blueprint + +__all__ = ["McpClient", "McpClientConfig", "mcp_client"] diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py new file mode 100644 index 0000000000..1f8ce92888 --- /dev/null +++ b/dimos/agents/mcp/mcp_server.py @@ -0,0 +1,197 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import asyncio +import json +from typing import TYPE_CHECKING, Any + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from starlette.responses import Response +import uvicorn + +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +from dimos.core import Module, rpc # noqa: I001 +from dimos.core.rpc_client import RpcCall, RPCClient + +from starlette.requests import Request # noqa: TC002 + +if TYPE_CHECKING: + import concurrent.futures + + from dimos.core.module import SkillInfo + + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["POST"], + allow_headers=["*"], +) +app.state.skills = [] +app.state.rpc_calls = {} + + +def _jsonrpc_result(req_id: Any, result: Any) -> dict[str, Any]: + return {"jsonrpc": "2.0", "id": req_id, "result": result} + + +def _jsonrpc_result_text(req_id: Any, text: str) -> dict[str, Any]: + return _jsonrpc_result(req_id, {"content": [{"type": "text", "text": text}]}) + + +def _jsonrpc_error(req_id: Any, code: int, message: str) -> dict[str, Any]: + return {"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}} + + +def _handle_initialize(req_id: Any) -> dict[str, Any]: + return _jsonrpc_result( + req_id, + { + "protocolVersion": "2025-11-25", + "capabilities": {"tools": {}}, + "serverInfo": {"name": "dimensional", "version": "1.0.0"}, + }, + ) + + +def _handle_tools_list(req_id: Any, skills: list[SkillInfo]) -> dict[str, Any]: + tools = [] + + for skill in skills: + schema = json.loads(skill.args_schema) + description = schema.pop("description", None) + schema.pop("title", None) + tool = {"name": skill.func_name, "inputSchema": schema} + if description: + tool["description"] = description + tools.append(tool) + + return _jsonrpc_result(req_id, {"tools": tools}) + + +async def _handle_tools_call( + req_id: Any, params: dict[str, Any], rpc_calls: dict[str, Any] +) -> dict[str, Any]: + name = params.get("name", "") + args: dict[str, Any] = params.get("arguments") or {} + + rpc_call = rpc_calls.get(name) + if rpc_call is None: + return _jsonrpc_result_text(req_id, f"Tool not found: {name}") + + try: + result = await asyncio.get_event_loop().run_in_executor(None, lambda: rpc_call(**args)) + except Exception as e: + logger.exception("Error running tool", tool_name=name, exc_info=True) + return _jsonrpc_result_text(req_id, f"Error running tool '{name}': {e}") + + if result is None: + return _jsonrpc_result_text(req_id, "It has started. You will be updated later.") + + if hasattr(result, "agent_encode"): + return _jsonrpc_result(req_id, {"content": result.agent_encode()}) + + return _jsonrpc_result_text(req_id, str(result)) + + +async def handle_request( + request: dict[str, Any], + skills: list[SkillInfo], + rpc_calls: dict[str, Any], +) -> dict[str, Any] | None: + """Handle a single MCP JSON-RPC request. + + Returns None for JSON-RPC notifications (no ``id``), which must not + receive a response. + """ + method = request.get("method", "") + params = request.get("params", {}) or {} + req_id = request.get("id") + + # JSON-RPC notifications have no "id" – the server must not reply. + if "id" not in request: + return None + + if method == "initialize": + return _handle_initialize(req_id) + if method == "tools/list": + return _handle_tools_list(req_id, skills) + if method == "tools/call": + return await _handle_tools_call(req_id, params, rpc_calls) + return _jsonrpc_error(req_id, -32601, f"Unknown: {method}") + + +@app.post("/mcp") +async def mcp_endpoint(request: Request) -> Response: + raw = await request.body() + try: + body = json.loads(raw) + except Exception: + logger.exception("POST /mcp JSON parse failed") + return JSONResponse( + {"jsonrpc": "2.0", "id": None, "error": {"code": -32700, "message": "Parse error"}}, + status_code=400, + ) + result = await handle_request(body, request.app.state.skills, request.app.state.rpc_calls) + if result is None: + return Response(status_code=204) + return JSONResponse(result) + + +class McpServer(Module): + def __init__(self) -> None: + super().__init__() + self._uvicorn_server: uvicorn.Server | None = None + self._serve_future: concurrent.futures.Future[None] | None = None + + @rpc + def start(self) -> None: + super().start() + self._start_server() + + @rpc + def stop(self) -> None: + if self._uvicorn_server: + self._uvicorn_server.should_exit = True + loop = self._loop + if loop is not None and self._serve_future is not None: + self._serve_future.result(timeout=5.0) + self._uvicorn_server = None + self._serve_future = None + super().stop() + + @rpc + def on_system_modules(self, modules: list[RPCClient]) -> None: + assert self.rpc is not None + app.state.skills = [skill for module in modules for skill in (module.get_skills() or [])] + app.state.rpc_calls = { + skill.func_name: RpcCall(None, self.rpc, skill.func_name, skill.class_name, []) + for skill in app.state.skills + } + + def _start_server(self, port: int = 9990) -> None: + config = uvicorn.Config(app, host="0.0.0.0", port=port, log_level="info") + server = uvicorn.Server(config) + self._uvicorn_server = server + loop = self._loop + assert loop is not None + self._serve_future = asyncio.run_coroutine_threadsafe(server.serve(), loop) diff --git a/dimos/agents/mcp/test_mcp_client.py b/dimos/agents/mcp/test_mcp_client.py new file mode 100644 index 0000000000..be4a09d5b9 --- /dev/null +++ b/dimos/agents/mcp/test_mcp_client.py @@ -0,0 +1,210 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from langchain_core.messages import HumanMessage +import pytest + +from dimos.agents.annotation import skill +from dimos.core.module import Module +from dimos.msgs.sensor_msgs import Image +from dimos.utils.data import get_data + + +class Adder(Module): + @skill + def add(self, x: int, y: int) -> str: + """adds x and y.""" + return str(x + y) + + +@pytest.mark.integration +@pytest.mark.parametrize("dask", [False, True]) +def test_can_call_tool(dask, agent_setup): + history = agent_setup( + blueprints=[Adder.blueprint()], + messages=[HumanMessage("What is 33333 + 100? Use the tool.")], + dask=dask, + ) + + assert "33433" in history[-1].content + + +class UserRegistration(Module): + def __init__(self): + super().__init__() + self._first_call = True + self._use_upper = False + + @skill + def register_user(self, name: str) -> str: + """registers a user by name.""" + + # If the agent calls with "paul" or "Paul", always say it's the wrong way + # to force it to try again. + + if self._first_call: + self._first_call = False + self._use_upper = not name[0].isupper() + + if self._use_upper and not name[0].isupper(): + raise ValueError("Names must start with an uppercase letter.") + if not self._use_upper and name[0].isupper(): + raise ValueError("The names must only use lowercase letters.") + + return "User name registered successfully." + + +@pytest.mark.integration +@pytest.mark.parametrize("dask", [False, True]) +def test_can_call_again_on_error(dask, agent_setup): + history = agent_setup( + blueprints=[UserRegistration.blueprint()], + messages=[ + HumanMessage( + "Register a user named 'Paul'. If there are errors, just try again until you succeed." + ) + ], + dask=dask, + ) + + assert any(message.content == "User name registered successfully." for message in history) + + +class MultipleTools(Module): + def __init__(self): + super().__init__() + self._people = {"Ben": "office", "Bob": "garage"} + + @skill + def register_person(self, name: str) -> str: + """Registers a person by name.""" + if name.lower() == "john": + self._people[name] = "kitchen" + elif name.lower() == "jane": + self._people[name] = "living room" + return f"'{name}' has been registered." + + @skill + def locate_person(self, name: str) -> str: + """Locates a person by name.""" + if name not in self._people: + known_people = list(self._people.keys()) + return ( + f"Error: '{name}' is not registered. People cannot be located until they've " + f"been registered in the system. People known so far: {', '.join(known_people)}. " + "Use register_person to register a person." + ) + return f"'{name}' is located at '{self._people[name]}'." + + +class NavigationSkill(Module): + @skill + def go_to_location(self, description: str) -> str: + """Go to a location by a description.""" + if description.strip().lower() not in ["kitchen", "living room"]: + return f"Error: Unknown location description: '{description}'." + return f"Going to the {description}." + + +@pytest.mark.integration +def test_multiple_tool_calls_with_multiple_messages(agent_setup): + history = agent_setup( + blueprints=[MultipleTools.blueprint(), NavigationSkill.blueprint()], + messages=[ + HumanMessage( + "You are a robot assistant. Move to the location where John is. Don't ask me for feedback, just go there." + ), + HumanMessage("Nice job. You did it. Now go to the location where Jane is."), + ], + ) + + # Collect all go_to_location calls from the history + go_to_location_calls = [] + for message in history: + if hasattr(message, "tool_calls"): + for tool_call in message.tool_calls: + if tool_call["name"] == "go_to_location": + go_to_location_calls.append(tool_call) + + # Find the index of the second HumanMessage to split first/second prompt + second_human_idx = None + human_count = 0 + for i, message in enumerate(history): + if isinstance(message, HumanMessage): + human_count += 1 + if human_count == 2: + second_human_idx = i + break + + # Collect go_to_location calls before and after the second prompt + calls_after_first_prompt = [] + calls_after_second_prompt = [] + for i, message in enumerate(history): + if hasattr(message, "tool_calls"): + for tool_call in message.tool_calls: + if tool_call["name"] == "go_to_location": + if i < second_human_idx: + calls_after_first_prompt.append(tool_call) + else: + calls_after_second_prompt.append(tool_call) + + # After the first prompt, go_to_location should be called with "kitchen" + assert len(calls_after_first_prompt) == 1 + assert "kitchen" in calls_after_first_prompt[0]["args"]["description"].lower() + + # After the second prompt, go_to_location should be called with "living room" + assert len(calls_after_second_prompt) == 1 + assert "living room" in calls_after_second_prompt[0]["args"]["description"].lower() + + # There should be exactly two go_to_location calls total + assert len(go_to_location_calls) == 2 + + +@pytest.mark.integration +def test_prompt(agent_setup): + history = agent_setup( + blueprints=[], + messages=[HumanMessage("What is your name?")], + system_prompt="You are a helpful assistant named Johnny.", + ) + + assert "Johnny" in history[-1].content + + +class Visualizer(Module): + @skill + def take_a_picture(self) -> Image: + """Takes a picture.""" + return Image.from_file(get_data("cafe-smol.jpg")).to_rgb() + + +@pytest.mark.integration +def test_image(agent_setup): + history = agent_setup( + blueprints=[Visualizer.blueprint()], + messages=[ + HumanMessage( + "What do you see? Take a picture using your camera and describe it. " + "Please mention one of the words which best match the image: " + "'stadium', 'cafe', 'battleship'." + ) + ], + system_prompt="You are a helpful assistant that can use a camera to take pictures.", + ) + + response = history[-1].content.lower() + assert "cafe" in response + assert "stadium" not in response + assert "battleship" not in response diff --git a/dimos/agents/mcp/test_mcp_client_unit.py b/dimos/agents/mcp/test_mcp_client_unit.py new file mode 100644 index 0000000000..8cd888f851 --- /dev/null +++ b/dimos/agents/mcp/test_mcp_client_unit.py @@ -0,0 +1,145 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from dimos.agents.mcp.mcp_client import McpClient +from dimos.utils.sequential_ids import SequentialIds + + +def _mock_post(url: str, **kwargs: object) -> MagicMock: + """Return a fake httpx response based on the JSON-RPC method.""" + body = kwargs.get("json") or (kwargs.get("content") and json.loads(kwargs["content"])) + assert isinstance(body, dict) + method = body["method"] + req_id = body["id"] + + result: object + if method == "initialize": + result = { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "serverInfo": {"name": "dimensional", "version": "1.0.0"}, + } + elif method == "tools/list": + result = { + "tools": [ + { + "name": "add", + "description": "Add two numbers", + "inputSchema": { + "type": "object", + "properties": { + "x": {"type": "integer"}, + "y": {"type": "integer"}, + }, + "required": ["x", "y"], + }, + }, + { + "name": "greet", + "description": "Say hello", + "inputSchema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + }, + }, + ] + } + elif method == "tools/call": + name = body["params"]["name"] + args = body["params"].get("arguments", {}) + if name == "add": + text = str(args.get("x", 0) + args.get("y", 0)) + elif name == "greet": + text = f"Hello, {args.get('name', 'world')}!" + else: + text = "Skill not found" + result = {"content": [{"type": "text", "text": text}]} + else: + resp = MagicMock() + resp.status_code = 200 + resp.raise_for_status = MagicMock() + resp.json.return_value = { + "jsonrpc": "2.0", + "id": req_id, + "error": {"code": -32601, "message": f"Unknown: {method}"}, + } + return resp + + resp = MagicMock() + resp.status_code = 200 + resp.raise_for_status = MagicMock() + resp.json.return_value = {"jsonrpc": "2.0", "id": req_id, "result": result} + return resp + + +@pytest.fixture +def mcp_client() -> McpClient: + """Build an McpClient wired to the mock MCP post handler.""" + mock_http = MagicMock() + mock_http.post.side_effect = _mock_post + + with patch("dimos.agents.mcp.mcp_client.httpx.Client", return_value=mock_http): + client = McpClient.__new__(McpClient) + + client._http_client = mock_http + client._seq_ids = SequentialIds() + client.config = MagicMock() + client.config.mcp_server_url = "http://localhost:9990/mcp" + return client + + +def test_fetch_tools_from_mcp_server(mcp_client: McpClient) -> None: + tools = mcp_client._fetch_tools() + + assert len(tools) == 2 + assert tools[0].name == "add" + assert tools[1].name == "greet" + + +def test_tool_invocation_via_mcp(mcp_client: McpClient) -> None: + tools = mcp_client._fetch_tools() + add_tool = next(t for t in tools if t.name == "add") + greet_tool = next(t for t in tools if t.name == "greet") + + assert add_tool.func(x=2, y=3) == "5" + assert greet_tool.func(name="Alice") == "Hello, Alice!" + + +def test_mcp_request_error_propagation(mcp_client: McpClient) -> None: + def error_post(url: str, **kwargs: object) -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.raise_for_status = MagicMock() + resp.json.return_value = { + "jsonrpc": "2.0", + "id": 1, + "error": {"code": -32601, "message": "Unknown: bad/method"}, + } + return resp + + mcp_client._http_client.post.side_effect = error_post + + try: + mcp_client._mcp_request("bad/method") + raise AssertionError("Expected RuntimeError") + except RuntimeError as e: + assert "Unknown: bad/method" in str(e) diff --git a/dimos/protocol/mcp/test_mcp_module.py b/dimos/agents/mcp/test_mcp_server.py similarity index 62% rename from dimos/protocol/mcp/test_mcp_module.py rename to dimos/agents/mcp/test_mcp_server.py index 050e24f13b..1cbca9e3e4 100644 --- a/dimos/protocol/mcp/test_mcp_module.py +++ b/dimos/agents/mcp/test_mcp_server.py @@ -16,34 +16,25 @@ import asyncio import json -from pathlib import Path from unittest.mock import MagicMock +from dimos.agents.mcp.mcp_server import handle_request from dimos.core.module import SkillInfo -from dimos.protocol.mcp.mcp import MCPModule -def _make_mcp(skills: list[SkillInfo], call_results: dict[str, object]) -> MCPModule: - """Create an MCPModule with pre-populated skills and mock RPC calls.""" - mcp = MCPModule.__new__(MCPModule) - mcp._skills = skills - mcp._rpc_calls = {} +def _make_rpc_calls( + skills: list[SkillInfo], call_results: dict[str, object] +) -> dict[str, MagicMock]: + """Create mock RPC calls for the given skills.""" + rpc_calls: dict[str, MagicMock] = {} for skill in skills: mock_call = MagicMock() if skill.func_name in call_results: mock_call.return_value = call_results[skill.func_name] else: mock_call.return_value = None - mcp._rpc_calls[skill.func_name] = mock_call - return mcp - - -def test_unitree_blueprint_has_mcp() -> None: - contents = Path( - "dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_mcp.py" - ).read_text() - assert "agentic_mcp" in contents - assert "MCPModule.blueprint()" in contents + rpc_calls[skill.func_name] = mock_call + return rpc_calls def test_mcp_module_request_flow() -> None: @@ -56,20 +47,21 @@ def test_mcp_module_request_flow() -> None: } ) skills = [SkillInfo(class_name="TestSkills", func_name="add", args_schema=schema)] + rpc_calls = _make_rpc_calls(skills, {"add": 5}) - mcp = _make_mcp(skills, {"add": 5}) - - response = asyncio.run(mcp._handle_request({"method": "tools/list", "id": 1})) + response = asyncio.run(handle_request({"method": "tools/list", "id": 1}, skills, rpc_calls)) assert response["result"]["tools"][0]["name"] == "add" assert response["result"]["tools"][0]["description"] == "Add two numbers" response = asyncio.run( - mcp._handle_request( + handle_request( { "method": "tools/call", "id": 2, "params": {"name": "add", "arguments": {"x": 2, "y": 3}}, - } + }, + skills, + rpc_calls, ) ) assert response["result"]["content"][0]["text"] == "5" @@ -82,49 +74,40 @@ def test_mcp_module_handles_errors() -> None: SkillInfo(class_name="TestSkills", func_name="fail_skill", args_schema=schema), ] - mcp = _make_mcp(skills, {"ok_skill": "done"}) - mcp._rpc_calls["fail_skill"] = MagicMock(side_effect=RuntimeError("boom")) + rpc_calls = _make_rpc_calls(skills, {"ok_skill": "done"}) + rpc_calls["fail_skill"] = MagicMock(side_effect=RuntimeError("boom")) # All skills listed - response = asyncio.run(mcp._handle_request({"method": "tools/list", "id": 1})) + response = asyncio.run(handle_request({"method": "tools/list", "id": 1}, skills, rpc_calls)) tool_names = {tool["name"] for tool in response["result"]["tools"]} assert "ok_skill" in tool_names assert "fail_skill" in tool_names # Error skill returns error text response = asyncio.run( - mcp._handle_request( - {"method": "tools/call", "id": 2, "params": {"name": "fail_skill", "arguments": {}}} + handle_request( + {"method": "tools/call", "id": 2, "params": {"name": "fail_skill", "arguments": {}}}, + skills, + rpc_calls, ) ) - assert "Error:" in response["result"]["content"][0]["text"] + assert "Error running tool" in response["result"]["content"][0]["text"] assert "boom" in response["result"]["content"][0]["text"] # Unknown skill returns not found response = asyncio.run( - mcp._handle_request( - {"method": "tools/call", "id": 3, "params": {"name": "no_such", "arguments": {}}} + handle_request( + {"method": "tools/call", "id": 3, "params": {"name": "no_such", "arguments": {}}}, + skills, + rpc_calls, ) ) assert "not found" in response["result"]["content"][0]["text"].lower() def test_mcp_module_initialize_and_unknown() -> None: - mcp = _make_mcp([], {}) - - response = asyncio.run(mcp._handle_request({"method": "initialize", "id": 1})) + response = asyncio.run(handle_request({"method": "initialize", "id": 1}, [], {})) assert response["result"]["serverInfo"]["name"] == "dimensional" - response = asyncio.run(mcp._handle_request({"method": "unknown/method", "id": 2})) + response = asyncio.run(handle_request({"method": "unknown/method", "id": 2}, [], {})) assert response["error"]["code"] == -32601 - - -def test_mcp_module_invalid_tool_name() -> None: - mcp = _make_mcp([], {}) - - response = asyncio.run( - mcp._handle_request( - {"method": "tools/call", "id": 1, "params": {"name": 123, "arguments": {}}} - ) - ) - assert response["error"]["code"] == -32602 diff --git a/dimos/protocol/mcp/README.md b/dimos/protocol/mcp/README.md deleted file mode 100644 index 233e852669..0000000000 --- a/dimos/protocol/mcp/README.md +++ /dev/null @@ -1,35 +0,0 @@ -# DimOS MCP Server - -Expose DimOS robot skills to Claude Code via Model Context Protocol. - -## Setup - -```bash -uv sync --extra base --extra unitree -``` - -Add to Claude Code (one command): -```bash -claude mcp add --transport stdio dimos --scope project -- python -m dimos.protocol.mcp -``` - - -## Usage - -**Terminal 1** - Start DimOS: -```bash -uv run dimos run unitree-go2-agentic-mcp -``` - -**Claude Code** - Use robot skills: -``` -> move forward 1 meter -> go to the kitchen -> tag this location as "desk" -``` - -## How It Works - -1. `MCPModule` in the blueprint starts a TCP server on port 9990 -2. Claude Code spawns the bridge (`--bridge`) which connects to `localhost:9990` -3. Skills are exposed as MCP tools (e.g., `relative_move`, `navigate_with_text`) diff --git a/dimos/protocol/mcp/__main__.py b/dimos/protocol/mcp/__main__.py deleted file mode 100644 index a58e59d367..0000000000 --- a/dimos/protocol/mcp/__main__.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""CLI entry point for Dimensional MCP Bridge. - -Connects Claude Code (or other MCP clients) to a running DimOS agent. - -Usage: - python -m dimos.protocol.mcp # Bridge to running DimOS on default port -""" - -from __future__ import annotations - -import asyncio - -from dimos.protocol.mcp.bridge import main as bridge_main - - -def main() -> None: - """Main entry point - connects to running DimOS via bridge.""" - asyncio.run(bridge_main()) - - -if __name__ == "__main__": - main() diff --git a/dimos/protocol/mcp/bridge.py b/dimos/protocol/mcp/bridge.py deleted file mode 100644 index 0b09997798..0000000000 --- a/dimos/protocol/mcp/bridge.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""MCP Bridge - Connects stdio (Claude Code) to TCP (DimOS Agent).""" - -import asyncio -import os -import sys - -DEFAULT_PORT = 9990 - - -async def main() -> None: - port = int(os.environ.get("MCP_PORT", DEFAULT_PORT)) - host = os.environ.get("MCP_HOST", "localhost") - - reader, writer = await asyncio.open_connection(host, port) - sys.stderr.write(f"MCP Bridge connected to {host}:{port}\n") - - async def stdin_to_tcp() -> None: - loop = asyncio.get_event_loop() - while True: - line = await loop.run_in_executor(None, sys.stdin.readline) - if not line: - break - writer.write(line.encode()) - await writer.drain() - - async def tcp_to_stdout() -> None: - while True: - data = await reader.readline() - if not data: - break - sys.stdout.write(data.decode()) - sys.stdout.flush() - - await asyncio.gather(stdin_to_tcp(), tcp_to_stdout()) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/dimos/protocol/mcp/mcp.py b/dimos/protocol/mcp/mcp.py deleted file mode 100644 index 78d19c64db..0000000000 --- a/dimos/protocol/mcp/mcp.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import asyncio -import json -from typing import TYPE_CHECKING, Any - -from dimos.core import Module, rpc -from dimos.core.rpc_client import RpcCall, RPCClient - -if TYPE_CHECKING: - from dimos.core.module import SkillInfo - - -class MCPModule(Module): - _skills: list[SkillInfo] - _rpc_calls: dict[str, RpcCall] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._skills = [] - self._rpc_calls = {} - self._server: asyncio.AbstractServer | None = None - self._server_future: object | None = None - - @rpc - def start(self) -> None: - super().start() - self._start_server() - - @rpc - def stop(self) -> None: - if self._server: - self._server.close() - loop = self._loop - assert loop is not None - asyncio.run_coroutine_threadsafe(self._server.wait_closed(), loop).result() - self._server = None - if self._server_future and hasattr(self._server_future, "cancel"): - self._server_future.cancel() - super().stop() - - @rpc - def on_system_modules(self, modules: list[RPCClient]) -> None: - assert self.rpc is not None - self._skills = [skill for module in modules for skill in (module.get_skills() or [])] - self._rpc_calls = { - skill.func_name: RpcCall(None, self.rpc, skill.func_name, skill.class_name, []) - for skill in self._skills - } - - def _start_server(self, port: int = 9990) -> None: - async def handle_client(reader, writer) -> None: # type: ignore[no-untyped-def] - while True: - if not (data := await reader.readline()): - break - response = await self._handle_request(json.loads(data.decode())) - writer.write(json.dumps(response).encode() + b"\n") - await writer.drain() - writer.close() - - async def start_server() -> None: - self._server = await asyncio.start_server(handle_client, "0.0.0.0", port) - await self._server.serve_forever() - - loop = self._loop - assert loop is not None - self._server_future = asyncio.run_coroutine_threadsafe(start_server(), loop) - - async def _handle_request(self, request: dict[str, Any]) -> dict[str, Any]: - method = request.get("method", "") - params = request.get("params", {}) or {} - req_id = request.get("id") - if method == "initialize": - init_result = { - "protocolVersion": "2024-11-05", - "capabilities": {"tools": {}}, - "serverInfo": {"name": "dimensional", "version": "1.0.0"}, - } - return {"jsonrpc": "2.0", "id": req_id, "result": init_result} - if method == "tools/list": - tools = [] - for skill in self._skills: - schema = json.loads(skill.args_schema) - tools.append( - { - "name": skill.func_name, - "description": schema.get("description", ""), - "inputSchema": schema, - } - ) - return {"jsonrpc": "2.0", "id": req_id, "result": {"tools": tools}} - if method == "tools/call": - name = params.get("name") - args = params.get("arguments") or {} - if not isinstance(name, str): - return { - "jsonrpc": "2.0", - "id": req_id, - "error": {"code": -32602, "message": "Missing or invalid tool name"}, - } - if not isinstance(args, dict): - args = {} - rpc_call = self._rpc_calls.get(name) - if rpc_call is None: - return { - "jsonrpc": "2.0", - "id": req_id, - "result": {"content": [{"type": "text", "text": "Skill not found"}]}, - } - try: - result = await asyncio.get_event_loop().run_in_executor( - None, lambda: rpc_call(**args) - ) - text = str(result) if result is not None else "Completed" - except Exception as e: - text = f"Error: {e}" - return { - "jsonrpc": "2.0", - "id": req_id, - "result": {"content": [{"type": "text", "text": text}]}, - } - return { - "jsonrpc": "2.0", - "id": req_id, - "error": {"code": -32601, "message": f"Unknown: {method}"}, - } diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 0e23c82065..8e1c7fa89f 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -113,6 +113,7 @@ "keyboard_teleop_module": "dimos.teleop.keyboard.keyboard_teleop_module", "manipulation_module": "dimos.manipulation.manipulation_module", "mapper": "dimos.robot.unitree.type.map", + "mcp_client": "dimos.agents.mcp.mcp_client", "mid360_module": "dimos.hardware.sensors.lidar.livox.module", "navigation_skill": "dimos.agents.skills.navigation", "object_scene_registration_module": "dimos.perception.object_scene_registration", diff --git a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_mcp.py b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_mcp.py index bbc3e4c216..e75b31e511 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_mcp.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_mcp.py @@ -13,13 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dimos.agents.mcp.mcp_client import mcp_client +from dimos.agents.mcp.mcp_server import McpServer from dimos.core.blueprints import autoconnect -from dimos.protocol.mcp.mcp import MCPModule -from dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_agentic import unitree_go2_agentic +from dimos.robot.unitree.go2.blueprints.agentic._common_agentic import _common_agentic +from dimos.robot.unitree.go2.blueprints.smart.unitree_go2_spatial import unitree_go2_spatial unitree_go2_agentic_mcp = autoconnect( - unitree_go2_agentic, - MCPModule.blueprint(), + unitree_go2_spatial, + McpServer.blueprint(), + mcp_client(), + _common_agentic, ) __all__ = ["unitree_go2_agentic_mcp"] diff --git a/pyproject.toml b/pyproject.toml index ee7c4778b2..6471fd89cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,9 +142,6 @@ agents = [ "openai", "openai-whisper", "sounddevice", - - # MCP Server - "mcp>=1.0.0", ] web = [ diff --git a/uv.lock b/uv.lock index 53b2454b40..47083b733e 100644 --- a/uv.lock +++ b/uv.lock @@ -1791,7 +1791,6 @@ agents = [ { name = "langchain-ollama" }, { name = "langchain-openai" }, { name = "langchain-text-splitters" }, - { name = "mcp" }, { name = "ollama" }, { name = "openai" }, { name = "openai-whisper" }, @@ -1812,7 +1811,6 @@ base = [ { name = "langchain-openai" }, { name = "langchain-text-splitters" }, { name = "lap" }, - { name = "mcp" }, { name = "moondream" }, { name = "mujoco" }, { name = "ollama" }, @@ -2011,7 +2009,6 @@ unitree = [ { name = "langchain-openai" }, { name = "langchain-text-splitters" }, { name = "lap" }, - { name = "mcp" }, { name = "moondream" }, { name = "mujoco" }, { name = "ollama" }, @@ -2089,7 +2086,6 @@ requires-dist = [ { name = "llvmlite", specifier = ">=0.42.0" }, { name = "lxml-stubs", marker = "extra == 'dev'", specifier = ">=0.5.1,<1" }, { name = "matplotlib", marker = "extra == 'manipulation'", specifier = ">=3.7.1" }, - { name = "mcp", marker = "extra == 'agents'", specifier = ">=1.0.0" }, { name = "md-babel-py", marker = "extra == 'dev'", specifier = "==1.1.1" }, { name = "moondream", marker = "extra == 'perception'" }, { name = "mujoco", marker = "extra == 'sim'", specifier = ">=3.3.4" }, @@ -3210,15 +3206,6 @@ http2 = [ { name = "h2" }, ] -[[package]] -name = "httpx-sse" -version = "0.4.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, -] - [[package]] name = "huggingface-hub" version = "0.36.2" @@ -4889,31 +4876,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/27/1a/1f68f9ba0c207934b35b86a8ca3aad8395a3d6dd7921c0686e23853ff5a9/mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e", size = 7350, upload-time = "2022-01-24T01:14:49.62Z" }, ] -[[package]] -name = "mcp" -version = "1.26.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "httpx" }, - { name = "httpx-sse" }, - { name = "jsonschema" }, - { name = "pydantic" }, - { name = "pydantic-settings" }, - { name = "pyjwt", extra = ["crypto"] }, - { name = "python-multipart" }, - { name = "pywin32", marker = "sys_platform == 'win32'" }, - { name = "sse-starlette" }, - { name = "starlette" }, - { name = "typing-extensions" }, - { name = "typing-inspection" }, - { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fc/6d/62e76bbb8144d6ed86e202b5edd8a4cb631e7c8130f3f4893c3f90262b10/mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66", size = 608005, upload-time = "2026-01-24T19:40:32.468Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fd/d9/eaa1f80170d2b7c5ba23f3b59f766f3a0bb41155fbc32a69adfa1adaaef9/mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca", size = 233615, upload-time = "2026-01-24T19:40:30.652Z" }, -] - [[package]] name = "md-babel-py" version = "1.1.1" @@ -7711,20 +7673,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] -[[package]] -name = "pyjwt" -version = "2.11.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" }, -] - -[package.optional-dependencies] -crypto = [ - { name = "cryptography" }, -] - [[package]] name = "pylibsrtp" version = "1.0.0" From 6426c53eaa6dc3d1765babcf1b12b6d58709e682 Mon Sep 17 00:00:00 2001 From: leshy Date: Tue, 24 Feb 2026 09:12:40 +0800 Subject: [PATCH 06/16] docs: go2 preflight checklist (#1349) * added go2 preflight checklist * typo --------- Co-authored-by: Paul Nechifor --- docs/platforms/quadruped/go2/index.md | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/docs/platforms/quadruped/go2/index.md b/docs/platforms/quadruped/go2/index.md index 40f32bcdd2..3a00ec1e84 100644 --- a/docs/platforms/quadruped/go2/index.md +++ b/docs/platforms/quadruped/go2/index.md @@ -35,14 +35,33 @@ Opens the command center at [localhost:7779](http://localhost:7779) with Rerun 3 ## Run on Your Go2 +### Pre-flight checks + +1. Robot is reachable and low latency <10ms, 0% packet loss +```bash +ping $ROBOT_IP +``` + +2. Built-in obstacle avoidance is on. (DimOS handles path planning, but the onboard obstacle avoidance provides an extra safety layer around tight spots) + +3. If video is not in sync with lidar/robot position, sync your clock with an NTP server + +```bash +sudo ntpdate pool.ntp.org +``` +or +```bash +sudo sntp -sS pool.ntp.org +``` + +### Ready to run DimOS + ```bash export ROBOT_IP= dimos run unitree-go2 ``` -That's it. DimOS connects via WebRTC (no jailbreak required), starts the full navigation stack, and opens the command center. - -> **Tip:** Keep the Unitree built-in obstacle avoidance enabled on the robot for now. DimOS handles path planning, but the onboard obstacle avoidance provides an extra safety layer. +That's it. DimOS connects via WebRTC (no jailbreak required), starts the full navigation stack, and opens the command center in your browser. ### What's Running From e23080f197114d860dd11a8380673de93f5562e2 Mon Sep 17 00:00:00 2001 From: s Date: Mon, 23 Feb 2026 17:13:26 -0800 Subject: [PATCH 07/16] docs: add Unitree G1 getting started documentation (#1347) - Add docs/platforms/humanoid/g1/index.md with full G1 guide - Update README.md to link G1 to new docs instead of todo.md - Covers: installation, simulation, real robot, agentic control, arm gestures, movement modes, keyboard teleop, all blueprints Closes DIM-576 --- README.md | 2 +- docs/platforms/humanoid/g1/index.md | 167 ++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 docs/platforms/humanoid/g1/index.md diff --git a/README.md b/README.md index a84fe11b0e..35ae8546c1 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ Dimensional is agent native -- "vibecode" your robots in natural language and bu 🟥 Unitree B1
- 🟨 Unitree G1
+ 🟨 Unitree G1
🟥 Xarm
diff --git a/docs/platforms/humanoid/g1/index.md b/docs/platforms/humanoid/g1/index.md new file mode 100644 index 0000000000..2e04f3b023 --- /dev/null +++ b/docs/platforms/humanoid/g1/index.md @@ -0,0 +1,167 @@ +# Unitree G1 — Getting Started + +The Unitree G1 is a humanoid robot platform with full-body locomotion, arm gesture control, and agentic capabilities — no ROS required for basic operation. + +## Requirements + +- Unitree G1 (stock firmware) +- Ubuntu 22.04/24.04 with CUDA GPU (recommended), or macOS (experimental) +- Python 3.12 +- ZED camera (mounted at chest height) for perception blueprints +- ROS 2 for navigation (the G1 navigation stack uses ROS nav) + +## Install + +First, install system dependencies for your platform: +- [Ubuntu](../../../installation/ubuntu.md) +- [macOS](../../../installation/osx.md) +- [Nix](../../../installation/nix.md) + +Then install DimOS: + +```bash +uv venv --python "3.12" +source .venv/bin/activate +uv pip install dimos[base,unitree] +``` + +## MuJoCo Simulation + +No hardware? Start with simulation: + +```bash +uv pip install dimos[base,unitree,sim] +dimos --simulation run unitree-g1-basic-sim +``` + +This runs the G1 in MuJoCo with the native A* navigation stack — same blueprint structure, simulated robot. Opens the command center at [localhost:7779](http://localhost:7779) with Rerun 3D visualization. + +## Run on Your G1 + +```bash +export ROBOT_IP= +dimos run unitree-g1-basic +``` + +DimOS connects via WebRTC, starts the ROS navigation stack, and opens the command center. + +### What's Running + +| Module | What It Does | +|--------|-------------| +| **G1Connection** | WebRTC connection to the robot — streams video, odometry | +| **Webcam** | ZED camera capture (stereo left, 15 fps) | +| **VoxelGridMapper** | Builds a 3D voxel map using column-carving (CUDA accelerated) | +| **CostMapper** | Converts 3D map → 2D costmap via terrain slope analysis | +| **WavefrontFrontierExplorer** | Autonomous exploration of unmapped areas | +| **ROSNav** | ROS 2 navigation integration for path planning | +| **RerunBridge** | 3D visualization in browser | +| **WebsocketVis** | Command center at localhost:7779 | + +### Send Goals + +From the command center ([localhost:7779](http://localhost:7779)): +- Click on the map to set navigation goals +- Toggle autonomous exploration +- Monitor robot pose, costmap, and planned path + +## Agentic Control + +Natural language control with an LLM agent that understands physical space and can command arm gestures: + +```bash +export OPENAI_API_KEY= +export ROBOT_IP= +dimos run unitree-g1-agentic +``` + +Then use the human CLI: + +```bash +humancli +> wave hello +> explore the room +> give me a high five +``` + +The agent subscribes to camera and spatial memory streams and has access to G1-specific skills including arm gestures and movement modes. + +### Arm Gestures + +The G1 agent can perform expressive arm gestures: + +| Gesture | Description | +|---------|-------------| +| Handshake | Perform a handshake gesture with the right hand | +| HighFive | Give a high five with the right hand | +| Hug | Perform a hugging gesture with both arms | +| HighWave | Wave with the hand raised high | +| Clap | Clap hands together | +| FaceWave | Wave near the face level | +| LeftKiss | Blow a kiss with the left hand | +| ArmHeart | Make a heart shape with both arms overhead | +| RightHeart | Make a heart gesture with the right hand | +| HandsUp | Raise both hands up in the air | +| RightHandUp | Raise only the right hand up | +| Reject | Make a rejection or "no" gesture | +| CancelAction | Cancel any current arm action and return to neutral | + +### Movement Modes + +| Mode | Description | +|------|-------------| +| WalkMode | Normal walking | +| WalkControlWaist | Walking with waist control | +| RunMode | Running | + +## Keyboard Teleop + +Direct keyboard control via a pygame-based joystick: + +```bash +export ROBOT_IP= +dimos run unitree-g1-joystick +``` + +## Available Blueprints + +| Blueprint | Description | +|-----------|-------------| +| `unitree-g1-basic` | Connection + ROS navigation + visualization | +| `unitree-g1-basic-sim` | Simulation with A* navigation | +| `unitree-g1` | Navigation + perception + spatial memory | +| `unitree-g1-sim` | Simulation with perception + spatial memory | +| `unitree-g1-agentic` | Full stack with LLM agent and G1 skills | +| `unitree-g1-agentic-sim` | Agentic stack in simulation | +| `unitree-g1-full` | Agentic + SHM image transport + keyboard teleop | +| `unitree-g1-joystick` | Navigation + keyboard teleop | +| `unitree-g1-detection` | Navigation + YOLO person detection and tracking | +| `unitree-g1-shm` | Navigation + perception with shared memory image transport | +| `uintree-g1-primitive-no-nav` | Sensors + visualization only (no navigation, base for custom blueprints) | + +### Blueprint Hierarchy + +Blueprints compose incrementally: + +``` +primitive (sensors + vis) +├── basic (+ connection + navigation) +│ ├── basic-sim (sim connection + A* nav) +│ ├── joystick (+ keyboard teleop) +│ └── detection (+ YOLO person tracking) +├── perceptive (+ spatial memory + object tracking) +│ ├── sim (sim variant) +│ └── shm (+ shared memory transport) +└── agentic (+ LLM agent + G1 skills) + ├── agentic-sim (sim variant) + └── full (+ SHM + keyboard teleop) +``` + +## Deep Dive + +- [Navigation Stack](../../../capabilities/navigation/readme.md) — path planning and autonomous exploration +- [Visualization](../../../usage/visualization.md) — Rerun, Foxglove, performance tuning +- [Data Streams](../../../usage/data_streams/) — RxPY streams, backpressure, quality filtering +- [Transports](../../../usage/transports/index.md) — LCM, SHM, DDS +- [Blueprints](../../../usage/blueprints.md) — composing modules +- [Agents](../../../capabilities/agents/readme.md) — LLM agent framework From 5d7ca0102de3c5507e07b4aa9635a68d3da7beac Mon Sep 17 00:00:00 2001 From: leshy Date: Tue, 24 Feb 2026 10:09:01 +0800 Subject: [PATCH 08/16] refactor(doclinks): unify link handlers and validate all file links (#1348) * feat(doclinks): validate and resolve .md links, fix broken doc links Add Pattern 3 to doclinks that validates existing .md links in docs: - Resolves relative .md paths to absolute links - Validates absolute .md links exist on disk - Falls back to doc_index search for broken links with disambiguation - Handles index.md files by searching parent dir name - Scores candidates by directory overlap + filename match Delete duplicate docs/usage/transports.md (identical to transports/index.md). Fixes 5 broken links across docs/ (agents/docs/index.md, capabilities/ navigation/readme.md, usage/lcm.md). * refactor --- dimos/utils/docs/doclinks.py | 247 ++++++++-- dimos/utils/docs/test_doclinks.py | 255 ++++++++++ docs/agents/docs/index.md | 8 +- docs/capabilities/navigation/readme.md | 4 +- docs/platforms/quadruped/go2/index.md | 16 +- docs/usage/data_streams/README.md | 10 +- docs/usage/data_streams/advanced_streams.md | 4 +- docs/usage/data_streams/temporal_alignment.md | 4 +- docs/usage/lcm.md | 2 +- docs/usage/sensor_streams/README.md | 10 +- docs/usage/sensor_streams/advanced_streams.md | 4 +- .../sensor_streams/temporal_alignment.md | 4 +- docs/usage/transports.md | 437 ------------------ docs/usage/transports/index.md | 4 +- 14 files changed, 495 insertions(+), 514 deletions(-) delete mode 100644 docs/usage/transports.md diff --git a/dimos/utils/docs/doclinks.py b/dimos/utils/docs/doclinks.py index 67d5897b28..2cf5d1702f 100644 --- a/dimos/utils/docs/doclinks.py +++ b/dimos/utils/docs/doclinks.py @@ -30,6 +30,7 @@ import re import subprocess import sys +import time from typing import Any @@ -78,7 +79,7 @@ def get_git_tracked_files(root: Path) -> list[Path]: return [] -def build_file_index(root: Path) -> dict[str, list[Path]]: +def build_file_index(root: Path, tracked_files: list[Path] | None = None) -> dict[str, list[Path]]: """ Build an index mapping filename suffixes to full paths. @@ -89,7 +90,8 @@ def build_file_index(root: Path) -> dict[str, list[Path]]: - dimos/protocol/service/spec.py """ index: dict[str, list[Path]] = defaultdict(list) - tracked_files = get_git_tracked_files(root) + if tracked_files is None: + tracked_files = get_git_tracked_files(root) for rel_path in tracked_files: parts = rel_path.parts @@ -102,7 +104,7 @@ def build_file_index(root: Path) -> dict[str, list[Path]]: return index -def build_doc_index(root: Path) -> dict[str, list[Path]]: +def build_doc_index(root: Path, tracked_files: list[Path] | None = None) -> dict[str, list[Path]]: """ Build an index mapping lowercase doc names to .md file paths. @@ -113,7 +115,8 @@ def build_doc_index(root: Path) -> dict[str, list[Path]]: - "modules" -> [Path("docs/modules/index.md")] (if modules/index.md exists) """ index: dict[str, list[Path]] = defaultdict(list) - tracked_files = get_git_tracked_files(root) + if tracked_files is None: + tracked_files = get_git_tracked_files(root) for rel_path in tracked_files: if rel_path.suffix != ".md": @@ -144,11 +147,78 @@ def find_symbol_line(file_path: Path, symbol: str) -> int | None: return None +# Extensions that indicate a backticked term is a filename, not a symbol +_FILE_EXTENSIONS = frozenset( + ( + ".py", + ".md", + ".ts", + ".js", + ".go", + ".rs", + ".c", + ".h", + ".cpp", + ".hpp", + ".java", + ".rb", + ".yaml", + ".yml", + ".json", + ".toml", + ".sh", + ".lua", + ) +) + + def extract_other_backticks(line: str, file_ref: str) -> list[str]: """Extract other backticked terms from a line, excluding the file reference.""" pattern = r"`([^`]+)`" matches = re.findall(pattern, line) - return [m for m in matches if m != file_ref and not m.endswith(".py") and "/" not in m] + return [ + m + for m in matches + if m != file_ref and "/" not in m and not any(m.endswith(ext) for ext in _FILE_EXTENSIONS) + ] + + +def score_path_similarity(candidate: Path, original_path: str) -> int: + """Score how well a candidate matches the original link's path. + + Counts common directory names plus a bonus for matching filename. + Higher = better match. + """ + orig = Path(original_path) + orig_dirs = set(orig.parent.parts) + cand_dirs = set(candidate.parent.parts) + score = len(orig_dirs & cand_dirs) + if candidate.name == orig.name: + score += 1 + return score + + +def pick_best_candidate(candidates: list[Path], original_path: str) -> Path | None: + """Pick the best candidate by path similarity. Returns None if tied.""" + if not candidates: + return None + if len(candidates) == 1: + return candidates[0] + scored = sorted(candidates, key=lambda c: score_path_similarity(c, original_path), reverse=True) + top = score_path_similarity(scored[0], original_path) + second = score_path_similarity(scored[1], original_path) + if top > second: + return scored[0] + return None # Ambiguous tie + + +def resolve_candidates(candidates: list[Path], original_path: str) -> Path | None: + """Resolve candidates to a single path. Returns None if 0 or ambiguous.""" + if len(candidates) == 1: + return candidates[0] + if len(candidates) > 1: + return pick_best_candidate(candidates, original_path) + return None def generate_link( @@ -245,14 +315,32 @@ def process_markdown( Returns (new_content, changes, errors). """ - changes = [] - errors = [] + changes: list[str] = [] + errors: list[str] = [] - # Pattern 1: [`filename`](link) - code file links + # Pattern 1: [`filename`](link) - backtick code links with symbol auto-linking code_pattern = r"\[`([^`]+)`\]\(([^)]*)\)" - # Pattern 2: [Text](.md) - doc file links - doc_pattern = r"\[([^\]]+)\]\(\.md\)" + # Pattern 2: [Text](url) - all non-backtick, non-image links + # (? tuple[Path | None, list[Path]]: + """Search for a broken link's target by name in doc_index or file_index.""" + path = Path(link_path) + if path.suffix == ".md": + stem = path.stem.lower() + if stem == "index": + stem = path.parent.name.lower() + candidates = doc_index.get(stem, []) if doc_index else [] + elif path.suffix: + # Has a file extension — search file_index by filename + candidates = file_index.get(path.name, []) + else: + # No extension (likely a directory) — no fallback search + return None, [] + return resolve_candidates(candidates, original_ref), candidates def replace_code_match(match: re.Match[str]) -> str: file_ref = match.group(1) @@ -267,18 +355,19 @@ def replace_code_match(match: re.Match[str]) -> str: if "." not in file_ref and "/" not in file_ref: return full_match - # Look up in index + # Look up in index, with disambiguation candidates = file_index.get(file_ref, []) + resolved_path = resolve_candidates(candidates, file_ref) - if len(candidates) == 0: - errors.append(f"No file matching '{file_ref}' found in codebase") - return full_match - elif len(candidates) > 1: - errors.append(f"'{file_ref}' matches multiple files: {[str(c) for c in candidates]}") + if resolved_path is None: + if len(candidates) > 1: + errors.append( + f"'{file_ref}' matches multiple files: {[str(c) for c in candidates]}" + ) + else: + errors.append(f"No file matching '{file_ref}' found in codebase") return full_match - resolved_path = candidates[0] - # Determine line fragment line_fragment = "" @@ -313,33 +402,107 @@ def replace_code_match(match: re.Match[str]) -> str: return new_match - def replace_doc_match(match: re.Match[str]) -> str: - """Replace [Text](.md) with resolved doc path.""" - if doc_index is None: - return match.group(0) - + def replace_link_match(match: re.Match[str]) -> str: + """Handle all non-backtick links: doc placeholders, path validation.""" link_text = match.group(1) + raw_link = match.group(2) full_match = match.group(0) - lookup_key = link_text.lower() - # Look up in doc index - candidates = doc_index.get(lookup_key, []) + # Skip URLs + if raw_link.startswith(("http://", "https://", "mailto:")): + return full_match - if len(candidates) == 0: - errors.append(f"No doc matching '{link_text}' found") + # Skip anchor-only links + if raw_link.startswith("#"): return full_match - elif len(candidates) > 1: - errors.append(f"'{link_text}' matches multiple docs: {[str(c) for c in candidates]}") + + # Extract fragment if present + fragment = "" + link_path = raw_link + if "#" in raw_link: + link_path, frag = raw_link.split("#", 1) + fragment = "#" + frag + + # .md placeholder: [Text](.md) → doc_index lookup by link text + if link_path == ".md": + if doc_index is None: + return full_match + lookup_key = link_text.lower() + candidates = doc_index.get(lookup_key, []) + resolved = resolve_candidates(candidates, lookup_key) + if resolved is not None: + new_link = generate_link( + resolved, root, doc_path, link_mode, github_url, github_ref, fragment + ) + result = f"[{link_text}]({new_link})" + if result != full_match: + changes.append(f" {link_text}: .md -> {new_link}") + return result + if len(candidates) > 1: + errors.append( + f"'{link_text}' matches multiple docs: {[str(c) for c in candidates]}" + ) + else: + errors.append(f"No doc matching '{link_text}' found") return full_match - resolved_path = candidates[0] - new_link = generate_link(resolved_path, root, doc_path, link_mode, github_url, github_ref) - new_match = f"[{link_text}]({new_link})" + # Absolute path + if link_path.startswith("/"): + target = root / link_path.lstrip("/") + if target.exists(): + return full_match # Valid, leave as-is + + # Broken — try fallback search + resolved, candidates = _search_fallback(link_path, link_path.lstrip("/")) + if resolved is not None: + new_link = generate_link( + resolved, root, doc_path, link_mode, github_url, github_ref, fragment + ) + changes.append(f" {link_text}: {raw_link} -> {new_link} (fixed broken link)") + return f"[{link_text}]({new_link})" + if len(candidates) > 1: + errors.append( + f"Broken link '{raw_link}': ambiguous, matches {[str(c) for c in candidates]}" + ) + else: + errors.append(f"Broken link: '{raw_link}' does not exist") + return full_match - if new_match != full_match: - changes.append(f" {link_text}: .md -> {new_link}") + # Relative path — resolve from doc file's directory + doc_dir = doc_path.parent + resolved_abs = (doc_dir / link_path).resolve() - return new_match + try: + rel_to_root = resolved_abs.relative_to(root) + except ValueError: + errors.append(f"Link '{raw_link}' resolves outside repo root") + return full_match + + if resolved_abs.exists(): + # File exists — convert to appropriate link format + new_link = generate_link( + rel_to_root, root, doc_path, link_mode, github_url, github_ref, fragment + ) + result = f"[{link_text}]({new_link})" + if result != full_match: + changes.append(f" {link_text}: {raw_link} -> {new_link}") + return result + + # Target doesn't exist — try fallback search + resolved, candidates = _search_fallback(link_path, raw_link) + if resolved is not None: + new_link = generate_link( + resolved, root, doc_path, link_mode, github_url, github_ref, fragment + ) + changes.append(f" {link_text}: {raw_link} -> {new_link} (found by search)") + return f"[{link_text}]({new_link})" + if len(candidates) > 1: + errors.append( + f"Broken link '{raw_link}': ambiguous, matches {[str(c) for c in candidates]}" + ) + else: + errors.append(f"Broken link '{raw_link}': target not found") + return full_match # Split by ignore regions and only process non-ignored parts regions = split_by_ignore_regions(content) @@ -347,9 +510,9 @@ def replace_doc_match(match: re.Match[str]) -> str: for region_content, should_process in regions: if should_process: - # Process code links first, then doc links + # Process code links first, then all other links processed = re.sub(code_pattern, replace_code_match, region_content) - processed = re.sub(doc_pattern, replace_doc_match, processed) + processed = re.sub(link_pattern, replace_link_match, processed) result_parts.append(processed) else: result_parts.append(region_content) @@ -377,6 +540,7 @@ def collect_markdown_files(paths: list[str]) -> list[Path]: Also auto-links symbols: `Configurable` on same line adds #L fragment. Supports doc-to-doc linking: [Modules](.md) resolves to modules.md or modules/index.md. +Validates all file links and fixes broken relative/absolute paths by searching the index. Usage: doclinks [options] @@ -471,8 +635,9 @@ def main() -> None: sys.exit(1) print(f"Building file index from {root}...") - file_index = build_file_index(root) - doc_index = build_doc_index(root) + tracked_files = get_git_tracked_files(root) + file_index = build_file_index(root, tracked_files) + doc_index = build_doc_index(root, tracked_files) print( f"Indexed {sum(len(v) for v in file_index.values())} file paths, {len(doc_index)} doc names" ) @@ -551,8 +716,6 @@ def on_created(self, event: Any) -> None: observer.start() try: while True: - import time - time.sleep(1) except KeyboardInterrupt: observer.stop() diff --git a/dimos/utils/docs/test_doclinks.py b/dimos/utils/docs/test_doclinks.py index f1303a2245..968f465cef 100644 --- a/dimos/utils/docs/test_doclinks.py +++ b/dimos/utils/docs/test_doclinks.py @@ -21,7 +21,10 @@ build_file_index, extract_other_backticks, find_symbol_line, + pick_best_candidate, process_markdown, + resolve_candidates, + score_path_similarity, split_by_ignore_regions, ) import pytest @@ -520,5 +523,257 @@ def test_ignores_doc_links_in_region(self, file_index, doc_index): assert "[Configuration](.md) example" in new_content +class TestPathSimilarity: + def test_exact_dir_match(self): + """Same directory components should give high score.""" + candidate = Path("docs/agents/docs/codeblocks.md") + score = score_path_similarity(candidate, "docs/agents/docs_agent/codeblocks.md") + assert score >= 2 # docs, agents + + def test_partial_match(self): + """Some shared dirs should give partial score.""" + candidate = Path("docs/other/codeblocks.md") + score = score_path_similarity(candidate, "docs/agents/docs_agent/codeblocks.md") + assert score == 2 # docs dir + filename match + + def test_no_match(self): + """Unrelated dirs should give filename-only score.""" + candidate = Path("src/lib/codeblocks.md") + score = score_path_similarity(candidate, "docs/agents/docs_agent/codeblocks.md") + assert score == 1 # filename match only, no dir overlap + + def test_pick_best_single(self): + """Single candidate always wins.""" + candidates = [Path("docs/agents/docs/codeblocks.md")] + best = pick_best_candidate(candidates, "docs/agents/docs_agent/codeblocks.md") + assert best == candidates[0] + + def test_pick_best_disambiguates(self): + """Should pick candidate with more directory overlap.""" + candidates = [ + Path("docs/other/codeblocks.md"), + Path("docs/agents/docs/codeblocks.md"), + ] + best = pick_best_candidate(candidates, "docs/agents/docs_agent/codeblocks.md") + assert best == Path("docs/agents/docs/codeblocks.md") + + def test_pick_best_tie_returns_none(self): + """Tied scores should return None.""" + candidates = [ + Path("a/x/file.md"), + Path("b/x/file.md"), + ] + best = pick_best_candidate(candidates, "c/x/file.md") + assert best is None + + +class TestResolveCandidates: + def test_single_candidate(self): + candidates = [Path("docs/usage/modules.md")] + assert resolve_candidates(candidates, "modules.md") == candidates[0] + + def test_empty_candidates(self): + assert resolve_candidates([], "modules.md") is None + + def test_disambiguates(self): + candidates = [ + Path("docs/other/codeblocks.md"), + Path("docs/agents/docs/codeblocks.md"), + ] + result = resolve_candidates(candidates, "docs/agents/docs_agent/codeblocks.md") + assert result == Path("docs/agents/docs/codeblocks.md") + + def test_tie_returns_none(self): + candidates = [Path("a/x/file.md"), Path("b/x/file.md")] + assert resolve_candidates(candidates, "c/x/file.md") is None + + +class TestLinkResolution: + def _process(self, content, file_index, doc_index, doc_path=None, link_mode="absolute"): + if doc_path is None: + doc_path = REPO_ROOT / "docs/usage/test.md" + return process_markdown( + content, + REPO_ROOT, + doc_path, + file_index, + link_mode=link_mode, + github_url=None, + github_ref="main", + doc_index=doc_index, + ) + + def test_resolves_relative_md_link(self, file_index, doc_index): + """Should resolve a valid relative .md link to absolute path.""" + # docs/usage/configuration.md exists — link from docs/usage/test.md + content = "[Configuration](configuration.md)" + new_content, changes, errors = self._process(content, file_index, doc_index) + + assert len(errors) == 0 + assert "configuration.md" in new_content + + def test_validates_absolute_md_link(self, file_index, doc_index): + """Valid absolute .md link should be left unchanged.""" + content = "[Configuration](/docs/usage/configuration.md)" + new_content, changes, errors = self._process(content, file_index, doc_index) + + assert len(errors) == 0 + assert new_content == content + + def test_reports_broken_absolute_md_link(self, file_index, doc_index): + """Broken absolute .md link with no match should error.""" + content = "[Foo](/docs/nonexistent/xyzzy_no_match.md)" + new_content, changes, errors = self._process(content, file_index, doc_index) + + assert len(errors) == 1 + assert "Broken link" in errors[0] or "does not exist" in errors[0] + + def test_searches_broken_relative_link(self, file_index, doc_index): + """Broken relative .md link should be resolved by name search if unique.""" + # Link to a non-existent relative path, but stem matches a known doc + content = "[Configuration](../nonexistent/configuration.md)" + new_content, changes, errors = self._process(content, file_index, doc_index) + + # Should resolve via search fallback (configuration.md exists) + if "configuration" in doc_index and len(doc_index["configuration"]) == 1: + assert len(errors) == 0 + assert len(changes) == 1 + assert "found by search" in changes[0] + else: + # Multiple matches — disambiguation should kick in + assert len(errors) <= 1 + + def test_disambiguates_by_path_similarity(self, file_index, doc_index): + """Multiple candidates should be disambiguated by directory overlap.""" + # Build a custom doc_index with multiple candidates + from collections import defaultdict + + custom_doc_index: dict[str, list[Path]] = defaultdict(list) + custom_doc_index["testdoc"] = [ + Path("docs/other/testdoc.md"), + Path("docs/agents/docs/testdoc.md"), + ] + + content = "[TestDoc](../agents/docs_agent/testdoc.md)" + doc_path = REPO_ROOT / "docs/usage/test.md" + new_content, changes, errors = process_markdown( + content, + REPO_ROOT, + doc_path, + file_index, + link_mode="absolute", + github_url=None, + github_ref="main", + doc_index=custom_doc_index, + ) + + # Should pick docs/agents/docs/testdoc.md (shares "docs", "agents") + assert len(errors) == 0 + assert len(changes) == 1 + assert "agents/docs/testdoc.md" in new_content + + def test_skips_url_md_links(self, file_index, doc_index): + """HTTP(S) .md links should be left untouched.""" + content = "[External](https://example.com/guide.md)" + new_content, changes, errors = self._process(content, file_index, doc_index) + + assert len(errors) == 0 + assert len(changes) == 0 + assert new_content == content + + def test_preserves_fragment(self, file_index, doc_index): + """Fragment (#section) should be preserved in resolved link.""" + content = "[Config](configuration.md#advanced)" + new_content, changes, errors = self._process(content, file_index, doc_index) + + assert "#advanced" in new_content + + def test_skips_backtick_wrapped(self, file_index, doc_index): + """Backtick-wrapped .md link text should be skipped by md_link_pattern.""" + content = "[`configuration.md`](configuration.md)" + new_content, changes, errors = self._process(content, file_index, doc_index) + + # The code_pattern handles backtick links; md_link_pattern sees backticks and skips + # No double-processing should occur + assert "configuration.md" in new_content + + def test_md_links_in_ignore_region(self, file_index, doc_index): + """Links in ignore regions should not be processed.""" + content = ( + "[Configuration](configuration.md)\n" + "\n" + "[Configuration](broken_nonexistent.md)\n" + "\n" + "[Configuration](configuration.md)" + ) + new_content, changes, errors = self._process(content, file_index, doc_index) + + # The broken link in ignore region should not produce errors + assert "broken_nonexistent.md" in new_content # Preserved as-is + + def test_validates_absolute_py_link(self, file_index, doc_index): + """Valid absolute .py link (without backticks) should be left unchanged.""" + content = "[spec](/dimos/protocol/service/spec.py)" + new_content, changes, errors = self._process(content, file_index, doc_index) + + assert len(errors) == 0 + assert new_content == content + + def test_broken_py_link_searches_file_index(self, file_index, doc_index): + """Broken .py link should fall back to file_index search.""" + content = "[spec](/nonexistent/path/service/spec.py)" + new_content, changes, errors = self._process(content, file_index, doc_index) + + # service/spec.py is unique in file_index — should resolve + # But spec.py alone is ambiguous, so it depends on disambiguation + # The fallback searches by filename (spec.py) which has multiple matches + # pick_best_candidate should resolve via path similarity + if len(errors) == 0: + assert "fixed broken link" in changes[0] + # If ambiguous, at least we get an error not a silent pass + else: + assert "Broken link" in errors[0] + + def test_validates_directory_link(self, file_index, doc_index): + """Valid directory link should be left unchanged.""" + content = "[examples](/examples/)" + doc_path = REPO_ROOT / "docs/test.md" + new_content, changes, errors = process_markdown( + content, + REPO_ROOT, + doc_path, + file_index, + link_mode="absolute", + github_url=None, + github_ref="main", + doc_index=doc_index, + ) + + if (REPO_ROOT / "examples").exists(): + assert len(errors) == 0 + assert new_content == content + else: + # Directory doesn't exist — should error + assert len(errors) == 1 + + def test_skips_image_links(self, file_index, doc_index): + """Image links ![alt](path) should not be processed.""" + content = "![screenshot](assets/screenshot.png)" + new_content, changes, errors = self._process(content, file_index, doc_index) + + assert len(errors) == 0 + assert len(changes) == 0 + assert new_content == content + + def test_skips_mailto_links(self, file_index, doc_index): + """mailto: links should be left untouched.""" + content = "[Email](mailto:test@example.com)" + new_content, changes, errors = self._process(content, file_index, doc_index) + + assert len(errors) == 0 + assert len(changes) == 0 + assert new_content == content + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/docs/agents/docs/index.md b/docs/agents/docs/index.md index bec2ce79e6..09dabad7ee 100644 --- a/docs/agents/docs/index.md +++ b/docs/agents/docs/index.md @@ -8,7 +8,7 @@ If you're showing an API usage pattern, create a minimal working example that ac After writing a code block in your markdown file, you can run it by executing `md-babel-py run document.md` -more information on this tool is in [codeblocks](/docs/agents/docs_agent/codeblocks.md) +more information on this tool is in [codeblocks](/docs/agents/docs/codeblocks.md) # Code or Docs Links @@ -40,15 +40,15 @@ The `Configurable` class is defined in [`service/spec.py`](/dimos/protocol/servi ### Doc-to-doc references Use `.md` as the link target: ```markdown -See [Configuration](/docs/api/configuration.md) for more details. +See [Configuration](/docs/usage/configuration.md) for more details. ``` Becomes: ```markdown -See [Configuration](/docs/concepts/configuration.md) for more details. +See [Configuration](/docs/usage/configuration.md) for more details. ``` -More information on this in [doclinks](/docs/agents/docs_agent/doclinks.md) +More information on this in [doclinks](/docs/agents/docs/doclinks.md) # Pikchr diff --git a/docs/capabilities/navigation/readme.md b/docs/capabilities/navigation/readme.md index af26c07f94..f36d795e62 100644 --- a/docs/capabilities/navigation/readme.md +++ b/docs/capabilities/navigation/readme.md @@ -3,8 +3,8 @@ ## Non-ROS -- [Go2 Navigation](native/index.md) — column-carving voxel mapping + slope-based costmap +- [Go2 Navigation](/docs/capabilities/navigation/native/index.md) — column-carving voxel mapping + slope-based costmap ## ROS -See [ROS Transports](/docs/api/transports.md) for bridging DimOS streams to ROS topics. +See [ROS Transports](/docs/usage/transports/index.md) for bridging DimOS streams to ROS topics. diff --git a/docs/platforms/quadruped/go2/index.md b/docs/platforms/quadruped/go2/index.md index 3a00ec1e84..ab9e6c202d 100644 --- a/docs/platforms/quadruped/go2/index.md +++ b/docs/platforms/quadruped/go2/index.md @@ -11,9 +11,9 @@ The Unitree Go2 is DimOS's primary reference platform. Full autonomous navigatio ## Install First, install system dependencies for your platform: -- [Ubuntu](../../../installation/ubuntu.md) -- [macOS](../../../installation/osx.md) -- [Nix](../../../installation/nix.md) +- [Ubuntu](/docs/installation/ubuntu.md) +- [macOS](/docs/installation/osx.md) +- [Nix](/docs/installation/nix.md) Then install DimOS: @@ -125,8 +125,8 @@ The agent subscribes to camera, LiDAR, and spatial memory streams — it sees wh ## Deep Dive -- [Navigation Stack](../../../capabilities/navigation/native/index.md) — column-carving voxel mapping, costmap generation, A* planning -- [Visualization](../../../usage/visualization.md) — Rerun, Foxglove, performance tuning -- [Data Streams](../../../usage/data_streams/) — RxPY streams, backpressure, quality filtering -- [Transports](../../../usage/transports/index.md) — LCM, SHM, DDS -- [Blueprints](../../../usage/blueprints.md) — composing modules +- [Navigation Stack](/docs/capabilities/navigation/native/index.md) — column-carving voxel mapping, costmap generation, A* planning +- [Visualization](/docs/usage/visualization.md) — Rerun, Foxglove, performance tuning +- [Data Streams](/docs/usage/data_streams) — RxPY streams, backpressure, quality filtering +- [Transports](/docs/usage/transports/index.md) — LCM, SHM, DDS +- [Blueprints](/docs/usage/blueprints.md) — composing modules diff --git a/docs/usage/data_streams/README.md b/docs/usage/data_streams/README.md index dc2ce6c91d..870c25fb34 100644 --- a/docs/usage/data_streams/README.md +++ b/docs/usage/data_streams/README.md @@ -6,11 +6,11 @@ Dimos uses reactive streams (RxPY) to handle sensor data. This approach naturall | Guide | Description | |----------------------------------------------|---------------------------------------------------------------| -| [ReactiveX Fundamentals](reactivex.md) | Observables, subscriptions, and disposables | -| [Advanced Streams](advanced_streams.md) | Backpressure, parallel subscribers, synchronous getters | -| [Quality-Based Filtering](quality_filter.md) | Select highest quality frames when downsampling streams | -| [Temporal Alignment](temporal_alignment.md) | Match messages from multiple sensors by timestamp | -| [Storage & Replay](storage_replay.md) | Record sensor streams to disk and replay with original timing | +| [ReactiveX Fundamentals](/docs/usage/data_streams/reactivex.md) | Observables, subscriptions, and disposables | +| [Advanced Streams](/docs/usage/data_streams/advanced_streams.md) | Backpressure, parallel subscribers, synchronous getters | +| [Quality-Based Filtering](/docs/usage/data_streams/quality_filter.md) | Select highest quality frames when downsampling streams | +| [Temporal Alignment](/docs/usage/data_streams/temporal_alignment.md) | Match messages from multiple sensors by timestamp | +| [Storage & Replay](/docs/usage/data_streams/storage_replay.md) | Record sensor streams to disk and replay with original timing | ## Quick Example diff --git a/docs/usage/data_streams/advanced_streams.md b/docs/usage/data_streams/advanced_streams.md index 187d432af2..e9d9f1d12d 100644 --- a/docs/usage/data_streams/advanced_streams.md +++ b/docs/usage/data_streams/advanced_streams.md @@ -1,6 +1,6 @@ # Advanced Stream Handling -> **Prerequisite:** Read [ReactiveX Fundamentals](reactivex.md) first for Observable basics. +> **Prerequisite:** Read [ReactiveX Fundamentals](/docs/usage/data_streams/reactivex.md) first for Observable basics. ## Backpressure and Parallel Subscribers to Hardware @@ -126,7 +126,7 @@ class MLModel(Module): Sometimes you don't want a stream, you just want to call a function and get the latest value. -If you are doing this periodically as a part of a processing loop, it is very likely that your code will be much cleaner and safer using actual reactivex pipeline. So bias towards checking our [reactivex quick guide](reactivex.md) and [official docs](https://rxpy.readthedocs.io/) +If you are doing this periodically as a part of a processing loop, it is very likely that your code will be much cleaner and safer using actual reactivex pipeline. So bias towards checking our [reactivex quick guide](/docs/usage/data_streams/reactivex.md) and [official docs](https://rxpy.readthedocs.io/) (TODO we should actually make this example actually executable) diff --git a/docs/usage/data_streams/temporal_alignment.md b/docs/usage/data_streams/temporal_alignment.md index 66230c9d54..c428c04e2e 100644 --- a/docs/usage/data_streams/temporal_alignment.md +++ b/docs/usage/data_streams/temporal_alignment.md @@ -34,7 +34,7 @@ Below we set up replay of real camera and lidar data from the Unitree Go2 robot.
Stream Setup -You can read more about [sensor storage here](storage_replay.md) and [LFS data storage here](/docs/development/large_file_management.md). +You can read more about [sensor storage here](/docs/usage/data_streams/storage_replay.md) and [LFS data storage here](/docs/development/large_file_management.md). ```python session=align no-result from reactivex import Subject @@ -196,7 +196,7 @@ plot_alignment_timeline(video_frames, lidar_scans, aligned_pairs, '{output}') ## Combine Frame Alignment with a Quality Filter -More on [quality filtering here](quality_filter.md). +More on [quality filtering here](/docs/usage/data_streams/quality_filter.md). ```python session=align from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier diff --git a/docs/usage/lcm.md b/docs/usage/lcm.md index 99437a2458..d089cfcdd3 100644 --- a/docs/usage/lcm.md +++ b/docs/usage/lcm.md @@ -7,7 +7,7 @@ The LCM project provides pubsub clients and code generators for many languages. Our messages are ported from ROS (they are structurally compatible in order to facilitate easy communication to ROS if needed) Repo that hosts our message definitions and autogenerators is at [dimos-lcm](https://github.com/dimensionalOS/dimos-lcm/) -our LCM implementation significantly [outperforms ROS for local communication](/docs/usage/transports.md#benchmarks) +our LCM implementation significantly [outperforms ROS for local communication](/docs/usage/transports/index.md#benchmarks) ## Supported languages diff --git a/docs/usage/sensor_streams/README.md b/docs/usage/sensor_streams/README.md index dc2ce6c91d..0bf61e98ef 100644 --- a/docs/usage/sensor_streams/README.md +++ b/docs/usage/sensor_streams/README.md @@ -6,11 +6,11 @@ Dimos uses reactive streams (RxPY) to handle sensor data. This approach naturall | Guide | Description | |----------------------------------------------|---------------------------------------------------------------| -| [ReactiveX Fundamentals](reactivex.md) | Observables, subscriptions, and disposables | -| [Advanced Streams](advanced_streams.md) | Backpressure, parallel subscribers, synchronous getters | -| [Quality-Based Filtering](quality_filter.md) | Select highest quality frames when downsampling streams | -| [Temporal Alignment](temporal_alignment.md) | Match messages from multiple sensors by timestamp | -| [Storage & Replay](storage_replay.md) | Record sensor streams to disk and replay with original timing | +| [ReactiveX Fundamentals](/docs/usage/sensor_streams/reactivex.md) | Observables, subscriptions, and disposables | +| [Advanced Streams](/docs/usage/sensor_streams/advanced_streams.md) | Backpressure, parallel subscribers, synchronous getters | +| [Quality-Based Filtering](/docs/usage/sensor_streams/quality_filter.md) | Select highest quality frames when downsampling streams | +| [Temporal Alignment](/docs/usage/sensor_streams/temporal_alignment.md) | Match messages from multiple sensors by timestamp | +| [Storage & Replay](/docs/usage/sensor_streams/storage_replay.md) | Record sensor streams to disk and replay with original timing | ## Quick Example diff --git a/docs/usage/sensor_streams/advanced_streams.md b/docs/usage/sensor_streams/advanced_streams.md index 187d432af2..c2cd0dbfca 100644 --- a/docs/usage/sensor_streams/advanced_streams.md +++ b/docs/usage/sensor_streams/advanced_streams.md @@ -1,6 +1,6 @@ # Advanced Stream Handling -> **Prerequisite:** Read [ReactiveX Fundamentals](reactivex.md) first for Observable basics. +> **Prerequisite:** Read [ReactiveX Fundamentals](/docs/usage/sensor_streams/reactivex.md) first for Observable basics. ## Backpressure and Parallel Subscribers to Hardware @@ -126,7 +126,7 @@ class MLModel(Module): Sometimes you don't want a stream, you just want to call a function and get the latest value. -If you are doing this periodically as a part of a processing loop, it is very likely that your code will be much cleaner and safer using actual reactivex pipeline. So bias towards checking our [reactivex quick guide](reactivex.md) and [official docs](https://rxpy.readthedocs.io/) +If you are doing this periodically as a part of a processing loop, it is very likely that your code will be much cleaner and safer using actual reactivex pipeline. So bias towards checking our [reactivex quick guide](/docs/usage/sensor_streams/reactivex.md) and [official docs](https://rxpy.readthedocs.io/) (TODO we should actually make this example actually executable) diff --git a/docs/usage/sensor_streams/temporal_alignment.md b/docs/usage/sensor_streams/temporal_alignment.md index 66230c9d54..7d1ad074f2 100644 --- a/docs/usage/sensor_streams/temporal_alignment.md +++ b/docs/usage/sensor_streams/temporal_alignment.md @@ -34,7 +34,7 @@ Below we set up replay of real camera and lidar data from the Unitree Go2 robot.
Stream Setup -You can read more about [sensor storage here](storage_replay.md) and [LFS data storage here](/docs/development/large_file_management.md). +You can read more about [sensor storage here](/docs/usage/sensor_streams/storage_replay.md) and [LFS data storage here](/docs/development/large_file_management.md). ```python session=align no-result from reactivex import Subject @@ -196,7 +196,7 @@ plot_alignment_timeline(video_frames, lidar_scans, aligned_pairs, '{output}') ## Combine Frame Alignment with a Quality Filter -More on [quality filtering here](quality_filter.md). +More on [quality filtering here](/docs/usage/sensor_streams/quality_filter.md). ```python session=align from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier diff --git a/docs/usage/transports.md b/docs/usage/transports.md deleted file mode 100644 index 4c80776531..0000000000 --- a/docs/usage/transports.md +++ /dev/null @@ -1,437 +0,0 @@ -# Transports - -Transports connect **module streams** across **process boundaries** and/or **networks**. - -* **Module**: a running component (e.g., camera, mapping, nav). -* **Stream**: a unidirectional flow of messages owned by a module (one broadcaster → many receivers). -* **Topic**: the name/identifier used by a transport or pubsub backend. -* **Message**: payload carried on a stream (often `dimos.msgs.*`, but can be bytes / images / pointclouds / etc.). - -Each edge in the graph is a **transported stream** (potentially different protocols). Each node is a **module**: - -![go2_nav](assets/go2_nav.svg) - -## What the transport layer guarantees (and what it doesn’t) - -Modules **don’t** know or care *how* data moves. They just: - -* emit messages (broadcast) -* subscribe to messages (receive) - -A transport is responsible for the mechanics of delivery (IPC, sockets, Redis, ROS 2, etc.). - -**Important:** delivery semantics depend on the backend: - -* Some are **best-effort** (e.g., UDP multicast / LCM): loss can happen. -* Some can be **reliable** (e.g., TCP-backed, Redis, some DDS configs) but may add latency/backpressure. - -So: treat the API as uniform, but pick a backend whose semantics match the task. - ---- - -## Benchmarks - -Quick view on performance of our pubsub backends: - -```sh skip -python -m pytest -svm tool -k "not bytes" dimos/protocol/pubsub/benchmark/test_benchmark.py -``` - -![Benchmark results](assets/pubsub_benchmark.png) - ---- - -## Abstraction layers - -
Pikchr - -```pikchr output=assets/abstraction_layers.svg fold -color = white -fill = none -linewid = 0.5in -boxwid = 1.0in -boxht = 0.4in - -# Boxes with labels -B: box "Blueprints" rad 10px -arrow -M: box "Modules" rad 5px -arrow -T: box "Transports" rad 5px -arrow -P: box "PubSub" rad 5px - -# Descriptions below -text "robot configs" at B.s + (0.1, -0.2in) -text "camera, nav" at M.s + (0, -0.2in) -text "LCM, SHM, ROS" at T.s + (0, -0.2in) -text "pub/sub API" at P.s + (0, -0.2in) -``` - -
- - -![output](assets/abstraction_layers.svg) - -We’ll go through these layers top-down. - ---- - -## Using transports with blueprints - -See [Blueprints](blueprints.md) for the blueprint API. - -From [`unitree/go2/blueprints/__init__.py`](/dimos/robot/unitree/go2/blueprints/__init__.py). - -Example: rebind a few streams from the default `LCMTransport` to `ROSTransport` (defined at [`transport.py`](/dimos/core/transport.py#L226)) so you can visualize in **rviz2**. - -```python skip -nav = autoconnect( - basic, - voxel_mapper(voxel_size=0.1), - cost_mapper(), - replanning_a_star_planner(), - wavefront_frontier_explorer(), -).global_config(n_dask_workers=6, robot_model="unitree_go2") - -ros = nav.transports( - { - ("lidar", PointCloud2): ROSTransport("lidar", PointCloud2), - ("global_map", PointCloud2): ROSTransport("global_map", PointCloud2), - ("odom", PoseStamped): ROSTransport("odom", PoseStamped), - ("color_image", Image): ROSTransport("color_image", Image), - } -) -``` - ---- - -## Using transports with modules - -Each **stream** on a module can use a different transport. Set `.transport` on the stream **before starting** modules. - -```python ansi=false -import time - -from dimos.core import In, Module, start -from dimos.core.transport import LCMTransport -from dimos.hardware.sensors.camera.module import CameraModule -from dimos.msgs.sensor_msgs import Image - - -class ImageListener(Module): - image: In[Image] - - def start(self): - super().start() - self.image.subscribe(lambda img: print(f"Received: {img.shape}")) - - -if __name__ == "__main__": - # Start local cluster and deploy modules to separate processes - dimos = start(2) - - camera = dimos.deploy(CameraModule, frequency=2.0) - listener = dimos.deploy(ImageListener) - - # Choose a transport for the stream (example: LCM typed channel) - camera.color_image.transport = LCMTransport("/camera/rgb", Image) - - # Connect listener input to camera output - listener.image.connect(camera.color_image) - - camera.start() - listener.start() - - time.sleep(2) - dimos.stop() -``` - - - -``` -Initialized dimos local cluster with 2 workers, memory limit: auto -2026-01-24T13:17:50.190559Z [info ] Deploying module. [dimos/core/__init__.py] module=CameraModule -2026-01-24T13:17:50.218466Z [info ] Deployed module. [dimos/core/__init__.py] module=CameraModule worker_id=1 -2026-01-24T13:17:50.229474Z [info ] Deploying module. [dimos/core/__init__.py] module=ImageListener -2026-01-24T13:17:50.250199Z [info ] Deployed module. [dimos/core/__init__.py] module=ImageListener worker_id=0 -Received: (480, 640, 3) -Received: (480, 640, 3) -Received: (480, 640, 3) -``` - -See [Modules](modules.md) for more on module architecture. - ---- - -## Inspecting LCM traffic (CLI) - -`lcmspy` shows topic frequency/bandwidth stats: - -![lcmspy](assets/lcmspy.png) - -`dimos topic echo /topic` listens on typed channels like `/topic#pkg.Msg` and decodes automatically: - -```sh skip -Listening on /camera/rgb (inferring from typed LCM channels like '/camera/rgb#pkg.Msg')... (Ctrl+C to stop) -Image(shape=(480, 640, 3), format=RGB, dtype=uint8, dev=cpu, ts=2026-01-24 20:28:59) -``` - ---- - -## Implementing a transport - -At the stream layer, a transport is implemented by subclassing `Transport` (see [`core/stream.py`](/dimos/core/stream.py#L83)) and implementing: - -* `broadcast(...)` -* `subscribe(...)` - -Your `Transport.__init__` args can be anything meaningful for your backend: - -* `(ip, port)` -* a shared-memory segment name -* a filesystem path -* a Redis channel - -Encoding is an implementation detail, but we encourage using LCM-compatible message types when possible. - -### Encoding helpers - -Many of our message types provide `lcm_encode` / `lcm_decode` for compact, language-agnostic binary encoding (often faster than pickle). For details, see [LCM](/docs/usage/lcm.md). - ---- - -## PubSub transports - -Even though transport can be anything (TCP connection, unix socket) for now all our transport backends implement the `PubSub` interface. - -* `publish(topic, message)` -* `subscribe(topic, callback) -> unsubscribe` - -```python -from dimos.protocol.pubsub.spec import PubSub -import inspect - -print(inspect.getsource(PubSub.publish)) -print(inspect.getsource(PubSub.subscribe)) -``` - - -```python - @abstractmethod - def publish(self, topic: TopicT, message: MsgT) -> None: - """Publish a message to a topic.""" - ... - - @abstractmethod - def subscribe( - self, topic: TopicT, callback: Callable[[MsgT, TopicT], None] - ) -> Callable[[], None]: - """Subscribe to a topic with a callback. returns unsubscribe function""" - ... -``` - -Topic/message types are flexible: bytes, JSON, or our ROS-compatible [LCM](/docs/usage/lcm.md) types. We also have pickle-based transports for arbitrary Python objects. - -### LCM (UDP multicast) - -LCM is UDP multicast. It’s very fast on a robot LAN, but it’s **best-effort** (packets can drop). -For local emission it autoconfigures system in a way in which it's more robust and faster then other more common protocols like ROS, DDS - -```python -from dimos.protocol.pubsub.lcmpubsub import LCM, Topic -from dimos.msgs.geometry_msgs import Vector3 - -lcm = LCM(autoconf=True) -lcm.start() - -received = [] -topic = Topic("/robot/velocity", Vector3) - -lcm.subscribe(topic, lambda msg, t: received.append(msg)) -lcm.publish(topic, Vector3(1.0, 0.0, 0.5)) - -import time -time.sleep(0.1) - -print(f"Received velocity: x={received[0].x}, y={received[0].y}, z={received[0].z}") -lcm.stop() -``` - - -``` -Received velocity: x=1.0, y=0.0, z=0.5 -``` - -### Shared memory (IPC) - -Shared memory is highest performance, but only works on the **same machine**. - -```python -from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory - -shm = PickleSharedMemory(prefer="cpu") -shm.start() - -received = [] -shm.subscribe("test/topic", lambda msg, topic: received.append(msg)) -shm.publish("test/topic", {"data": [1, 2, 3]}) - -import time -time.sleep(0.1) - -print(f"Received: {received}") -shm.stop() -``` - - -``` -Received: [{'data': [1, 2, 3]}] -``` - -### DDS Transport - -For network communication, DDS uses the Data Distribution Service (DDS) protocol: - -```python session=dds_demo ansi=false -from dataclasses import dataclass -from cyclonedds.idl import IdlStruct - -from dimos.protocol.pubsub.impl.ddspubsub import DDS, Topic - -@dataclass -class SensorReading(IdlStruct): - value: float - -dds = DDS() -dds.start() - -received = [] -sensor_topic = Topic(name="sensors/temperature", data_type=SensorReading) - -dds.subscribe(sensor_topic, lambda msg, t: received.append(msg)) -dds.publish(sensor_topic, SensorReading(value=22.5)) - -import time -time.sleep(0.1) - -print(f"Received: {received}") -dds.stop() -``` - - -``` -Received: [SensorReading(value=22.5)] -``` - ---- - -## A minimal transport: `Memory` - -The simplest toy backend is `Memory` (single process). Start from there when implementing a new pubsub backend. - -```python -from dimos.protocol.pubsub.memory import Memory - -bus = Memory() -received = [] - -unsubscribe = bus.subscribe("sensor/data", lambda msg, topic: received.append(msg)) - -bus.publish("sensor/data", {"temperature": 22.5}) -bus.publish("sensor/data", {"temperature": 23.0}) - -print(f"Received {len(received)} messages:") -for msg in received: - print(f" {msg}") - -unsubscribe() -``` - - -``` -Received 2 messages: - {'temperature': 22.5} - {'temperature': 23.0} -``` - -See [`memory.py`](/dimos/protocol/pubsub/impl/memory.py) for the complete source. - ---- - -## Encode/decode mixins - -Transports often need to serialize messages before sending and deserialize after receiving. - -`PubSubEncoderMixin` at [`pubsub/spec.py`](/dimos/protocol/pubsub/spec.py#L95) provides a clean way to add encoding/decoding to any pubsub implementation. - -### Available mixins - -| Mixin | Encoding | Use case | -|----------------------|-----------------|------------------------------------| -| `PickleEncoderMixin` | Python pickle | Any Python object, Python-only | -| `LCMEncoderMixin` | LCM binary | Cross-language (C/C++/Python/Go/…) | -| `JpegEncoderMixin` | JPEG compressed | Image data, reduces bandwidth | - -`LCMEncoderMixin` is especially useful: you can use LCM message definitions with *any* transport (not just UDP multicast). See [LCM](/docs/usage/lcm.md) for details. - -### Creating a custom mixin - -```python session=jsonencoder no-result -from dimos.protocol.pubsub.spec import PubSubEncoderMixin -import json - -class JsonEncoderMixin(PubSubEncoderMixin[str, dict, bytes]): - def encode(self, msg: dict, topic: str) -> bytes: - return json.dumps(msg).encode("utf-8") - - def decode(self, msg: bytes, topic: str) -> dict: - return json.loads(msg.decode("utf-8")) -``` - -Combine with a pubsub implementation via multiple inheritance: - -```python session=jsonencoder no-result -from dimos.protocol.pubsub.memory import Memory - -class MyJsonPubSub(JsonEncoderMixin, Memory): - pass -``` - -Swap serialization by changing the mixin: - -```python session=jsonencoder no-result -from dimos.protocol.pubsub.spec import PickleEncoderMixin - -class MyPicklePubSub(PickleEncoderMixin, Memory): - pass -``` - ---- - -## Testing and benchmarks - -### Spec tests - -See [`pubsub/test_spec.py`](/dimos/protocol/pubsub/test_spec.py) for the grid tests your new backend should pass. - -### Benchmarks - -Add your backend to benchmarks to compare in context: - -```sh skip -python -m pytest -svm tool -k "not bytes" dimos/protocol/pubsub/benchmark/test_benchmark.py -``` - ---- - -# Available transports - -| Transport | Use case | Cross-process | Network | Notes | -|----------------|-------------------------------------|---------------|---------|--------------------------------------| -| `Memory` | Testing only, single process | No | No | Minimal reference impl | -| `SharedMemory` | Multi-process on same machine | Yes | No | Highest throughput (IPC) | -| `LCM` | Robot LAN broadcast (UDP multicast) | Yes | Yes | Best-effort; can drop packets on LAN | -| `Redis` | Network pubsub via Redis server | Yes | Yes | Central broker; adds hop | -| `ROS` | ROS 2 topic communication | Yes | Yes | Integrates with RViz/ROS tools | -| `DDS` | Cyclone DDS without ROS (WIP) | Yes | Yes | WIP | diff --git a/docs/usage/transports/index.md b/docs/usage/transports/index.md index 748cf03aa1..1c8745d117 100644 --- a/docs/usage/transports/index.md +++ b/docs/usage/transports/index.md @@ -79,7 +79,7 @@ We’ll go through these layers top-down. ## Using transports with blueprints -See [Blueprints](blueprints.md) for the blueprint API. +See [Blueprints](/docs/usage/blueprints.md) for the blueprint API. From [`unitree/go2/blueprints/__init__.py`](/dimos/robot/unitree/go2/blueprints/__init__.py). @@ -160,7 +160,7 @@ Received: (480, 640, 3) Received: (480, 640, 3) ``` -See [Modules](modules.md) for more on module architecture. +See [Modules](/docs/usage/modules.md) for more on module architecture. --- From f53a3ff354c10b7ac9787b18a73bd0157e304bdf Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Mon, 23 Feb 2026 23:35:53 -0600 Subject: [PATCH 09/16] Fix `bin/gen_diagrams` that is referenced in docs (#1291) * add missing command * Apply suggestion from @greptile-apps[bot] Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * fix * Update bin/gen-diagrams Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- bin/gen-diagrams | 32 ++++++++++++++++++++++++++++++++ docs/development/writing_docs.md | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100755 bin/gen-diagrams diff --git a/bin/gen-diagrams b/bin/gen-diagrams new file mode 100755 index 0000000000..7bdf29933b --- /dev/null +++ b/bin/gen-diagrams @@ -0,0 +1,32 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$REPO_ROOT" + +# if md-babel-py doesnt exist +if [ -z "$(command -v "md-babel-py")" ]; then + # if nix doesnt exist + if [ -z "$(command -v "nix")" ]; then + echo "Error: md-babel-py required for running gen-diagrams." >&2 + echo " Either install nix or install md-babel-py" >&2 + echo " https://github.com/leshy/md-babel-py" >&2 + exit 1 + # use nix if local command doesn't exist + else + md-babel-py() { + nix run github:leshy/md-babel-py -- "$@" + } + fi +fi + +diagram_langs="asymptote,pikchr,openscad,diagon" +if [[ "$#" -gt 0 ]]; then + for arg in "$@"; do + md-babel-py run "$arg" --lang "$diagram_langs" + done +else + while IFS= read -r file; do + md-babel-py run "$file" --lang "$diagram_langs" + done < <(find ./docs -type f -name '*.md' | sort) +fi diff --git a/docs/development/writing_docs.md b/docs/development/writing_docs.md index 58466d6592..8b24dc620b 100644 --- a/docs/development/writing_docs.md +++ b/docs/development/writing_docs.md @@ -3,5 +3,5 @@ 1. Where to put your docs: - If it only matters to people who contribute to dimos (like this doc), put them in `docs/development` - Otherwise put them in `docs/usage` -2. Run `bin/gen_diagrams` to generate the svg's for your diagrams. We use [pikchr](https://pikchr.org/home/doc/trunk/doc/userman.md) as a diagram language. +2. Run `bin/gen-diagrams` to generate the svg's for your diagrams. We use [mermaid](https://mermaid.js.org/intro/) (no generation needed) and [pikchr](https://pikchr.org/home/doc/trunk/doc/userman.md) as diagrams languages. 3. Use [md-babel-py](https://github.com/leshy/md-babel-py/) (`md-babel-py run thing.md`) to make sure your code examples work. From 3b15cde237432154c3bc236c277257f1cdd07bd6 Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Tue, 24 Feb 2026 15:12:54 +0200 Subject: [PATCH 10/16] fix(tests): simplify testing (#1343) --- .github/workflows/docker.yml | 100 +---------------- bin/pytest-slow | 2 +- dimos/agents/mcp/test_mcp_client.py | 10 +- .../test_google_maps_skill_container.py | 4 +- dimos/agents/skills/test_gps_nav_skills.py | 4 +- dimos/agents/skills/test_navigation.py | 6 +- .../skills/test_unitree_skill_container.py | 2 +- dimos/agents/test_agent.py | 10 +- .../memory/test_image_embedding.py | 14 ++- dimos/conftest.py | 43 ++++--- dimos/core/test_blueprints.py | 12 +- dimos/core/test_core.py | 8 +- dimos/core/test_native_module.py | 2 +- dimos/core/test_stream.py | 58 +++++++--- dimos/core/test_worker.py | 6 +- dimos/e2e_tests/test_control_coordinator.py | 5 +- dimos/e2e_tests/test_dimos_cli_e2e.py | 8 +- dimos/e2e_tests/test_person_follow.py | 5 +- dimos/e2e_tests/test_simulation_module.py | 6 +- dimos/e2e_tests/test_spatial_memory.py | 5 +- .../monitor/world_obstacle_monitor.py | 2 +- .../occupancy/test_extrude_occupancy.py | 2 +- dimos/memory/timeseries/test_legacy.py | 4 + dimos/models/embedding/test_embedding.py | 9 +- dimos/models/vl/test_base.py | 3 +- dimos/models/vl/test_captioner.py | 6 +- dimos/models/vl/test_vlm.py | 14 ++- dimos/msgs/geometry_msgs/test_TwistStamped.py | 12 -- dimos/msgs/nav_msgs/test_OccupancyGrid.py | 105 ------------------ dimos/msgs/sensor_msgs/test_PointCloud2.py | 5 - .../reid/test_embedding_id_system.py | 36 ++++-- dimos/perception/detection/test_moduleDB.py | 59 ---------- .../test_temporal_memory_module.py | 10 +- dimos/perception/test_spatial_memory.py | 13 ++- .../perception/test_spatial_memory_module.py | 38 +++---- dimos/protocol/pubsub/impl/test_rospubsub.py | 7 +- dimos/protocol/pubsub/test_spec.py | 2 +- dimos/protocol/rpc/test_spec.py | 7 +- dimos/robot/drone/test_drone.py | 4 - dimos/robot/test_all_blueprints.py | 2 +- dimos/robot/unitree/testing/test_actors.py | 6 - dimos/types/test_weaklist.py | 4 +- dimos/utils/cli/lcmspy/test_lcmspy.py | 100 ++++++++--------- dimos/utils/docs/test_doclinks.py | 4 - dimos/utils/test_data.py | 14 +-- dimos/utils/test_reactive.py | 6 +- dimos/utils/test_transform_utils.py | 4 - docs/capabilities/manipulation/readme.md | 2 +- docs/development/testing.md | 40 +++++-- docs/platforms/humanoid/g1/index.md | 18 +-- pyproject.toml | 15 +-- 51 files changed, 300 insertions(+), 563 deletions(-) delete mode 100644 dimos/perception/detection/test_moduleDB.py diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 03de5c3d15..361ef66bf8 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -212,70 +212,9 @@ jobs: uses: ./.github/workflows/tests.yml secrets: inherit with: - cmd: "pytest && pytest -m ros" # run tests that depend on ros as well + cmd: "pytest --durations=0 -m 'not (tool or mujoco)'" dev-image: ros-dev:${{ (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.dev == 'true' || needs.check-changes.outputs.ros == 'true') && needs.ros-dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} - run-tests: - needs: [check-changes, dev] - if: ${{ - always() && - needs.check-changes.result == 'success' && - (needs.check-changes.outputs.tests == 'true' || - needs.check-changes.outputs.python == 'true' || - needs.check-changes.outputs.dev == 'true') - }} - uses: ./.github/workflows/tests.yml - secrets: inherit - with: - cmd: "pytest" - dev-image: dev:${{ (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.dev == 'true') && needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} - - # we run in parallel with normal tests for speed - run-heavy-tests: - needs: [check-changes, dev] - if: ${{ - always() && - needs.check-changes.result == 'success' && - (needs.check-changes.outputs.tests == 'true' || - needs.check-changes.outputs.python == 'true' || - needs.check-changes.outputs.dev == 'true') - }} - uses: ./.github/workflows/tests.yml - secrets: inherit - with: - cmd: "pytest -m heavy" - dev-image: dev:${{ (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.dev == 'true') && needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} - - run-lcm-tests: - needs: [check-changes, dev] - if: ${{ - always() && - needs.check-changes.result == 'success' && - (needs.check-changes.outputs.tests == 'true' || - needs.check-changes.outputs.python == 'true' || - needs.check-changes.outputs.dev == 'true') - }} - uses: ./.github/workflows/tests.yml - secrets: inherit - with: - cmd: "pytest -m lcm" - dev-image: dev:${{ (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.dev == 'true') && needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} - - run-integration-tests: - needs: [check-changes, dev] - if: ${{ - always() && - needs.check-changes.result == 'success' && - (needs.check-changes.outputs.tests == 'true' || - needs.check-changes.outputs.python == 'true' || - needs.check-changes.outputs.dev == 'true') - }} - uses: ./.github/workflows/tests.yml - secrets: inherit - with: - cmd: "pytest -m integration" - dev-image: dev:${{ (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.dev == 'true') && needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} - run-mypy: needs: [check-changes, ros-dev] if: ${{ @@ -292,43 +231,8 @@ jobs: cmd: "MYPYPATH=/opt/ros/humble/lib/python3.10/site-packages mypy dimos" dev-image: ros-dev:${{ (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.dev == 'true' || needs.check-changes.outputs.ros == 'true') && needs.ros-dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} - # Run module tests directly to avoid pytest forking issues - # run-module-tests: - # needs: [check-changes, dev] - # if: ${{ - # always() && - # needs.check-changes.result == 'success' && - # ((needs.dev.result == 'success') || - # (needs.dev.result == 'skipped' && - # needs.check-changes.outputs.tests == 'true')) - # }} - # runs-on: [self-hosted, x64, 16gb] - # container: - # image: ghcr.io/dimensionalos/dev:${{ needs.check-changes.outputs.dev == 'true' && needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} - # steps: - # - name: Fix permissions - # run: | - # sudo chown -R $USER:$USER ${{ github.workspace }} || true - # - # - uses: actions/checkout@v4 - # with: - # lfs: true - # - # - name: Configure Git LFS - # run: | - # git config --global --add safe.directory '*' - # git lfs install - # git lfs fetch - # git lfs checkout - # - # - name: Run module tests - # env: - # CI: "true" - # run: | - # /entrypoint.sh bash -c "pytest -m module" - ci-complete: - needs: [check-changes, ros, python, ros-python, dev, ros-dev, run-tests, run-heavy-tests, run-lcm-tests, run-integration-tests, run-ros-tests, run-mypy] + needs: [check-changes, ros, python, ros-python, dev, ros-dev, run-ros-tests, run-mypy] runs-on: [self-hosted, Linux] if: always() steps: diff --git a/bin/pytest-slow b/bin/pytest-slow index 85643d4413..9f9d5ae611 100755 --- a/bin/pytest-slow +++ b/bin/pytest-slow @@ -3,4 +3,4 @@ set -euo pipefail . .venv/bin/activate -exec pytest "$@" -m 'not (tool or module or neverending or mujoco)' dimos +exec pytest "$@" -m 'not (tool or mujoco)' dimos diff --git a/dimos/agents/mcp/test_mcp_client.py b/dimos/agents/mcp/test_mcp_client.py index be4a09d5b9..946bdc4eb8 100644 --- a/dimos/agents/mcp/test_mcp_client.py +++ b/dimos/agents/mcp/test_mcp_client.py @@ -29,7 +29,7 @@ def add(self, x: int, y: int) -> str: return str(x + y) -@pytest.mark.integration +@pytest.mark.slow @pytest.mark.parametrize("dask", [False, True]) def test_can_call_tool(dask, agent_setup): history = agent_setup( @@ -66,7 +66,7 @@ def register_user(self, name: str) -> str: return "User name registered successfully." -@pytest.mark.integration +@pytest.mark.slow @pytest.mark.parametrize("dask", [False, True]) def test_can_call_again_on_error(dask, agent_setup): history = agent_setup( @@ -118,7 +118,7 @@ def go_to_location(self, description: str) -> str: return f"Going to the {description}." -@pytest.mark.integration +@pytest.mark.slow def test_multiple_tool_calls_with_multiple_messages(agent_setup): history = agent_setup( blueprints=[MultipleTools.blueprint(), NavigationSkill.blueprint()], @@ -172,7 +172,7 @@ def test_multiple_tool_calls_with_multiple_messages(agent_setup): assert len(go_to_location_calls) == 2 -@pytest.mark.integration +@pytest.mark.slow def test_prompt(agent_setup): history = agent_setup( blueprints=[], @@ -190,7 +190,7 @@ def take_a_picture(self) -> Image: return Image.from_file(get_data("cafe-smol.jpg")).to_rgb() -@pytest.mark.integration +@pytest.mark.slow def test_image(agent_setup): history = agent_setup( blueprints=[Visualizer.blueprint()], diff --git a/dimos/agents/skills/test_google_maps_skill_container.py b/dimos/agents/skills/test_google_maps_skill_container.py index 84da91e886..1d8e4549b0 100644 --- a/dimos/agents/skills/test_google_maps_skill_container.py +++ b/dimos/agents/skills/test_google_maps_skill_container.py @@ -70,7 +70,7 @@ def __init__(self): self._max_valid_distance = 20000 -@pytest.mark.integration +@pytest.mark.slow def test_where_am_i(agent_setup) -> None: history = agent_setup( blueprints=[FakeGPS.blueprint(), MockedWhereAmISkill.blueprint()], @@ -80,7 +80,7 @@ def test_where_am_i(agent_setup) -> None: assert "bourbon" in history[-1].content.lower() -@pytest.mark.integration +@pytest.mark.slow def test_get_gps_position_for_queries(agent_setup) -> None: history = agent_setup( blueprints=[FakeGPS.blueprint(), MockedPositionSkill.blueprint()], diff --git a/dimos/agents/skills/test_gps_nav_skills.py b/dimos/agents/skills/test_gps_nav_skills.py index afcb4d36d0..d701d469ca 100644 --- a/dimos/agents/skills/test_gps_nav_skills.py +++ b/dimos/agents/skills/test_gps_nav_skills.py @@ -35,7 +35,7 @@ def __init__(self): self._max_valid_distance = 50000 -@pytest.mark.integration +@pytest.mark.slow def test_set_gps_travel_points(agent_setup) -> None: history = agent_setup( blueprints=[FakeGPS.blueprint(), MockedGpsNavSkill.blueprint()], @@ -50,7 +50,7 @@ def test_set_gps_travel_points(agent_setup) -> None: assert "success" in history[-1].content.lower() -@pytest.mark.integration +@pytest.mark.slow def test_set_gps_travel_points_multiple(agent_setup) -> None: history = agent_setup( blueprints=[FakeGPS.blueprint(), MockedGpsNavSkill.blueprint()], diff --git a/dimos/agents/skills/test_navigation.py b/dimos/agents/skills/test_navigation.py index 91737ada77..a7505b23c7 100644 --- a/dimos/agents/skills/test_navigation.py +++ b/dimos/agents/skills/test_navigation.py @@ -72,7 +72,7 @@ def _navigate_using_semantic_map(self, query): return f"Successfuly arrived at '{query}'" -@pytest.mark.integration +@pytest.mark.slow def test_stop_movement(agent_setup) -> None: history = agent_setup( blueprints=[ @@ -86,7 +86,7 @@ def test_stop_movement(agent_setup) -> None: assert "stopped" in history[-1].content.lower() -@pytest.mark.integration +@pytest.mark.slow def test_start_exploration(agent_setup) -> None: history = agent_setup( blueprints=[ @@ -102,7 +102,7 @@ def test_start_exploration(agent_setup) -> None: assert "explor" in history[-1].content.lower() -@pytest.mark.integration +@pytest.mark.slow def test_go_to_semantic_location(agent_setup) -> None: history = agent_setup( blueprints=[ diff --git a/dimos/agents/skills/test_unitree_skill_container.py b/dimos/agents/skills/test_unitree_skill_container.py index ea1cfba5cf..dde7239bbd 100644 --- a/dimos/agents/skills/test_unitree_skill_container.py +++ b/dimos/agents/skills/test_unitree_skill_container.py @@ -29,7 +29,7 @@ def __init__(self): self._bound_rpc_calls["GO2Connection.publish_request"] = lambda *args, **kwargs: None -@pytest.mark.integration +@pytest.mark.slow def test_pounce(agent_setup) -> None: history = agent_setup( blueprints=[MockedUnitreeSkill.blueprint()], diff --git a/dimos/agents/test_agent.py b/dimos/agents/test_agent.py index da69dfb7dc..cd571a56ae 100644 --- a/dimos/agents/test_agent.py +++ b/dimos/agents/test_agent.py @@ -29,7 +29,7 @@ def add(self, x: int, y: int) -> str: return str(x + y) -@pytest.mark.integration +@pytest.mark.slow @pytest.mark.parametrize("dask", [False, True]) def test_can_call_tool(dask, agent_setup): history = agent_setup( @@ -68,7 +68,7 @@ def register_user(self, name: str) -> str: return "User name registered successfully." -@pytest.mark.integration +@pytest.mark.slow @pytest.mark.parametrize("dask", [False, True]) def test_can_call_again_on_error(dask, agent_setup): history = agent_setup( @@ -120,7 +120,7 @@ def go_to_location(self, description: str) -> str: return f"Going to the {description}." -@pytest.mark.integration +@pytest.mark.slow def test_multiple_tool_calls_with_multiple_messages(agent_setup): history = agent_setup( blueprints=[MultipleTools.blueprint(), NavigationSkill.blueprint()], @@ -174,7 +174,7 @@ def test_multiple_tool_calls_with_multiple_messages(agent_setup): assert len(go_to_location_calls) == 2 -@pytest.mark.integration +@pytest.mark.slow def test_prompt(agent_setup): history = agent_setup( blueprints=[], @@ -192,7 +192,7 @@ def take_a_picture(self) -> Image: return Image.from_file(get_data("cafe-smol.jpg")).to_rgb() -@pytest.mark.integration +@pytest.mark.slow def test_image(agent_setup): history = agent_setup( blueprints=[Visualizer.blueprint()], diff --git a/dimos/agents_deprecated/memory/test_image_embedding.py b/dimos/agents_deprecated/memory/test_image_embedding.py index 89f0716e7e..fd8cef696e 100644 --- a/dimos/agents_deprecated/memory/test_image_embedding.py +++ b/dimos/agents_deprecated/memory/test_image_embedding.py @@ -22,12 +22,13 @@ import numpy as np import pytest from reactivex import operators as ops +from reactivex.scheduler import ThreadPoolScheduler from dimos.agents_deprecated.memory.image_embedding import ImageEmbeddingProvider from dimos.stream.video_provider import VideoProvider -@pytest.mark.heavy +@pytest.mark.slow class TestImageEmbedding: """Test class for CLIP image embedding functionality.""" @@ -45,6 +46,7 @@ def test_clip_embedding_initialization(self) -> None: def test_clip_embedding_process_video(self) -> None: """Test CLIP embedding provider can process video frames and return embeddings.""" + test_scheduler = ThreadPoolScheduler(max_workers=4) try: from dimos.utils.data import get_data @@ -53,7 +55,9 @@ def test_clip_embedding_process_video(self) -> None: embedding_provider = ImageEmbeddingProvider(model_name="clip", dimensions=512) assert os.path.exists(video_path), f"Test video not found: {video_path}" - video_provider = VideoProvider(dev_name="test_video", video_source=video_path) + video_provider = VideoProvider( + dev_name="test_video", video_source=video_path, pool_scheduler=test_scheduler + ) video_stream = video_provider.capture_video_as_observable(realtime=False, fps=15) @@ -146,6 +150,8 @@ def on_completed() -> None: except Exception as e: pytest.fail(f"Test failed with error: {e}") + finally: + test_scheduler.executor.shutdown(wait=True) def test_clip_embedding_similarity(self) -> None: """Test CLIP embedding similarity search and text-to-image queries.""" @@ -205,7 +211,3 @@ def test_clip_embedding_similarity(self) -> None: except Exception as e: pytest.fail(f"Similarity test failed with error: {e}") - - -if __name__ == "__main__": - pytest.main(["-v", "--disable-warnings", __file__]) diff --git a/dimos/conftest.py b/dimos/conftest.py index 5d1ca2b860..2e8b558e6e 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import os import threading from dotenv import load_dotenv @@ -23,27 +24,28 @@ load_dotenv() -def _has_cuda(): - try: - import torch - except Exception: - return False - - try: - return bool(torch.cuda.is_available()) - except Exception: - return False +def pytest_configure(config): + config.addinivalue_line("markers", "tool: dev tooling") + config.addinivalue_line("markers", "slow: tests that are too slow for the fast loop") + config.addinivalue_line("markers", "mujoco: tests which open mujoco") + config.addinivalue_line("markers", "skipif_in_ci: skip when CI env var is set") + config.addinivalue_line("markers", "skipif_no_openai: skip when OPENAI_API_KEY is not set") + config.addinivalue_line("markers", "skipif_no_alibaba: skip when ALIBABA_API_KEY is not set") @pytest.hookimpl() def pytest_collection_modifyitems(config, items): - if not _has_cuda(): - skip_marker = pytest.mark.skip( - reason="CUDA is not available (torch.cuda.is_available() returned False)" - ) - for item in items: - if item.get_closest_marker("cuda"): - item.add_marker(skip_marker) + _skipif_markers = { + "skipif_in_ci": (bool(os.getenv("CI")), "Skipped in CI"), + "skipif_no_openai": (not os.getenv("OPENAI_API_KEY"), "OPENAI_API_KEY not set"), + "skipif_no_alibaba": (not os.getenv("ALIBABA_API_KEY"), "ALIBABA_API_KEY not set"), + } + for marker_name, (condition, reason) in _skipif_markers.items(): + if condition: + skip = pytest.mark.skip(reason=reason) + for item in items: + if item.get_closest_marker(marker_name): + item.add_marker(skip) @pytest.fixture @@ -70,8 +72,6 @@ def _autoconf(request): _seen_threads_lock = threading.RLock() _before_test_threads = {} # Map test name to set of thread IDs before test -_skip_for = ["lcm", "heavy", "ros"] - @pytest.fixture(scope="module") def dimos_cluster(): @@ -107,11 +107,6 @@ def pytest_sessionfinish(session): @pytest.fixture(autouse=True) def monitor_threads(request): - # Skip monitoring for tests marked with specified markers - if any(request.node.get_closest_marker(marker) for marker in _skip_for): - yield - return - # Capture threads before test runs test_name = request.node.nodeid with _seen_threads_lock: diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index 09144054c1..fd18fe72d8 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -175,7 +175,7 @@ def test_global_config() -> None: assert blueprint_set.global_config_overrides["option2"] == 42 -@pytest.mark.integration +@pytest.mark.slow def test_build_happy_path() -> None: pubsub.lcm.autoconf() @@ -286,7 +286,7 @@ class Module3(Module): blueprint_set_remapped._verify_no_name_conflicts() -@pytest.mark.integration +@pytest.mark.slow def test_remapping() -> None: """Test that remapping streams works correctly.""" pubsub.lcm.autoconf() @@ -355,7 +355,7 @@ def test_future_annotations_support() -> None: assert in_blueprint.streams[0] == StreamRef(name="data", type=FutureData, direction="in") -@pytest.mark.integration +@pytest.mark.slow def test_future_annotations_autoconnect() -> None: """Test that autoconnect works with modules using `from __future__ import annotations`.""" @@ -448,7 +448,7 @@ def start(self) -> None: def stop(self) -> None: ... -@pytest.mark.integration +@pytest.mark.slow def test_module_ref_direct() -> None: coordinator = autoconnect( Calculator1.blueprint(), @@ -464,7 +464,7 @@ def test_module_ref_direct() -> None: coordinator.stop() -@pytest.mark.integration +@pytest.mark.slow def test_module_ref_spec() -> None: coordinator = autoconnect( Calculator1.blueprint(), @@ -480,7 +480,7 @@ def test_module_ref_spec() -> None: coordinator.stop() -@pytest.mark.integration +@pytest.mark.slow def test_module_ref_remap_ambiguous() -> None: coordinator = ( autoconnect( diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index c229659b84..3866d55bdb 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -101,7 +101,8 @@ def test_classmethods() -> None: nav._close_module() -@pytest.mark.module +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_basic_deployment(dimos) -> None: robot = dimos.deploy(MockRobotClient) @@ -136,4 +137,7 @@ def test_basic_deployment(dimos) -> None: assert nav.odom_msg_count >= 8 assert nav.lidar_msg_count >= 8 - dimos.shutdown() + nav.stop() + nav.stop_rpc_client() + robot.stop_rpc_client() + dimos.close_all() diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index 8af63b0bf4..a022be0685 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -131,7 +131,7 @@ def test_manual(dimos_cluster: DimosCluster, args_file: str) -> None: } -@pytest.mark.heavy +@pytest.mark.slow def test_autoconnect(args_file: str) -> None: """autoconnect passes correct topic args to the native subprocess.""" blueprint = autoconnect( diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py index 836f879b67..7a594f64e4 100644 --- a/dimos/core/test_stream.py +++ b/dimos/core/test_stream.py @@ -13,6 +13,7 @@ # limitations under the License. from collections.abc import Callable +import threading import time import pytest @@ -21,6 +22,7 @@ In, LCMTransport, Module, + pLCMTransport, rpc, ) from dimos.core.testing import MockRobotClient, dimos @@ -37,14 +39,32 @@ class SubscriberBase(Module): def __init__(self) -> None: self.sub1_msgs = [] self.sub2_msgs = [] + self._sub1_received = threading.Event() + self._sub2_received = threading.Event() super().__init__() + def _sub1_callback(self, msg) -> None: + self.sub1_msgs.append(msg) + self._sub1_received.set() + + def _sub2_callback(self, msg) -> None: + self.sub2_msgs.append(msg) + self._sub2_received.set() + @rpc def sub1(self) -> None: ... @rpc def sub2(self) -> None: ... + @rpc + def wait_for_sub1_msg(self, timeout: float = 10) -> bool: + return self._sub1_received.wait(timeout) + + @rpc + def wait_for_sub2_msg(self, timeout: float = 10) -> bool: + return self._sub2_received.wait(timeout) + @rpc def active_subscribers(self): return self.odom.transport.active_subscribers @@ -65,14 +85,14 @@ class ClassicSubscriber(SubscriberBase): @rpc def sub1(self) -> None: - self.unsub = self.odom.subscribe(self.sub1_msgs.append) + self.unsub = self.odom.subscribe(self._sub1_callback) @rpc def sub2(self) -> None: - self.unsub2 = self.odom.subscribe(self.sub2_msgs.append) + self.unsub2 = self.odom.subscribe(self._sub2_callback) @rpc - def stop(self) -> None: + def unsub_all(self) -> None: if self.unsub: self.unsub() self.unsub = None @@ -90,14 +110,14 @@ class RXPYSubscriber(SubscriberBase): @rpc def sub1(self) -> None: - self.unsub = self.odom.observable().subscribe(self.sub1_msgs.append) + self.unsub = self.odom.observable().subscribe(self._sub1_callback) @rpc def sub2(self) -> None: - self.unsub2 = self.odom.observable().subscribe(self.sub2_msgs.append) + self.unsub2 = self.odom.observable().subscribe(self._sub2_callback) @rpc - def stop(self) -> None: + def unsub_all(self) -> None: if self.unsub: self.unsub.dispose() self.unsub = None @@ -153,12 +173,13 @@ def wrapped_unsubscribe() -> None: @pytest.mark.parametrize("subscriber_class", [ClassicSubscriber, RXPYSubscriber]) -@pytest.mark.module +@pytest.mark.slow def test_subscription(dimos, subscriber_class) -> None: robot = dimos.deploy(MockRobotClient) robot.lidar.transport = SpyLCMTransport("/lidar", PointCloud2) robot.odometry.transport = SpyLCMTransport("/odom", Odometry) + robot.mov.transport = pLCMTransport("/mov") subscriber = dimos.deploy(subscriber_class) @@ -166,16 +187,16 @@ def test_subscription(dimos, subscriber_class) -> None: robot.start() subscriber.sub1() - time.sleep(0.25) + subscriber.wait_for_sub1_msg() assert subscriber.sub1_msgs_len() > 0 assert subscriber.sub2_msgs_len() == 0 assert subscriber.active_subscribers() == 1 subscriber.sub2() + subscriber.wait_for_sub2_msg() - time.sleep(0.25) - subscriber.stop() + subscriber.unsub_all() assert subscriber.active_subscribers() == 0 assert subscriber.sub1_msgs_len() != 0 @@ -183,20 +204,24 @@ def test_subscription(dimos, subscriber_class) -> None: total_msg_n = subscriber.sub1_msgs_len() + subscriber.sub2_msgs_len() - time.sleep(0.25) + time.sleep(0.5) # ensuring no new messages have passed through assert total_msg_n == subscriber.sub1_msgs_len() + subscriber.sub2_msgs_len() robot.stop() + subscriber.stop_rpc_client() + robot.stop_rpc_client() + dimos.close_all() -@pytest.mark.module +@pytest.mark.slow def test_get_next(dimos) -> None: robot = dimos.deploy(MockRobotClient) robot.lidar.transport = SpyLCMTransport("/lidar", PointCloud2) robot.odometry.transport = SpyLCMTransport("/odom", Odometry) + robot.mov.transport = pLCMTransport("/mov") subscriber = dimos.deploy(RXPYSubscriber) subscriber.odom.connect(robot.odometry) @@ -218,14 +243,18 @@ def test_get_next(dimos) -> None: assert next_odom != odom robot.stop() + subscriber.stop_rpc_client() + robot.stop_rpc_client() + dimos.close_all() -@pytest.mark.module +@pytest.mark.slow def test_hot_getter(dimos) -> None: robot = dimos.deploy(MockRobotClient) robot.lidar.transport = SpyLCMTransport("/lidar", PointCloud2) robot.odometry.transport = SpyLCMTransport("/odom", Odometry) + robot.mov.transport = pLCMTransport("/mov") subscriber = dimos.deploy(RXPYSubscriber) subscriber.odom.connect(robot.odometry) @@ -254,3 +283,6 @@ def test_hot_getter(dimos) -> None: subscriber.stop_hot_getter() robot.stop() + subscriber.stop_rpc_client() + robot.stop_rpc_client() + dimos.close_all() diff --git a/dimos/core/test_worker.py b/dimos/core/test_worker.py index 98a7c5782d..6892d226fd 100644 --- a/dimos/core/test_worker.py +++ b/dimos/core/test_worker.py @@ -82,7 +82,7 @@ def worker_manager(): manager.close_all() -@pytest.mark.integration +@pytest.mark.slow def test_worker_manager_basic(worker_manager): module = worker_manager.deploy(SimpleModule) module.start() @@ -99,7 +99,7 @@ def test_worker_manager_basic(worker_manager): module.stop() -@pytest.mark.integration +@pytest.mark.slow def test_worker_manager_multiple_different_modules(worker_manager): module1 = worker_manager.deploy(SimpleModule) module2 = worker_manager.deploy(AnotherModule) @@ -120,7 +120,7 @@ def test_worker_manager_multiple_different_modules(worker_manager): module2.stop() -@pytest.mark.integration +@pytest.mark.slow def test_worker_manager_parallel_deployment(worker_manager): modules = worker_manager.deploy_parallel( [ diff --git a/dimos/e2e_tests/test_control_coordinator.py b/dimos/e2e_tests/test_control_coordinator.py index f6e520831d..5bb7a096f7 100644 --- a/dimos/e2e_tests/test_control_coordinator.py +++ b/dimos/e2e_tests/test_control_coordinator.py @@ -18,7 +18,6 @@ Unlike unit tests, these verify the full system integration. """ -import os import time import pytest @@ -29,8 +28,8 @@ from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint, TrajectoryState -@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM doesn't work in CI.") -@pytest.mark.e2e +@pytest.mark.skipif_in_ci +@pytest.mark.slow class TestControlCoordinatorE2E: """End-to-end tests for ControlCoordinator.""" diff --git a/dimos/e2e_tests/test_dimos_cli_e2e.py b/dimos/e2e_tests/test_dimos_cli_e2e.py index ede0ec7a3a..f27502d620 100644 --- a/dimos/e2e_tests/test_dimos_cli_e2e.py +++ b/dimos/e2e_tests/test_dimos_cli_e2e.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import pytest -@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM spy doesn't work in CI.") -@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.") -@pytest.mark.e2e +@pytest.mark.skipif_in_ci +@pytest.mark.skipif_no_openai +@pytest.mark.slow def test_dimos_skills(lcm_spy, start_blueprint, human_input) -> None: lcm_spy.save_topic("/agent") lcm_spy.save_topic("/rpc/Agent/on_system_modules/res") diff --git a/dimos/e2e_tests/test_person_follow.py b/dimos/e2e_tests/test_person_follow.py index abb9cfb4fa..090ee90f2a 100644 --- a/dimos/e2e_tests/test_person_follow.py +++ b/dimos/e2e_tests/test_person_follow.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Callable, Generator -import os import threading import time @@ -53,8 +52,8 @@ def run_person_track() -> None: publisher.stop() -@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM spy doesn't work in CI.") -@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.") +@pytest.mark.skipif_in_ci +@pytest.mark.skipif_no_openai @pytest.mark.mujoco def test_person_follow( lcm_spy: LcmSpy, diff --git a/dimos/e2e_tests/test_simulation_module.py b/dimos/e2e_tests/test_simulation_module.py index 6c15f62056..b5902ad7e2 100644 --- a/dimos/e2e_tests/test_simulation_module.py +++ b/dimos/e2e_tests/test_simulation_module.py @@ -14,8 +14,6 @@ """End-to-end tests for the simulation module.""" -import os - import pytest from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState @@ -31,8 +29,8 @@ def _positions_within_tolerance( return all(abs(positions[i] - target[i]) <= tolerance for i in range(len(target))) -@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM doesn't work in CI.") -@pytest.mark.e2e +@pytest.mark.skipif_in_ci +@pytest.mark.slow class TestSimulationModuleE2E: def test_xarm7_joint_state_published(self, lcm_spy, start_blueprint) -> None: joint_state_topic = "/xarm/joint_states#sensor_msgs.JointState" diff --git a/dimos/e2e_tests/test_spatial_memory.py b/dimos/e2e_tests/test_spatial_memory.py index 8b03a9915c..ad22368678 100644 --- a/dimos/e2e_tests/test_spatial_memory.py +++ b/dimos/e2e_tests/test_spatial_memory.py @@ -14,7 +14,6 @@ from collections.abc import Callable import math -import os import time import pytest @@ -23,8 +22,8 @@ from dimos.e2e_tests.lcm_spy import LcmSpy -@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM spy doesn't work in CI.") -@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.") +@pytest.mark.skipif_in_ci +@pytest.mark.skipif_no_openai @pytest.mark.mujoco def test_spatial_memory_navigation( lcm_spy: LcmSpy, diff --git a/dimos/manipulation/planning/monitor/world_obstacle_monitor.py b/dimos/manipulation/planning/monitor/world_obstacle_monitor.py index 6082ab93a9..a96d3efaf6 100644 --- a/dimos/manipulation/planning/monitor/world_obstacle_monitor.py +++ b/dimos/manipulation/planning/monitor/world_obstacle_monitor.py @@ -555,7 +555,7 @@ def list_added_obstacles(self) -> list[dict[str, Any]]: entry = self._object_cache.get(oid) if entry is None: continue - obj, first_seen, last_seen = entry + obj, _first_seen, _last_seen = entry if not isinstance(obj, Object): continue result.append( diff --git a/dimos/mapping/occupancy/test_extrude_occupancy.py b/dimos/mapping/occupancy/test_extrude_occupancy.py index 88f05d7780..76d714a2a1 100644 --- a/dimos/mapping/occupancy/test_extrude_occupancy.py +++ b/dimos/mapping/occupancy/test_extrude_occupancy.py @@ -18,7 +18,7 @@ from dimos.utils.data import get_data -@pytest.mark.integration +@pytest.mark.slow def test_generate_mujoco_scene(occupancy) -> None: with open(get_data("expected_occupancy_scene.xml")) as f: expected = f.read() diff --git a/dimos/memory/timeseries/test_legacy.py b/dimos/memory/timeseries/test_legacy.py index aaad962a95..c77ec64a76 100644 --- a/dimos/memory/timeseries/test_legacy.py +++ b/dimos/memory/timeseries/test_legacy.py @@ -13,12 +13,16 @@ # limitations under the License. """Tests specific to LegacyPickleStore.""" +import pytest + from dimos.memory.timeseries.legacy import LegacyPickleStore class TestLegacyPickleStoreRealData: """Test LegacyPickleStore with real recorded data.""" + @pytest.mark.skipif_in_ci + @pytest.mark.slow def test_read_lidar_recording(self) -> None: """Test reading from unitree_go2_bigoffice/lidar recording.""" store = LegacyPickleStore("unitree_go2_bigoffice/lidar") diff --git a/dimos/models/embedding/test_embedding.py b/dimos/models/embedding/test_embedding.py index a87a2f5a57..466c974b32 100644 --- a/dimos/models/embedding/test_embedding.py +++ b/dimos/models/embedding/test_embedding.py @@ -20,7 +20,8 @@ ], ids=["clip", "mobileclip", "treid"], ) -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_embedding_model(model_class: type, model_name: str, supports_text: bool) -> None: """Test embedding functionality across different model types.""" image = Image.from_file(get_data("cafe.jpg")).to_rgb() @@ -94,7 +95,8 @@ def test_embedding_model(model_class: type, model_name: str, supports_text: bool ], ids=["clip", "mobileclip"], ) -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_text_image_retrieval(model_class: type, model_name: str) -> None: """Test text-to-image retrieval using embedding similarity.""" image = Image.from_file(get_data("cafe.jpg")).to_rgb() @@ -126,7 +128,8 @@ def test_text_image_retrieval(model_class: type, model_name: str) -> None: print(f"\n{model_name} retrieval test passed!") -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_embedding_device_transfer() -> None: """Test embedding device transfer operations.""" image = Image.from_file(get_data("cafe.jpg")).to_rgb() diff --git a/dimos/models/vl/test_base.py b/dimos/models/vl/test_base.py index a7296bd87b..2e4229d944 100644 --- a/dimos/models/vl/test_base.py +++ b/dimos/models/vl/test_base.py @@ -1,4 +1,3 @@ -import os from unittest.mock import MagicMock from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations @@ -78,7 +77,7 @@ def test_query_detections_mocked() -> None: @pytest.mark.tool -@pytest.mark.skipif(not os.getenv("ALIBABA_API_KEY"), reason="ALIBABA_API_KEY not set") +@pytest.mark.skipif_no_alibaba def test_query_detections_real() -> None: """Test query_detections with real API calls (requires API key).""" # Load test image diff --git a/dimos/models/vl/test_captioner.py b/dimos/models/vl/test_captioner.py index 081f3bcefc..c7ebb8fc63 100644 --- a/dimos/models/vl/test_captioner.py +++ b/dimos/models/vl/test_captioner.py @@ -44,7 +44,8 @@ def florence2_model(request: pytest.FixtureRequest) -> Generator[Florence2Model, yield from generic_model_fixture(request.param) -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_captioner(captioner_model: CaptionerModel, test_image: Image) -> None: """Test captioning functionality across different model types.""" # Test single caption @@ -72,7 +73,8 @@ def test_captioner(captioner_model: CaptionerModel, test_image: Image) -> None: assert all(isinstance(c, str) and len(c) > 0 for c in captions) -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_florence2_detail_levels(florence2_model: Florence2Model, test_image: Image) -> None: """Test Florence-2 different detail levels.""" detail_levels = ["brief", "normal", "detailed", "more_detailed"] diff --git a/dimos/models/vl/test_vlm.py b/dimos/models/vl/test_vlm.py index 741e0dede2..54ceddadc5 100644 --- a/dimos/models/vl/test_vlm.py +++ b/dimos/models/vl/test_vlm.py @@ -32,7 +32,8 @@ (QwenVlModel, "Qwen"), ], ) -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_vlm_bbox_detections(model_class: "type[VlModel]", model_name: str) -> None: if model_class is MoondreamHostedVlModel and 'MOONDREAM_API_KEY' not in os.environ: pytest.skip("Need MOONDREAM_API_KEY to run") @@ -104,7 +105,8 @@ def test_vlm_bbox_detections(model_class: "type[VlModel]", model_name: str) -> N (QwenVlModel, "Qwen"), ], ) -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_vlm_point_detections(model_class: "type[VlModel]", model_name: str) -> None: """Test VLM point detection capabilities.""" @@ -172,7 +174,8 @@ def test_vlm_point_detections(model_class: "type[VlModel]", model_name: str) -> (MoondreamVlModel, "Moondream"), ], ) -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_vlm_query_multi(model_class: "type[VlModel]", model_name: str) -> None: """Test query_multi optimization - single image, multiple queries.""" image = Image.from_file(get_data("cafe.jpg")).to_rgb() @@ -222,7 +225,7 @@ def test_vlm_query_multi(model_class: "type[VlModel]", model_name: str) -> None: ], ) @pytest.mark.tool -@pytest.mark.gpu +@pytest.mark.slow def test_vlm_query_batch(model_class: "type[VlModel]", model_name: str) -> None: """Test query_batch optimization - multiple images, same query.""" from dimos.utils.testing import TimedSensorReplay @@ -275,7 +278,8 @@ def test_vlm_query_batch(model_class: "type[VlModel]", model_name: str) -> None: (QwenVlModel, [None, (512, 512), (256, 256)]), ], ) -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_vlm_resize( model_class: "type[VlModel]", sizes: list[tuple[int, int] | None], diff --git a/dimos/msgs/geometry_msgs/test_TwistStamped.py b/dimos/msgs/geometry_msgs/test_TwistStamped.py index afb8489032..d37b4b2717 100644 --- a/dimos/msgs/geometry_msgs/test_TwistStamped.py +++ b/dimos/msgs/geometry_msgs/test_TwistStamped.py @@ -52,15 +52,3 @@ def test_pickle_encode_decode() -> None: assert isinstance(twist_dest, TwistStamped) assert twist_dest is not twist_source assert twist_dest == twist_source - - -if __name__ == "__main__": - print("Running test_lcm_encode_decode...") - test_lcm_encode_decode() - print("test_lcm_encode_decode passed") - - print("Running test_pickle_encode_decode...") - test_pickle_encode_decode() - print("test_pickle_encode_decode passed") - - print("\nAll tests passed!") diff --git a/dimos/msgs/nav_msgs/test_OccupancyGrid.py b/dimos/msgs/nav_msgs/test_OccupancyGrid.py index 29ef196de8..d1ec8938b4 100644 --- a/dimos/msgs/nav_msgs/test_OccupancyGrid.py +++ b/dimos/msgs/nav_msgs/test_OccupancyGrid.py @@ -26,7 +26,6 @@ from dimos.msgs.geometry_msgs import Pose from dimos.msgs.nav_msgs import OccupancyGrid from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic from dimos.utils.data import get_data @@ -364,107 +363,3 @@ def test_max() -> None: assert maxed.unknown_cells == 3 # Same as original assert maxed.occupied_cells == 13 # All non-unknown cells assert maxed.free_cells == 0 # No free cells - - -@pytest.mark.lcm -def test_lcm_broadcast() -> None: - """Test broadcasting OccupancyGrid and gradient over LCM.""" - file_path = get_data("lcm_msgs") / "sensor_msgs/PointCloud2.pickle" - with open(file_path, "rb") as f: - lcm_msg = pickle.loads(f.read()) - - pointcloud = PointCloud2.lcm_decode(lcm_msg) - - # Create occupancy grid from pointcloud - occupancygrid = general_occupancy(pointcloud, resolution=0.05, min_height=0.1, max_height=2.0) - # Apply inflation separately if needed - occupancygrid = simple_inflate(occupancygrid, 0.1) - - # Create gradient field with larger max_distance for better visualization - gradient_grid = gradient(occupancygrid, obstacle_threshold=70, max_distance=2.0) - - # Debug: Print actual values to see the difference - print("\n=== DEBUG: Comparing grids ===") - print(f"Original grid unique values: {np.unique(occupancygrid.grid)}") - print(f"Gradient grid unique values: {np.unique(gradient_grid.grid)}") - - # Find an area with occupied cells to show the difference - occupied_indices = np.argwhere(occupancygrid.grid == 100) - if len(occupied_indices) > 0: - # Pick a point near an occupied cell - idx = len(occupied_indices) // 2 # Middle occupied cell - sample_y, sample_x = occupied_indices[idx] - sample_size = 15 - - # Ensure we don't go out of bounds - y_start = max(0, sample_y - sample_size // 2) - y_end = min(occupancygrid.height, y_start + sample_size) - x_start = max(0, sample_x - sample_size // 2) - x_end = min(occupancygrid.width, x_start + sample_size) - - print(f"\nSample area around occupied cell ({sample_x}, {sample_y}):") - print("Original occupancy grid:") - print(occupancygrid.grid[y_start:y_end, x_start:x_end]) - print("\nGradient grid (same area):") - print(gradient_grid.grid[y_start:y_end, x_start:x_end]) - else: - print("\nNo occupied cells found for sampling") - - # Check statistics - print("\nOriginal grid stats:") - print(f" Occupied (100): {np.sum(occupancygrid.grid == 100)} cells") - print(f" Inflated (99): {np.sum(occupancygrid.grid == 99)} cells") - print(f" Free (0): {np.sum(occupancygrid.grid == 0)} cells") - print(f" Unknown (-1): {np.sum(occupancygrid.grid == -1)} cells") - - print("\nGradient grid stats:") - print(f" Max gradient (100): {np.sum(gradient_grid.grid == 100)} cells") - print( - f" High gradient (80-99): {np.sum((gradient_grid.grid >= 80) & (gradient_grid.grid < 100))} cells" - ) - print( - f" Medium gradient (40-79): {np.sum((gradient_grid.grid >= 40) & (gradient_grid.grid < 80))} cells" - ) - print( - f" Low gradient (1-39): {np.sum((gradient_grid.grid >= 1) & (gradient_grid.grid < 40))} cells" - ) - print(f" Zero gradient (0): {np.sum(gradient_grid.grid == 0)} cells") - print(f" Unknown (-1): {np.sum(gradient_grid.grid == -1)} cells") - - # # Save debug images - # import matplotlib.pyplot as plt - - # fig, axes = plt.subplots(1, 2, figsize=(12, 5)) - - # # Original - # ax = axes[0] - # im1 = ax.imshow(occupancygrid.grid, origin="lower", cmap="gray_r", vmin=-1, vmax=100) - # ax.set_title(f"Original Occupancy Grid\n{occupancygrid}") - # plt.colorbar(im1, ax=ax) - - # # Gradient - # ax = axes[1] - # im2 = ax.imshow(gradient_grid.grid, origin="lower", cmap="hot", vmin=-1, vmax=100) - # ax.set_title(f"Gradient Grid\n{gradient_grid}") - # plt.colorbar(im2, ax=ax) - - # plt.tight_layout() - # plt.savefig("lcm_debug_grids.png", dpi=150) - # print("\nSaved debug visualization to lcm_debug_grids.png") - # plt.close() - - # Broadcast all the data - lcm = LCM() - lcm.start() - lcm.publish(Topic("/global_map", PointCloud2), pointcloud) - lcm.publish(Topic("/global_costmap", OccupancyGrid), occupancygrid) - lcm.publish(Topic("/global_gradient", OccupancyGrid), gradient_grid) - - print("\nPublished to LCM:") - print(f" /global_map: PointCloud2 with {len(pointcloud)} points") - print(f" /global_costmap: {occupancygrid}") - print(f" /global_gradient: {gradient_grid}") - print("\nGradient info:") - print(" Values: 0 (free far from obstacles) -> 100 (at obstacles)") - print(f" Unknown cells: {gradient_grid.unknown_cells} (preserved as -1)") - print(" Max distance for gradient: 5.0 meters") diff --git a/dimos/msgs/sensor_msgs/test_PointCloud2.py b/dimos/msgs/sensor_msgs/test_PointCloud2.py index 501a4cd441..f48802ab7a 100644 --- a/dimos/msgs/sensor_msgs/test_PointCloud2.py +++ b/dimos/msgs/sensor_msgs/test_PointCloud2.py @@ -153,8 +153,3 @@ def test_bounding_box_intersects() -> None: pass print("✓ All bounding box intersection tests passed!") - - -if __name__ == "__main__": - test_lcm_encode_decode() - test_bounding_box_intersects() diff --git a/dimos/perception/detection/reid/test_embedding_id_system.py b/dimos/perception/detection/reid/test_embedding_id_system.py index b9e6f591ee..cc8632627f 100644 --- a/dimos/perception/detection/reid/test_embedding_id_system.py +++ b/dimos/perception/detection/reid/test_embedding_id_system.py @@ -46,7 +46,8 @@ def test_image(): return Image.from_file(get_data("cafe.jpg")).to_rgb() -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_update_embedding_single(track_associator, mobileclip_model, test_image) -> None: """Test updating embedding for a single track.""" embedding = mobileclip_model.embed(test_image) @@ -64,7 +65,8 @@ def test_update_embedding_single(track_associator, mobileclip_model, test_image) assert abs(norm - 1.0) < 0.01, "Embedding should be normalized" -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_update_embedding_multiple(track_associator, mobileclip_model, test_image) -> None: """Test storing multiple embeddings per track.""" embedding1 = mobileclip_model.embed(test_image) @@ -91,7 +93,8 @@ def test_update_embedding_multiple(track_associator, mobileclip_model, test_imag assert similarity > 0.99, "Same image should produce very similar embeddings" -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_negative_constraints(track_associator) -> None: """Test negative constraint recording.""" # Simulate frame with 3 tracks @@ -107,7 +110,8 @@ def test_negative_constraints(track_associator) -> None: assert 2 in track_associator.negative_pairs[3] -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_associate_new_track(track_associator, mobileclip_model, test_image) -> None: """Test associating a new track creates new long_term_id.""" embedding = mobileclip_model.embed(test_image) @@ -121,7 +125,8 @@ def test_associate_new_track(track_associator, mobileclip_model, test_image) -> assert track_associator.long_term_counter == 1 -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_associate_similar_tracks(track_associator, mobileclip_model, test_image) -> None: """Test associating similar tracks to same long_term_id.""" # Create embeddings from same image (should be very similar) @@ -141,7 +146,8 @@ def test_associate_similar_tracks(track_associator, mobileclip_model, test_image assert track_associator.long_term_counter == 1, "Only one long_term_id should be created" -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_associate_with_negative_constraint(track_associator, mobileclip_model, test_image) -> None: """Test that negative constraints prevent association.""" # Create similar embeddings @@ -166,7 +172,8 @@ def test_associate_with_negative_constraint(track_associator, mobileclip_model, assert track_associator.long_term_counter == 2, "Two long_term_ids should be created" -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_associate_different_objects(track_associator, mobileclip_model, test_image) -> None: """Test that dissimilar embeddings get different long_term_ids.""" # Create embeddings for image and text (very different) @@ -186,7 +193,8 @@ def test_associate_different_objects(track_associator, mobileclip_model, test_im assert track_associator.long_term_counter == 2 -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_associate_returns_cached(track_associator, mobileclip_model, test_image) -> None: """Test that repeated calls return same long_term_id.""" embedding = mobileclip_model.embed(test_image) @@ -202,7 +210,8 @@ def test_associate_returns_cached(track_associator, mobileclip_model, test_image assert track_associator.long_term_counter == 1, "Should not create new ID" -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_associate_no_embedding(track_associator) -> None: """Test that associate creates new ID for track without embedding.""" # Track with no embedding gets assigned a new ID @@ -211,7 +220,8 @@ def test_associate_no_embedding(track_associator) -> None: assert track_associator.long_term_counter == 1 -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_embeddings_stored_as_numpy(track_associator, mobileclip_model, test_image) -> None: """Test that embeddings are stored as numpy arrays for efficient CPU comparisons.""" embedding = mobileclip_model.embed(test_image) @@ -232,7 +242,8 @@ def test_embeddings_stored_as_numpy(track_associator, mobileclip_model, test_ima assert isinstance(emb, np.ndarray) -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_similarity_threshold_configurable(mobileclip_model) -> None: """Test that similarity threshold is configurable.""" associator_strict = EmbeddingIDSystem(model=lambda: mobileclip_model, similarity_threshold=0.95) @@ -242,7 +253,8 @@ def test_similarity_threshold_configurable(mobileclip_model) -> None: assert associator_loose.similarity_threshold == 0.50 -@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.skipif_in_ci def test_multi_track_scenario(track_associator, mobileclip_model, test_image) -> None: """Test realistic scenario with multiple tracks across frames.""" # Frame 1: Track 1 appears diff --git a/dimos/perception/detection/test_moduleDB.py b/dimos/perception/detection/test_moduleDB.py deleted file mode 100644 index 23885a1c60..0000000000 --- a/dimos/perception/detection/test_moduleDB.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import time - -from lcm_msgs.foxglove_msgs import SceneUpdate -import pytest - -from dimos.core import LCMTransport -from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.vision_msgs import Detection2DArray -from dimos.perception.detection.moduleDB import ObjectDBModule -from dimos.robot.unitree.go2 import connection as go2_connection - - -@pytest.mark.module -def test_moduleDB(dimos_cluster) -> None: - connection = go2_connection.deploy(dimos_cluster, "fake") - - moduleDB = dimos_cluster.deploy( - ObjectDBModule, - camera_info=go2_connection._camera_info_static(), - goto=lambda obj_id: print(f"Going to {obj_id}"), - ) - moduleDB.image.connect(connection.color_image) - moduleDB.pointcloud.connect(connection.lidar) - - moduleDB.annotations.transport = LCMTransport("/annotations", ImageAnnotations) - moduleDB.detections.transport = LCMTransport("/detections", Detection2DArray) - - moduleDB.detected_pointcloud_0.transport = LCMTransport("/detected/pointcloud/0", PointCloud2) - moduleDB.detected_pointcloud_1.transport = LCMTransport("/detected/pointcloud/1", PointCloud2) - moduleDB.detected_pointcloud_2.transport = LCMTransport("/detected/pointcloud/2", PointCloud2) - - moduleDB.detected_image_0.transport = LCMTransport("/detected/image/0", Image) - moduleDB.detected_image_1.transport = LCMTransport("/detected/image/1", Image) - moduleDB.detected_image_2.transport = LCMTransport("/detected/image/2", Image) - - moduleDB.scene_update.transport = LCMTransport("/scene_update", SceneUpdate) - moduleDB.target.transport = LCMTransport("/target", PoseStamped) - - connection.start() - moduleDB.start() - - time.sleep(4) - print("VLM RES", moduleDB.navigate_to_object_in_view("white floor")) - time.sleep(30) diff --git a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py index 1d0dab007b..ef584a2527 100644 --- a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py +++ b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py @@ -76,9 +76,9 @@ def stop(self) -> None: logger.info("VideoReplayModule stopped") -@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM replay + dataset not CI-safe.") -@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.") -@pytest.mark.neverending +@pytest.mark.skipif_in_ci +@pytest.mark.skipif_no_openai +@pytest.mark.slow class TestTemporalMemoryModule: @pytest.fixture(scope="function") def temp_dir(self): @@ -221,7 +221,3 @@ async def test_temporal_memory_module_with_replay( assert (output_path / "frames_index.jsonl").exists(), "frames_index.jsonl should exist" logger.info("All temporal memory module tests passed!") - - -if __name__ == "__main__": - pytest.main(["-v", "-s", __file__]) diff --git a/dimos/perception/test_spatial_memory.py b/dimos/perception/test_spatial_memory.py index d4b188ced3..433896aefe 100644 --- a/dimos/perception/test_spatial_memory.py +++ b/dimos/perception/test_spatial_memory.py @@ -20,13 +20,14 @@ import numpy as np import pytest from reactivex import operators as ops +from reactivex.scheduler import ThreadPoolScheduler from dimos.msgs.geometry_msgs import Pose from dimos.perception.spatial_perception import SpatialMemory from dimos.stream.video_provider import VideoProvider -@pytest.mark.heavy +@pytest.mark.slow class TestSpatialMemory: @pytest.fixture(scope="class") def temp_dir(self): @@ -87,6 +88,7 @@ def test_image_embedding(self, spatial_memory) -> None: def test_spatial_memory_processing(self, spatial_memory, temp_dir) -> None: """Test processing video frames and building spatial memory with CLIP embeddings.""" + test_scheduler = ThreadPoolScheduler(max_workers=4) try: # Use the shared spatial_memory fixture memory = spatial_memory @@ -95,7 +97,9 @@ def test_spatial_memory_processing(self, spatial_memory, temp_dir) -> None: video_path = get_data("assets") / "trimmed_video_office.mov" assert os.path.exists(video_path), f"Test video not found: {video_path}" - video_provider = VideoProvider(dev_name="test_video", video_source=video_path) + video_provider = VideoProvider( + dev_name="test_video", video_source=video_path, pool_scheduler=test_scheduler + ) video_stream = video_provider.capture_video_as_observable(realtime=False, fps=15) # Create a frame counter for position generation @@ -196,7 +200,4 @@ def on_completed() -> None: pytest.fail(f"Error in test: {e}") finally: video_provider.dispose_all() - - -if __name__ == "__main__": - pytest.main(["-v", __file__]) + test_scheduler.executor.shutdown(wait=True) diff --git a/dimos/perception/test_spatial_memory_module.py b/dimos/perception/test_spatial_memory_module.py index 98ec7a1212..a8c42d4f0e 100644 --- a/dimos/perception/test_spatial_memory_module.py +++ b/dimos/perception/test_spatial_memory_module.py @@ -22,6 +22,7 @@ from dimos import core from dimos.core import Module, Out, rpc +from dimos.msgs.geometry_msgs import Transform from dimos.msgs.sensor_msgs import Image from dimos.perception.spatial_perception import SpatialMemory from dimos.robot.unitree.type.odometry import Odometry @@ -70,29 +71,31 @@ def stop(self) -> None: class OdometryReplayModule(Module): - """Module that replays odometry data from TimedSensorReplay.""" - - odom_out: Out[Odometry] + """Module that replays odometry data and publishes to the tf system.""" def __init__(self, odom_path: str) -> None: super().__init__() self.odom_path = odom_path self._subscription = None + def _publish_tf(self, odom: Odometry) -> None: + """Convert odometry to TF transforms and publish.""" + self.tf.publish(Transform.from_pose("base_link", odom)) + @rpc def start(self) -> None: """Start replaying odometry data.""" # Use TimedSensorReplay to replay odometry odom_replay = TimedSensorReplay(self.odom_path, autocast=Odometry.from_msg) - # Subscribe to the replay stream and publish to LCM + # Subscribe to the replay stream and publish to tf self._subscription = ( odom_replay.stream() .pipe( ops.sample(0.5), # Sample every 500ms ops.take(10), # Only take 10 odometry updates total ) - .subscribe(self.odom_out.publish) + .subscribe(self._publish_tf) ) logger.info("OdometryReplayModule started") @@ -106,8 +109,8 @@ def stop(self) -> None: logger.info("OdometryReplayModule stopped") -@pytest.mark.gpu -@pytest.mark.neverending +@pytest.mark.slow +@pytest.mark.skipif_in_ci class TestSpatialMemoryModule: @pytest.fixture(scope="function") def temp_dir(self): @@ -135,9 +138,8 @@ async def test_spatial_memory_module_with_replay(self, temp_dir): video_module = dimos.deploy(VideoReplayModule, video_path) video_module.video_out.transport = core.LCMTransport("/test_video", Image) - # Odometry replay module + # Odometry replay module (publishes to tf system directly) odom_module = dimos.deploy(OdometryReplayModule, odom_path) - odom_module.odom_out.transport = core.LCMTransport("/test_odom", Odometry) # Spatial memory module spatial_memory = dimos.deploy( @@ -153,9 +155,8 @@ async def test_spatial_memory_module_with_replay(self, temp_dir): output_dir=os.path.join(temp_dir, "images"), ) - # Connect streams - spatial_memory.video.connect(video_module.video_out) - spatial_memory.odom.connect(odom_module.odom_out) + # Connect video stream + spatial_memory.color_image.connect(video_module.video_out) # Start all modules video_module.start() @@ -209,19 +210,12 @@ async def test_spatial_memory_module_with_replay(self, temp_dir): video_module.stop() odom_module.stop() - logger.info("Stopped replay modules") + spatial_memory.stop() + logger.info("Stopped all modules") logger.info("All spatial memory module tests passed!") finally: # Cleanup if "dimos" in locals(): - dimos.close() - - -if __name__ == "__main__": - pytest.main(["-v", "-s", __file__]) - # test = TestSpatialMemoryModule() - # asyncio.run( - # test.test_spatial_memory_module_with_replay(tempfile.mkdtemp(prefix="spatial_memory_test_")) - # ) + dimos.close_all() diff --git a/dimos/protocol/pubsub/impl/test_rospubsub.py b/dimos/protocol/pubsub/impl/test_rospubsub.py index 6cf49c37b2..9add4ef893 100644 --- a/dimos/protocol/pubsub/impl/test_rospubsub.py +++ b/dimos/protocol/pubsub/impl/test_rospubsub.py @@ -50,7 +50,6 @@ def subscriber() -> Generator[DimosROS, None, None]: yield from ros_node() -@pytest.mark.ros def test_basic_conversion(publisher, subscriber): """Test Vector3 publish/subscribe through ROS. @@ -76,7 +75,7 @@ def callback(msg, t): assert msg.z == 3.0 -@pytest.mark.ros +@pytest.mark.slow def test_pointcloud2_pubsub(publisher, subscriber): """Test PointCloud2 publish/subscribe through ROS. @@ -133,7 +132,6 @@ def callback(msg, t): assert abs(original.ts - converted.ts) < 0.001 -@pytest.mark.ros def test_pointcloud2_empty_pubsub(publisher, subscriber): """Test empty PointCloud2 publish/subscribe. @@ -162,7 +160,6 @@ def callback(msg, t): assert len(received[0]) == 0 -@pytest.mark.ros def test_posestamped_pubsub(publisher, subscriber): """Test PoseStamped publish/subscribe through ROS. @@ -203,7 +200,6 @@ def callback(msg, t): np.testing.assert_allclose(converted.orientation.w, original.orientation.w, rtol=1e-5) -@pytest.mark.ros def test_pointstamped_pubsub(publisher, subscriber): """Test PointStamped publish/subscribe through ROS. @@ -246,7 +242,6 @@ def callback(msg, t): assert converted.point.z == original.point.z -@pytest.mark.ros def test_twist_pubsub(publisher, subscriber): """Test Twist publish/subscribe through ROS. diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index 26c1cf0357..f79145642a 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -319,7 +319,7 @@ async def consume_messages() -> None: assert received_messages == messages_to_send -@pytest.mark.integration +@pytest.mark.slow @pytest.mark.parametrize("pubsub_context, topic, values", testdata) def test_high_volume_messages( pubsub_context: Callable[[], Any], topic: Any, values: list[Any] diff --git a/dimos/protocol/rpc/test_spec.py b/dimos/protocol/rpc/test_spec.py index c29db13703..b5189c04bf 100644 --- a/dimos/protocol/rpc/test_spec.py +++ b/dimos/protocol/rpc/test_spec.py @@ -293,7 +293,7 @@ def callback(val) -> None: unsub_server() -@pytest.mark.integration +@pytest.mark.slow @pytest.mark.parametrize("rpc_context, impl_name", testdata) def test_timeout(rpc_context, impl_name: str) -> None: """Test that RPC calls properly timeout.""" @@ -392,8 +392,3 @@ def make_call(a, b) -> None: finally: unsub() - - -if __name__ == "__main__": - # Run tests for debugging - pytest.main([__file__, "-v"]) diff --git a/dimos/robot/drone/test_drone.py b/dimos/robot/drone/test_drone.py index d9075beae3..7381359f5a 100644 --- a/dimos/robot/drone/test_drone.py +++ b/dimos/robot/drone/test_drone.py @@ -1029,7 +1029,3 @@ def test_velocity_from_bbox_center_error(self) -> None: self.assertGreater(vy, 0) # No vertical offset -> vx should be ~0 self.assertAlmostEqual(vx, 0, places=1) - - -if __name__ == "__main__": - unittest.main() diff --git a/dimos/robot/test_all_blueprints.py b/dimos/robot/test_all_blueprints.py index 16f657393b..6c2d000ca8 100644 --- a/dimos/robot/test_all_blueprints.py +++ b/dimos/robot/test_all_blueprints.py @@ -25,7 +25,7 @@ } -@pytest.mark.integration +@pytest.mark.slow @pytest.mark.parametrize("blueprint_name", all_blueprints.keys()) def test_all_blueprints_are_valid(blueprint_name: str) -> None: """Test that all blueprints in all_blueprints are valid Blueprint instances.""" diff --git a/dimos/robot/unitree/testing/test_actors.py b/dimos/robot/unitree/testing/test_actors.py index 9366092eb6..0fee2175fc 100644 --- a/dimos/robot/unitree/testing/test_actors.py +++ b/dimos/robot/unitree/testing/test_actors.py @@ -102,12 +102,6 @@ def test_mapper_start(dimos) -> None: print("start res", mapper.start().result()) -if __name__ == "__main__": - dimos = core.start(2) - test_basic(dimos) - test_mapper_start(dimos) - - @pytest.mark.tool def test_counter(dimos) -> None: counter = dimos.deploy(Counter) diff --git a/dimos/types/test_weaklist.py b/dimos/types/test_weaklist.py index 06f9f851ce..447b2fdd9a 100644 --- a/dimos/types/test_weaklist.py +++ b/dimos/types/test_weaklist.py @@ -54,7 +54,7 @@ def test_weaklist_basic_operations() -> None: assert SampleObject(4) not in wl -@pytest.mark.integration +@pytest.mark.slow def test_weaklist_auto_removal() -> None: """Test that objects are automatically removed when garbage collected.""" wl = WeakList() @@ -137,7 +137,7 @@ def test_weaklist_clear() -> None: assert obj1 not in wl -@pytest.mark.integration +@pytest.mark.slow def test_weaklist_iteration_during_modification() -> None: """Test that iteration works even if objects are deleted during iteration.""" wl = WeakList() diff --git a/dimos/utils/cli/lcmspy/test_lcmspy.py b/dimos/utils/cli/lcmspy/test_lcmspy.py index 530f081f29..13e6306c10 100644 --- a/dimos/utils/cli/lcmspy/test_lcmspy.py +++ b/dimos/utils/cli/lcmspy/test_lcmspy.py @@ -20,28 +20,46 @@ from dimos.utils.cli.lcmspy.lcmspy import GraphLCMSpy, GraphTopic, LCMSpy, Topic as TopicSpy -@pytest.mark.lcm -def test_spy_basic() -> None: +@pytest.fixture +def pickle_lcm(): lcm = PickleLCM(autoconf=True) lcm.start() + yield lcm + lcm.stop() - lcmspy = LCMSpy(autoconf=True) - lcmspy.start() +@pytest.fixture +def lcmspy_instance(): + spy = LCMSpy(autoconf=True) + spy.start() + yield spy + spy.stop() + + +@pytest.fixture +def graph_lcmspy_instance(): + spy = GraphLCMSpy(autoconf=True, graph_log_window=0.1) + spy.start() + time.sleep(0.2) # Wait for thread to start + yield spy + spy.stop() + + +def test_spy_basic(pickle_lcm, lcmspy_instance) -> None: video_topic = Topic(topic="/video") odom_topic = Topic(topic="/odom") for i in range(5): - lcm.publish(video_topic, f"video frame {i}") + pickle_lcm.publish(video_topic, f"video frame {i}") time.sleep(0.1) if i % 2 == 0: - lcm.publish(odom_topic, f"odometry data {i / 2}") + pickle_lcm.publish(odom_topic, f"odometry data {i / 2}") # Wait a bit for messages to be processed time.sleep(0.5) # Test statistics for video topic - video_topic_spy = lcmspy.topic["/video"] + video_topic_spy = lcmspy_instance.topic["/video"] assert video_topic_spy is not None # Test frequency (should be around 10 Hz for 5 messages in ~0.5 seconds) @@ -60,7 +78,7 @@ def test_spy_basic() -> None: print(f"Video topic average message size: {avg_size:.2f} bytes") # Test statistics for odom topic - odom_topic_spy = lcmspy.topic["/odom"] + odom_topic_spy = lcmspy_instance.topic["/odom"] assert odom_topic_spy is not None freq = odom_topic_spy.freq(1.0) @@ -79,7 +97,6 @@ def test_spy_basic() -> None: print(f"Odom topic: {odom_topic_spy}") -@pytest.mark.lcm def test_topic_statistics_direct() -> None: """Test Topic statistics directly without LCM""" @@ -128,7 +145,6 @@ def test_topic_cleanup() -> None: assert topic.message_history[0][0] > time.time() - 10 # Recent message -@pytest.mark.lcm def test_graph_topic_basic() -> None: """Test GraphTopic basic functionality""" topic = GraphTopic("/test_graph") @@ -144,79 +160,59 @@ def test_graph_topic_basic() -> None: assert topic.bandwidth_history[0] > 0 -@pytest.mark.lcm -def test_graph_lcmspy_basic() -> None: +def test_graph_lcmspy_basic(graph_lcmspy_instance) -> None: """Test GraphLCMSpy basic functionality""" - spy = GraphLCMSpy(autoconf=True, graph_log_window=0.1) - spy.start() - time.sleep(0.2) # Wait for thread to start - # Simulate a message - spy.msg("/test", b"test data") + graph_lcmspy_instance.msg("/test", b"test data") time.sleep(0.2) # Wait for graph update # Should create GraphTopic with history - topic = spy.topic["/test"] + topic = graph_lcmspy_instance.topic["/test"] assert isinstance(topic, GraphTopic) assert len(topic.freq_history) > 0 assert len(topic.bandwidth_history) > 0 - spy.stop() - -@pytest.mark.lcm -def test_lcmspy_global_totals() -> None: +def test_lcmspy_global_totals(lcmspy_instance) -> None: """Test that LCMSpy tracks global totals as a Topic itself""" - spy = LCMSpy(autoconf=True) - spy.start() - # Send messages to different topics - spy.msg("/video", b"video frame data") - spy.msg("/odom", b"odometry data") - spy.msg("/imu", b"imu data") + lcmspy_instance.msg("/video", b"video frame data") + lcmspy_instance.msg("/odom", b"odometry data") + lcmspy_instance.msg("/imu", b"imu data") # Verify each test topic received exactly one message (ignore LCM discovery packets) for t in ("/video", "/odom", "/imu"): - assert len(spy.topic[t].message_history) == 1 + assert len(lcmspy_instance.topic[t].message_history) == 1 # Check global statistics - global_freq = spy.freq(1.0) - global_kbps = spy.kbps(1.0) - global_size = spy.size(1.0) + global_freq = lcmspy_instance.freq(1.0) + global_kbps = lcmspy_instance.kbps(1.0) + global_size = lcmspy_instance.size(1.0) assert global_freq > 0 assert global_kbps > 0 assert global_size > 0 print(f"Global frequency: {global_freq:.2f} Hz") - print(f"Global bandwidth: {spy.kbps_hr(1.0)}") + print(f"Global bandwidth: {lcmspy_instance.kbps_hr(1.0)}") print(f"Global avg message size: {global_size:.0f} bytes") - spy.stop() - -@pytest.mark.lcm -def test_graph_lcmspy_global_totals() -> None: +def test_graph_lcmspy_global_totals(graph_lcmspy_instance) -> None: """Test that GraphLCMSpy tracks global totals with history""" - spy = GraphLCMSpy(autoconf=True, graph_log_window=0.1) - spy.start() - time.sleep(0.2) - # Send messages - spy.msg("/video", b"video frame data") - spy.msg("/odom", b"odometry data") + graph_lcmspy_instance.msg("/video", b"video frame data") + graph_lcmspy_instance.msg("/odom", b"odometry data") time.sleep(0.2) # Wait for graph update # Update global graphs - spy.update_graphs(1.0) + graph_lcmspy_instance.update_graphs(1.0) # Should have global history - assert len(spy.freq_history) == 1 - assert len(spy.bandwidth_history) == 1 - assert spy.freq_history[0] > 0 - assert spy.bandwidth_history[0] > 0 - - print(f"Global frequency history: {spy.freq_history[0]:.2f} Hz") - print(f"Global bandwidth history: {spy.bandwidth_history[0]:.2f} kB/s") + assert len(graph_lcmspy_instance.freq_history) == 1 + assert len(graph_lcmspy_instance.bandwidth_history) == 1 + assert graph_lcmspy_instance.freq_history[0] > 0 + assert graph_lcmspy_instance.bandwidth_history[0] > 0 - spy.stop() + print(f"Global frequency history: {graph_lcmspy_instance.freq_history[0]:.2f} Hz") + print(f"Global bandwidth history: {graph_lcmspy_instance.bandwidth_history[0]:.2f} kB/s") diff --git a/dimos/utils/docs/test_doclinks.py b/dimos/utils/docs/test_doclinks.py index 968f465cef..7da6a6281b 100644 --- a/dimos/utils/docs/test_doclinks.py +++ b/dimos/utils/docs/test_doclinks.py @@ -773,7 +773,3 @@ def test_skips_mailto_links(self, file_index, doc_index): assert len(errors) == 0 assert len(changes) == 0 assert new_content == content - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/dimos/utils/test_data.py b/dimos/utils/test_data.py index e5be4307c7..e55c8b20f3 100644 --- a/dimos/utils/test_data.py +++ b/dimos/utils/test_data.py @@ -23,7 +23,7 @@ from dimos.utils.data import LfsPath -@pytest.mark.heavy +@pytest.mark.slow def test_pull_file() -> None: repo_root = data._get_repo_root() test_file_name = "cafe.jpg" @@ -79,7 +79,7 @@ def test_pull_file() -> None: ) -@pytest.mark.heavy +@pytest.mark.slow def test_pull_dir() -> None: repo_root = data._get_repo_root() test_dir_name = "ab_lidar_frames" @@ -187,7 +187,7 @@ def test_lfs_path_no_download_on_creation() -> None: assert cache is None -@pytest.mark.heavy +@pytest.mark.slow def test_lfs_path_with_real_file() -> None: """Test LfsPath with a real small LFS file.""" # Use a small existing LFS file @@ -221,7 +221,7 @@ def test_lfs_path_with_real_file() -> None: assert content.startswith(b"\x89PNG") -@pytest.mark.heavy +@pytest.mark.slow def test_lfs_path_unload_and_reload() -> None: """Test unloading and reloading an LFS file.""" filename = "three_paths.png" @@ -266,7 +266,7 @@ def test_lfs_path_unload_and_reload() -> None: assert content_first == content_second -@pytest.mark.heavy +@pytest.mark.slow def test_lfs_path_operations() -> None: """Test various Path operations with LfsPath.""" filename = "three_paths.png" @@ -295,7 +295,7 @@ def test_lfs_path_operations() -> None: assert filename in fspath_result -@pytest.mark.heavy +@pytest.mark.slow def test_lfs_path_division_operator() -> None: """Test path division operator with LfsPath.""" # Use a directory for testing @@ -309,7 +309,7 @@ def test_lfs_path_division_operator() -> None: assert "three_paths.png" in str(result) -@pytest.mark.heavy +@pytest.mark.slow def test_lfs_path_multiple_instances() -> None: """Test that multiple LfsPath instances for same file work correctly.""" filename = "three_paths.png" diff --git a/dimos/utils/test_reactive.py b/dimos/utils/test_reactive.py index 5bfc0a590f..f6f1340059 100644 --- a/dimos/utils/test_reactive.py +++ b/dimos/utils/test_reactive.py @@ -82,7 +82,7 @@ def _dispose() -> None: return proxy -@pytest.mark.integration +@pytest.mark.slow def test_backpressure_handling() -> None: # Create a dedicated scheduler for this test to avoid thread leaks test_scheduler = ThreadPoolScheduler(max_workers=8) @@ -142,7 +142,7 @@ def test_backpressure_handling() -> None: test_scheduler.executor.shutdown(wait=True) -@pytest.mark.integration +@pytest.mark.slow def test_getter_streaming_blocking() -> None: source = dispose_spy( rx.interval(0.2).pipe(ops.map(lambda i: np.array([i, i + 1, i + 2])), ops.take(50)) @@ -177,7 +177,7 @@ def test_getter_streaming_blocking_timeout() -> None: assert source.is_disposed() -@pytest.mark.integration +@pytest.mark.slow def test_getter_streaming_nonblocking() -> None: source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) diff --git a/dimos/utils/test_transform_utils.py b/dimos/utils/test_transform_utils.py index b404579598..7923124c9f 100644 --- a/dimos/utils/test_transform_utils.py +++ b/dimos/utils/test_transform_utils.py @@ -672,7 +672,3 @@ def test_retract_arbitrary_pose(self) -> None: assert np.isclose(retracted.position.x, expected_x) assert np.isclose(retracted.position.y, expected_y) assert np.isclose(retracted.position.z, expected_z) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/docs/capabilities/manipulation/readme.md b/docs/capabilities/manipulation/readme.md index 4a943e6be5..0d6539b75c 100644 --- a/docs/capabilities/manipulation/readme.md +++ b/docs/capabilities/manipulation/readme.md @@ -101,7 +101,7 @@ KeyboardTeleopModule ──→ ControlCoordinator ──→ ManipulationModule ## Adding a Custom Arm -[guide is here](adding_a_custom_arm.md) +[guide is here](/docs/capabilities/manipulation/adding_a_custom_arm.md) ## Key Files diff --git a/docs/development/testing.md b/docs/development/testing.md index c27a8c5dec..c8a226b7ad 100644 --- a/docs/development/testing.md +++ b/docs/development/testing.md @@ -8,24 +8,28 @@ uv sync --all-extras --no-extra dds ## Types of tests -There are different types of tests based on what their goal is: +In general, there are different types of tests based on what their goal is: | Type | Description | Mocking | Speed | |------|-------------|---------|-------| -| Unit | Test a small individual piece of code | All dependencies | Very fast | -| Integration | Test the integration between multiple units of code | Most dependencies | Some fast, some slow | -| Functional | Test a particular desired functionality | Some dependencies | Some fast, some slow | +| Unit | Test a small individual piece of code | All external systems | Very fast | +| Integration | Test the integration between multiple units of code | Most external systems | Some fast, some slow | +| Functional | Test a particular desired functionality | Some external systems | Some fast, some slow | | End-to-end | Test the entire system as a whole from the perspective of the user | None | Very slow | The distinction between unit, integration, and functional tests is often debated and rarely productive. Rather than waste time on classifying tests, it's better to separate tests by how they are used: -* **fast tests**: tests which you can run after each code change (people often run them with filesystem watchers: whenever a file is saved, automatically run the tests) -* **slow tests**: tests which you run every once in a while to make sure you haven't broken anything (maybe every commit, but definitely before publishing a PR) +| Test Group | When to run | Typical usage | +|------------|-------------|---------------| +| **fast tests** | after each code change | often run with filesystem watchers so tests rerun whenever a file is saved | +| **slow tests** | every once in a while to make sure you haven't broken anything | maybe every commit, but definitely before publishing a PR | The purpose of running tests in a loop is to get immediate feedback. The faster the loop, the easier it is to identify a problem since the source is the tiny bit of code you changed. +For the purposes of DimOS, slow tests are marked with `@pytest.mark.slow` and fast tests are all the remaining ones. + ## Usage ### Fast tests @@ -42,7 +46,7 @@ This is the same as: pytest dimos ``` -The default `addopts` in `pyproject.toml` includes a `-m` filter that excludes slow markers (like `integration`, `heavy`, `e2e`, etc.), so plain `pytest dimos` only runs fast tests. +The default `addopts` in `pyproject.toml` includes a `-m` filter that excludes the `slow`/`mujoco`/`tool`. So plain `pytest dimos` only runs fast tests. ### Slow tests @@ -52,14 +56,14 @@ Run the slow tests: ./bin/pytest-slow ``` -This overrides the default `-m` filter to include most markers. When writing or debugging a specific slow test, override `-m` yourself: +(This is just a shortcut for `pytest -m 'not (tool or mujoco)' dimos`. I.e., run both fast tests and slow tests, but not `tool` or `mujoco`.) + +When writing or debugging a specific slow test, override `-m` yourself to run it: ```bash -pytest -m integration dimos/path/to/test_something.py +pytest -m slow dimos/path/to/test_something.py ``` -Note: passing `-m` on the command line overrides the default from `addopts`, so you get exactly the marker set you asked for. - ## Writing tests Test files live next to the code they test. If you have `dimos/core/pubsub.py`, its tests go in `dimos/core/test_pubsub.py`. @@ -120,3 +124,17 @@ There are other useful things in `mocker`, like `mocker.MagicMock()` for creatin | `--pdb` | Drop into the debugger when a test fails | | `--tb=short` | Shorter tracebacks | | `--durations=0` | Measure the speed of each test | + +## Markers + +We have a few markers in use now. + +* `slow`: used to mark tests that take more than 1 second to finish. +* `tool`: tests which require human interaction. I don't like this. Please don't use them. +* `mujoco`: tests which use `MuJoCo`. These are very slow and don't work in CI currently. + +If a test needs to be skipped for some reason, please use on of these markers, or add another one. + +* `skipif_in_ci`: tests which cannot run in GitHub Actions +* `skipif_no_openai`: tests which require an `OPENAI_API_KEY` key in the env +* `skipif_no_alibaba`: tests which require an `ALIBABA_API_KEY` key in the env diff --git a/docs/platforms/humanoid/g1/index.md b/docs/platforms/humanoid/g1/index.md index 2e04f3b023..797c865b20 100644 --- a/docs/platforms/humanoid/g1/index.md +++ b/docs/platforms/humanoid/g1/index.md @@ -13,9 +13,9 @@ The Unitree G1 is a humanoid robot platform with full-body locomotion, arm gestu ## Install First, install system dependencies for your platform: -- [Ubuntu](../../../installation/ubuntu.md) -- [macOS](../../../installation/osx.md) -- [Nix](../../../installation/nix.md) +- [Ubuntu](/docs/installation/ubuntu.md) +- [macOS](/docs/installation/osx.md) +- [Nix](/docs/installation/nix.md) Then install DimOS: @@ -159,9 +159,9 @@ primitive (sensors + vis) ## Deep Dive -- [Navigation Stack](../../../capabilities/navigation/readme.md) — path planning and autonomous exploration -- [Visualization](../../../usage/visualization.md) — Rerun, Foxglove, performance tuning -- [Data Streams](../../../usage/data_streams/) — RxPY streams, backpressure, quality filtering -- [Transports](../../../usage/transports/index.md) — LCM, SHM, DDS -- [Blueprints](../../../usage/blueprints.md) — composing modules -- [Agents](../../../capabilities/agents/readme.md) — LLM agent framework +- [Navigation Stack](/docs/capabilities/navigation/readme.md) — path planning and autonomous exploration +- [Visualization](/docs/usage/visualization.md) — Rerun, Foxglove, performance tuning +- [Data Streams](/docs/usage/data_streams) — RxPY streams, backpressure, quality filtering +- [Transports](/docs/usage/transports/index.md) — LCM, SHM, DDS +- [Blueprints](/docs/usage/blueprints.md) — composing modules +- [Agents](/docs/capabilities/agents/readme.md) — LLM agent framework diff --git a/pyproject.toml b/pyproject.toml index 6471fd89cd..f81f69887c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -388,24 +388,11 @@ follow_imports = "skip" [tool.pytest.ini_options] testpaths = ["dimos"] -markers = [ - "heavy: resource heavy test", - "tool: dev tooling", - "ros: depend on ros", - "lcm: tests that run actual LCM bus (can't execute in CI)", - "module: tests that need to run directly as modules", - "gpu: tests that require GPU", - "cuda: tests which require CUDA (specifically CUDA not just GPU acceleration)", - "e2e: end to end tests", - "integration: slower integration tests", - "neverending: they don't finish", - "mujoco: tests which open mujoco", -] env = [ "GOOGLE_MAPS_API_KEY=AIzafake_google_key", "PYTHONWARNINGS=ignore:cupyx.jit.rawkernel is experimental:FutureWarning", ] -addopts = "-v -s -p no:warnings -ra --color=yes -m 'not (vis or exclude or tool or lcm or ros or heavy or gpu or module or e2e or integration or neverending or mujoco)'" +addopts = "-v -r a -p no:warnings --color=yes -m 'not (tool or slow or mujoco)'" asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function" From 6b73b42354922ed09ddbe908b656fb11fa2d1b7a Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Tue, 24 Feb 2026 15:13:05 +0200 Subject: [PATCH 11/16] feat(blueprints): running modules by themselves (#1342) * feat(blueprints): add scratch * feat(blueprints): be able to run modules too * remove scratch --- dimos/robot/all_blueprints.py | 100 +++++++++--------- dimos/robot/cli/dimos.py | 21 +--- dimos/robot/get_all_blueprints.py | 36 ++++++- dimos/robot/test_all_blueprints_generation.py | 11 +- docs/development/dimos_run.md | 2 +- pyproject.toml | 5 +- 6 files changed, 100 insertions(+), 75 deletions(-) diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 8e1c7fa89f..ed90982801 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -90,57 +90,57 @@ all_modules = { "agent": "dimos.agents.agent", - "arm_teleop_module": "dimos.teleop.quest.quest_extensions", - "camera_module": "dimos.hardware.sensors.camera.module", - "cartesian_motion_controller": "dimos.manipulation.control.servo_control.cartesian_motion_controller", - "control_coordinator": "dimos.control.coordinator", - "cost_mapper": "dimos.mapping.costmapper", - "demo_calculator_skill": "dimos.agents.skills.demo_calculator_skill", - "demo_robot": "dimos.agents.skills.demo_robot", - "detection3d_module": "dimos.perception.detection.module3D", - "detection_db_module": "dimos.perception.detection.moduleDB", - "fastlio2_module": "dimos.hardware.sensors.lidar.fastlio2.module", - "foxglove_bridge": "dimos.robot.foxglove_bridge", - "g1_connection": "dimos.robot.unitree.g1.connection", - "g1_sim_connection": "dimos.robot.unitree.g1.sim", - "g1_skills": "dimos.robot.unitree.g1.skill_container", - "go2_connection": "dimos.robot.unitree.go2.connection", - "google_maps_skill": "dimos.agents.skills.google_maps_skill_container", - "gps_nav_skill": "dimos.agents.skills.gps_nav_skill", - "grasping_module": "dimos.manipulation.grasping.grasping", - "joint_trajectory_controller": "dimos.manipulation.control.trajectory_controller.joint_trajectory_controller", - "keyboard_teleop": "dimos.robot.unitree.keyboard_teleop", - "keyboard_teleop_module": "dimos.teleop.keyboard.keyboard_teleop_module", - "manipulation_module": "dimos.manipulation.manipulation_module", + "arm-teleop-module": "dimos.teleop.quest.quest_extensions", + "camera-module": "dimos.hardware.sensors.camera.module", + "cartesian-motion-controller": "dimos.manipulation.control.servo_control.cartesian_motion_controller", + "control-coordinator": "dimos.control.coordinator", + "cost-mapper": "dimos.mapping.costmapper", + "demo-calculator-skill": "dimos.agents.skills.demo_calculator_skill", + "demo-robot": "dimos.agents.skills.demo_robot", + "detection-db-module": "dimos.perception.detection.moduleDB", + "detection3d-module": "dimos.perception.detection.module3D", + "fastlio2-module": "dimos.hardware.sensors.lidar.fastlio2.module", + "foxglove-bridge": "dimos.robot.foxglove_bridge", + "g1-connection": "dimos.robot.unitree.g1.connection", + "g1-sim-connection": "dimos.robot.unitree.g1.sim", + "g1-skills": "dimos.robot.unitree.g1.skill_container", + "go2-connection": "dimos.robot.unitree.go2.connection", + "google-maps-skill": "dimos.agents.skills.google_maps_skill_container", + "gps-nav-skill": "dimos.agents.skills.gps_nav_skill", + "grasping-module": "dimos.manipulation.grasping.grasping", + "joint-trajectory-controller": "dimos.manipulation.control.trajectory_controller.joint_trajectory_controller", + "keyboard-teleop": "dimos.robot.unitree.keyboard_teleop", + "keyboard-teleop-module": "dimos.teleop.keyboard.keyboard_teleop_module", + "manipulation-module": "dimos.manipulation.manipulation_module", "mapper": "dimos.robot.unitree.type.map", - "mcp_client": "dimos.agents.mcp.mcp_client", - "mid360_module": "dimos.hardware.sensors.lidar.livox.module", - "navigation_skill": "dimos.agents.skills.navigation", - "object_scene_registration_module": "dimos.perception.object_scene_registration", - "object_tracking": "dimos.perception.object_tracker", - "osm_skill": "dimos.agents.skills.osm", - "person_follow_skill": "dimos.agents.skills.person_follow", - "person_tracker_module": "dimos.perception.detection.person_tracker", - "phone_teleop_module": "dimos.teleop.phone.phone_teleop_module", - "quest_teleop_module": "dimos.teleop.quest.quest_teleop_module", - "realsense_camera": "dimos.hardware.sensors.camera.realsense.camera", - "replanning_a_star_planner": "dimos.navigation.replanning_a_star.module", - "rerun_bridge": "dimos.visualization.rerun.bridge", - "ros_nav": "dimos.navigation.rosnav", - "simple_phone_teleop_module": "dimos.teleop.phone.phone_extensions", + "mcp-client": "dimos.agents.mcp.mcp_client", + "mid360-module": "dimos.hardware.sensors.lidar.livox.module", + "navigation-skill": "dimos.agents.skills.navigation", + "object-scene-registration-module": "dimos.perception.object_scene_registration", + "object-tracking": "dimos.perception.object_tracker", + "osm-skill": "dimos.agents.skills.osm", + "person-follow-skill": "dimos.agents.skills.person_follow", + "person-tracker-module": "dimos.perception.detection.person_tracker", + "phone-teleop-module": "dimos.teleop.phone.phone_teleop_module", + "quest-teleop-module": "dimos.teleop.quest.quest_teleop_module", + "realsense-camera": "dimos.hardware.sensors.camera.realsense.camera", + "replanning-a-star-planner": "dimos.navigation.replanning_a_star.module", + "rerun-bridge": "dimos.visualization.rerun.bridge", + "ros-nav": "dimos.navigation.rosnav", + "simple-phone-teleop-module": "dimos.teleop.phone.phone_extensions", "simulation": "dimos.simulation.manipulators.sim_module", - "spatial_memory": "dimos.perception.spatial_perception", - "speak_skill": "dimos.agents.skills.speak_skill", - "temporal_memory": "dimos.perception.experimental.temporal_memory.temporal_memory", - "twist_teleop_module": "dimos.teleop.quest.quest_extensions", - "unitree_skills": "dimos.robot.unitree.unitree_skill_container", + "spatial-memory": "dimos.perception.spatial_perception", + "speak-skill": "dimos.agents.skills.speak_skill", + "temporal-memory": "dimos.perception.experimental.temporal_memory.temporal_memory", + "twist-teleop-module": "dimos.teleop.quest.quest_extensions", + "unitree-skills": "dimos.robot.unitree.unitree_skill_container", "utilization": "dimos.utils.monitoring", - "visualizing_teleop_module": "dimos.teleop.quest.quest_extensions", - "vlm_agent": "dimos.agents.vlm_agent", - "vlm_stream_tester": "dimos.agents.vlm_stream_tester", - "voxel_mapper": "dimos.mapping.voxels", - "wavefront_frontier_explorer": "dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector", - "web_input": "dimos.agents.web_human_input", - "websocket_vis": "dimos.web.websocket_vis.websocket_vis_module", - "zed_camera": "dimos.hardware.sensors.camera.zed.camera", + "visualizing-teleop-module": "dimos.teleop.quest.quest_extensions", + "vlm-agent": "dimos.agents.vlm_agent", + "vlm-stream-tester": "dimos.agents.vlm_stream_tester", + "voxel-mapper": "dimos.mapping.voxels", + "wavefront-frontier-explorer": "dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector", + "web-input": "dimos.agents.web_human_input", + "websocket-vis": "dimos.web.websocket_vis.websocket_vis_module", + "zed-camera": "dimos.hardware.sensors.camera.zed.camera", } diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py index c390d3b76c..3979138216 100644 --- a/dimos/robot/cli/dimos.py +++ b/dimos/robot/cli/dimos.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum import inspect import sys from typing import Any, get_args, get_origin @@ -21,9 +20,6 @@ import typer from dimos.core.global_config import GlobalConfig, global_config -from dimos.robot.all_blueprints import all_blueprints - -RobotType = Enum("RobotType", {key.replace("-", "_").upper(): key for key in all_blueprints.keys()}) # type: ignore[misc] main = typer.Typer( help="Dimensional CLI", @@ -102,15 +98,12 @@ def callback(**kwargs) -> None: # type: ignore[no-untyped-def] @main.command() def run( ctx: typer.Context, - robot_type: RobotType = typer.Argument(..., help="Type of robot to run"), - extra_modules: list[str] = typer.Option( # type: ignore[valid-type] - [], "--extra-module", help="Extra modules to add to the blueprint" - ), + robot_types: list[str] = typer.Argument(..., help="Blueprints or modules to run"), ) -> None: """Start a robot blueprint""" from dimos.core.blueprints import autoconnect from dimos.protocol import pubsub - from dimos.robot.get_all_blueprints import get_blueprint_by_name, get_module_by_name + from dimos.robot.get_all_blueprints import get_by_name from dimos.utils.logging_config import setup_exception_handler setup_exception_handler() @@ -118,12 +111,8 @@ def run( cli_config_overrides: dict[str, Any] = ctx.obj global_config.update(**cli_config_overrides) pubsub.lcm.autoconf() # type: ignore[attr-defined] - blueprint = get_blueprint_by_name(robot_type.value) - - if extra_modules: - loaded_modules = [get_module_by_name(mod_name) for mod_name in extra_modules] # type: ignore[attr-defined] - blueprint = autoconnect(blueprint, *loaded_modules) + blueprint = autoconnect(*map(get_by_name, robot_types)) dimos = blueprint.build(cli_config_overrides=cli_config_overrides) dimos.loop() @@ -139,8 +128,8 @@ def show_config(ctx: typer.Context) -> None: typer.echo(f"{field_name}: {value}") -@main.command() -def list() -> None: +@main.command(name="list") +def list_blueprints() -> None: """List all available blueprints.""" from dimos.robot.all_blueprints import all_blueprints diff --git a/dimos/robot/get_all_blueprints.py b/dimos/robot/get_all_blueprints.py index 8658e4f4ec..f7a79fb8d7 100644 --- a/dimos/robot/get_all_blueprints.py +++ b/dimos/robot/get_all_blueprints.py @@ -12,13 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import difflib +import sys +from typing import NoReturn + +import typer + from dimos.core.blueprints import Blueprint from dimos.robot.all_blueprints import all_blueprints, all_modules +all_names = sorted(set(all_blueprints.keys()) | set(all_modules.keys())) + + +def _fail_unknown(name: str, candidates: list[str]) -> NoReturn: + typer.echo(typer.style(f"Unknown blueprint or module: {name}", fg=typer.colors.RED), err=True) + suggestions = difflib.get_close_matches(name, candidates, n=5, cutoff=0.4) + if suggestions: + typer.echo("Did you mean one of these?", err=True) + for s in suggestions: + typer.echo(f" {s}", err=True) + sys.exit(1) + def get_blueprint_by_name(name: str) -> Blueprint: if name not in all_blueprints: - raise ValueError(f"Unknown blueprint set name: {name}") + _fail_unknown(name, list(all_blueprints.keys())) module_path, attr = all_blueprints[name].split(":") module = __import__(module_path, fromlist=[attr]) return getattr(module, attr) # type: ignore[no-any-return] @@ -26,6 +44,16 @@ def get_blueprint_by_name(name: str) -> Blueprint: def get_module_by_name(name: str) -> Blueprint: if name not in all_modules: - raise ValueError(f"Unknown module name: {name}") - python_module = __import__(all_modules[name], fromlist=[name]) - return getattr(python_module, name)() # type: ignore[no-any-return] + _fail_unknown(name, list(all_modules.keys())) + attr_name = name.replace("-", "_") + python_module = __import__(all_modules[name], fromlist=[attr_name]) + return getattr(python_module, attr_name)() # type: ignore[no-any-return] + + +def get_by_name(name: str) -> Blueprint: + if name in all_blueprints: + return get_blueprint_by_name(name) + elif name in all_modules: + return get_module_by_name(name) + else: + _fail_unknown(name, all_names) diff --git a/dimos/robot/test_all_blueprints_generation.py b/dimos/robot/test_all_blueprints_generation.py index e7ba79a404..74c8534820 100644 --- a/dimos/robot/test_all_blueprints_generation.py +++ b/dimos/robot/test_all_blueprints_generation.py @@ -37,6 +37,13 @@ def test_all_blueprints_is_current() -> None: root = DIMOS_PROJECT_ROOT / "dimos" all_blueprints, all_modules = _scan_for_blueprints(root) + + common = set(all_blueprints.keys()) & set(all_modules.keys()) + assert not common, ( + f"Names must be unique across blueprints and modules, " + f"but these appear in both: {sorted(common)}" + ) + generated_content = _generate_all_blueprints_content(all_blueprints, all_modules) file_path = root / "robot" / "all_blueprints.py" @@ -83,8 +90,8 @@ def _scan_for_blueprints(root: Path) -> tuple[dict[str, str], dict[str, str]]: all_blueprints[cli_name] = full_path for var_name in module_vars: - full_path = f"{module_name}:{var_name}" - all_modules[var_name] = module_name + cli_name = var_name.replace("_", "-") + all_modules[cli_name] = module_name return all_blueprints, all_modules diff --git a/docs/development/dimos_run.md b/docs/development/dimos_run.md index 3e6bee65e6..604724a68d 100644 --- a/docs/development/dimos_run.md +++ b/docs/development/dimos_run.md @@ -31,7 +31,7 @@ dimos run unitree-go2-agentic You can dynamically connect additional modules. For example: ```bash -dimos run unitree-go2 --extra-module agent --extra-module navigation_skill +dimos run unitree-go2 keyboard-teleop ``` ## Adding your own diff --git a/pyproject.toml b/pyproject.toml index f81f69887c..be56bbeb41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -359,11 +359,12 @@ module = [ "pinocchio", "piper_sdk.*", "plotext", - "pydrake", - "pydrake.*", "plum.*", + "portal", "pycuda", "pycuda.*", + "pydrake", + "pydrake.*", "pyzed", "pyzed.*", "rclpy.*", From 83732afa5e39b9406b6176854c1e5694ef2db776 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 24 Feb 2026 19:24:32 -0800 Subject: [PATCH 12/16] Resolve annotations using namespaces from the full MRO chain --- dimos/core/blueprints.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index 605517e6cf..17475e9034 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -63,9 +63,16 @@ def create( streams: list[StreamRef] = [] module_refs: list[ModuleRef] = [] - # Use get_type_hints() to properly resolve string annotations. + # Resolve annotations using namespaces from the full MRO chain so that + # In/Out behind TYPE_CHECKING + `from __future__ import annotations` work. + # Iterate reversed MRO so the most specific class's namespace wins when + # parent modules shadow names (e.g. spec.perception.Image vs sensor_msgs.Image). + globalns: dict[str, Any] = {} + for c in reversed(module.__mro__): + if c.__module__ in sys.modules: + globalns.update(sys.modules[c.__module__].__dict__) try: - all_annotations = get_type_hints(module) + all_annotations = get_type_hints(module, globalns=globalns) except Exception: # Fallback to raw annotations if get_type_hints fails. all_annotations = {} From 2ff79b94a7dacccdb5a8d4a4f5d1a356409a41a3 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 24 Feb 2026 19:25:20 -0800 Subject: [PATCH 13/16] refactored manipulation module to make it generic --- dimos/manipulation/manipulation_module.py | 592 +++------------------- 1 file changed, 64 insertions(+), 528 deletions(-) diff --git a/dimos/manipulation/manipulation_module.py b/dimos/manipulation/manipulation_module.py index 310b77d766..93f25281b9 100644 --- a/dimos/manipulation/manipulation_module.py +++ b/dimos/manipulation/manipulation_module.py @@ -14,28 +14,25 @@ """Manipulation Module - Motion planning with ControlCoordinator execution. -Interface layers: +Base module providing core manipulation infrastructure: - @rpc: Low-level building blocks (plan_to_pose, plan_to_joints, preview_path, execute) -- @skill (short-horizon): Single-step actions (move_to_pose, open_gripper, scan_objects, go_init) -- @skill (long-horizon): Multi-step composed behaviors (pick, place, place_back, pick_and_place) +- @skill (short-horizon): Single-step actions (move_to_pose, open_gripper, go_home, go_init) + +Subclass PickAndPlaceModule (pick_and_place_module.py) adds perception integration +(scan_objects, get_scene_info) and long-horizon skills (pick, place, pick_and_place). """ from __future__ import annotations from dataclasses import dataclass, field from enum import Enum -import math -from pathlib import Path import threading import time from typing import TYPE_CHECKING, Any, TypeAlias from dimos.agents.annotation import skill -from dimos.constants import DIMOS_PROJECT_ROOT from dimos.core import In, Module, rpc -from dimos.core.docker_runner import DockerModule as DockerRunner from dimos.core.module import ModuleConfig -from dimos.manipulation.grasping.graspgen_module import GraspGenModule from dimos.manipulation.planning import ( JointPath, JointTrajectoryGenerator, @@ -55,15 +52,10 @@ # These must be imported at runtime (not TYPE_CHECKING) for In/Out port creation from dimos.msgs.sensor_msgs import JointState from dimos.msgs.trajectory_msgs import JointTrajectory -from dimos.perception.detection.type.detection3d.object import Object as DetObject -from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: from dimos.core.rpc_client import RPCClient - from dimos.msgs.geometry_msgs import PoseArray - from dimos.msgs.sensor_msgs import PointCloud2 - from dimos.perception.detection.type.detection3d.object import Object as DetObject logger = setup_logger() @@ -80,10 +72,6 @@ PlannedTrajectories: TypeAlias = dict[RobotName, JointTrajectory] """Maps robot_name -> planned trajectory""" -# The host-side path (graspgen_visualization_output_path) is volume-mounted here. -_GRASPGEN_VIZ_CONTAINER_DIR = "/output/graspgen" -_GRASPGEN_VIZ_CONTAINER_PATH = f"{_GRASPGEN_VIZ_CONTAINER_DIR}/visualization.json" - class ManipulationState(Enum): """State machine for manipulation module.""" @@ -105,25 +93,14 @@ class ManipulationModuleConfig(ModuleConfig): planner_name: str = "rrt_connect" # "rrt_connect" kinematics_name: str = "jacobian" # "jacobian" or "drake_optimization" - # GraspGen Docker settings (optional) - graspgen_docker_image: str = "dimos-graspgen:latest" - graspgen_gripper_type: str = "robotiq_2f_140" - graspgen_num_grasps: int = 400 - graspgen_topk_num_grasps: int = 100 - graspgen_grasp_threshold: float = -1.0 - graspgen_filter_collisions: bool = False - graspgen_save_visualization_data: bool = False - graspgen_visualization_output_path: Path = field( - default_factory=lambda: Path.home() / ".dimos" / "graspgen" / "visualization.json" - ) - class ManipulationModule(Module): - """Motion planning module with ControlCoordinator execution. + """Base motion planning module with ControlCoordinator execution. + + - @rpc: Low-level building blocks (plan, execute, gripper) + - @skill (short-horizon): Single-step actions (move_to_pose, open_gripper, go_home) - - @rpc: Low-level building blocks (plan, execute, obstacles) - - @skill (short-horizon): Single-step actions (move_to_pose, open_gripper, scan_objects) - - @skill (long-horizon): Multi-step behaviors (pick, place, pick_and_place) + Subclass PickAndPlaceModule adds perception integration and long-horizon skills. """ default_config = ManipulationModuleConfig @@ -134,9 +111,6 @@ class ManipulationModule(Module): # Input: Joint state from coordinator (for world sync) joint_state: In[JointState] - # Input: Objects from perception (for obstacle integration) - objects: In[list[DetObject]] - def __init__(self, *args: object, **kwargs: object) -> None: super().__init__(*args, **kwargs) @@ -160,19 +134,9 @@ def __init__(self, *args: object, **kwargs: object) -> None: # Coordinator integration (lazy initialized) self._coordinator_client: RPCClient | None = None - # GraspGen Docker runner (lazy initialized on first generate_grasps call) - self._graspgen: DockerRunner | None = None # Init joints: captured from first joint state received, used by go_init self._init_joints: JointState | None = None - # Last pick position: stored during pick so place_back() can return the object - self._last_pick_position: Vector3 | None = None - - # Snapshotted detections from the last scan_objects/refresh call. - # The live detection cache is volatile (labels change every frame), - # so pick/place use this stable snapshot instead. - self._detection_snapshot: list[DetObject] = [] - # TF publishing thread self._tf_stop_event = threading.Event() self._tf_thread: threading.Thread | None = None @@ -192,11 +156,6 @@ def start(self) -> None: self.joint_state.subscribe(self._on_joint_state) logger.info("Subscribed to joint_state port") - # Subscribe to objects port for perception obstacle integration - if self.objects is not None: - self.objects.observable().subscribe(self._on_objects) # type: ignore[no-untyped-call] - logger.info("Subscribed to objects port (async)") - logger.info("ManipulationModule started") def _initialize_planning(self) -> None: @@ -221,9 +180,6 @@ def _initialize_planning(self) -> None: for _, (robot_id, _, _) in self._robots.items(): self._world_monitor.start_state_monitor(robot_id) - # Start obstacle monitor for perception integration - self._world_monitor.start_obstacle_monitor() - if self.config.enable_viz: self._world_monitor.start_visualization_thread(rate_hz=10.0) if url := self._world_monitor.get_visualization_url(): @@ -294,14 +250,6 @@ def _on_joint_state(self, msg: JointState) -> None: logger.error(traceback.format_exc()) - def _on_objects(self, objects: list[DetObject]) -> None: - """Callback when objects received from perception (runs on RxPY thread pool).""" - try: - if self._world_monitor is not None: - self._world_monitor.on_objects(objects) - except Exception as e: - logger.error(f"Exception in _on_objects: {e}") - def _tf_publish_loop(self) -> None: """Publish TF transforms at 10Hz for EE and extra links.""" from dimos.msgs.geometry_msgs import Transform @@ -362,14 +310,18 @@ def cancel(self) -> bool: logger.info("Motion cancelled") return True - @rpc - def reset(self) -> bool: - """Reset to IDLE state (fails if EXECUTING).""" + @skill + def reset(self) -> str: + """Reset the robot module to IDLE state, clearing any fault. + + Use this after an error or fault to allow new commands. + Cannot reset while a motion is executing — cancel first. + """ if self._state == ManipulationState.EXECUTING: - return False + return "Error: Cannot reset while executing — cancel the motion first" self._state = ManipulationState.IDLE self._error_message = "" - return True + return "Reset to IDLE — ready for new commands" @rpc def get_current_joints(self, robot_name: RobotName | None = None) -> list[float] | None: @@ -779,70 +731,6 @@ def get_trajectory_status(self, robot_name: RobotName | None = None) -> dict[str except Exception: return None - def _get_graspgen(self) -> DockerRunner: - """Get or create GraspGen Docker module (lazy init, thread-safe).""" - # Fast path: already initialized (no lock needed for read) - if self._graspgen is not None: - return self._graspgen - - # Slow path: need to initialize (acquire lock to prevent race condition) - with self._lock: - # Double-check: another thread may have initialized while we waited for lock - if self._graspgen is not None: - return self._graspgen - - # Ensure GraspGen model checkpoints are pulled from LFS - get_data("models_graspgen") - - docker_file = ( - DIMOS_PROJECT_ROOT - / "dimos" - / "manipulation" - / "grasping" - / "docker_context" - / "Dockerfile" - ) - - # Auto-mount host directory for visualization output when enabled. - docker_volumes: list[tuple[str, str, str]] = [] - if self.config.graspgen_save_visualization_data: - host_dir = self.config.graspgen_visualization_output_path.parent - host_dir.mkdir(parents=True, exist_ok=True) - docker_volumes.append((str(host_dir), _GRASPGEN_VIZ_CONTAINER_DIR, "rw")) - - graspgen = DockerRunner( - GraspGenModule, # type: ignore[arg-type] - docker_file=docker_file, - docker_build_context=DIMOS_PROJECT_ROOT, - docker_image=self.config.graspgen_docker_image, - docker_env={"CI": "1"}, # skip interactive system config prompt in container - docker_volumes=docker_volumes, - gripper_type=self.config.graspgen_gripper_type, - num_grasps=self.config.graspgen_num_grasps, - topk_num_grasps=self.config.graspgen_topk_num_grasps, - grasp_threshold=self.config.graspgen_grasp_threshold, - filter_collisions=self.config.graspgen_filter_collisions, - save_visualization_data=self.config.graspgen_save_visualization_data, - visualization_output_path=_GRASPGEN_VIZ_CONTAINER_PATH, - ) - graspgen.start() - self._graspgen = graspgen # cache only after successful start - return self._graspgen - - @rpc - def generate_grasps( - self, - pointcloud: PointCloud2, - scene_pointcloud: PointCloud2 | None = None, - ) -> PoseArray | None: - """Generate grasp poses for the given point cloud via GraspGen Docker module.""" - try: - graspgen = self._get_graspgen() - return graspgen.generate_grasps(pointcloud, scene_pointcloud) # type: ignore[no-any-return] - except Exception as e: - logger.error(f"Grasp generation failed: {e}") - return None - @property def world_monitor(self) -> WorldMonitor | None: """Access the world monitor for advanced obstacle/world operations.""" @@ -897,52 +785,6 @@ def remove_obstacle(self, obstacle_id: str) -> bool: return False return self._world_monitor.remove_obstacle(obstacle_id) - # ========================================================================= - # Perception RPC Methods - # ========================================================================= - - @rpc - def refresh_obstacles(self, min_duration: float = 0.0) -> list[dict[str, Any]]: - """Refresh perception obstacles. Returns the list of obstacles added. - - Also snapshots the current detections so pick/place can use stable labels. - """ - if self._world_monitor is None: - return [] - result = self._world_monitor.refresh_obstacles(min_duration) - # Snapshot detections at refresh time — the live cache is volatile - self._detection_snapshot = self._world_monitor.get_cached_objects() - logger.info(f"Detection snapshot: {[d.name for d in self._detection_snapshot]}") - return result - - @rpc - def clear_perception_obstacles(self) -> int: - """Remove all perception obstacles. Returns count removed.""" - if self._world_monitor is None: - return 0 - return self._world_monitor.clear_perception_obstacles() - - @rpc - def get_perception_status(self) -> dict[str, int]: - """Get perception obstacle status (cached/added counts).""" - if self._world_monitor is None: - return {"cached": 0, "added": 0} - return self._world_monitor.get_perception_status() - - @rpc - def list_cached_detections(self) -> list[dict[str, Any]]: - """List cached detections from perception.""" - if self._world_monitor is None: - return [] - return self._world_monitor.list_cached_detections() - - @rpc - def list_added_obstacles(self) -> list[dict[str, Any]]: - """List perception obstacles currently in the planning world.""" - if self._world_monitor is None: - return [] - return self._world_monitor.list_added_obstacles() - # ========================================================================= # Gripper Methods # ========================================================================= @@ -1107,76 +949,41 @@ def _preview_execute_wait( return None - def _compute_pre_grasp_pose(self, grasp_pose: Pose, offset: float = 0.10) -> Pose: - """Compute a pre-grasp pose offset along the approach direction (local -Z). - - Args: - grasp_pose: The final grasp pose - offset: Distance to retract along the approach direction (meters) - - Returns: - Pre-grasp pose offset from the grasp pose - """ - from dimos.utils.transform_utils import offset_distance - - return offset_distance(grasp_pose, offset) - - def _find_object_in_detections( - self, object_name: str, object_id: str | None = None - ) -> DetObject | None: - """Find an object in the detection snapshot by name or ID. + # ========================================================================= + # Short-Horizon Skills — Single-step actions + # ========================================================================= - Uses the snapshot taken during the last scan_objects/refresh call, - not the volatile live cache (which changes labels every frame). + @skill + def get_robot_state(self, robot_name: str | None = None) -> str: + """Get current robot state: joint positions, end-effector pose, and gripper. Args: - object_name: Name/label to search for - object_id: Optional specific object ID - - Returns: - Matching DetObject, or None + robot_name: Robot to query (only needed for multi-arm setups). """ - if not self._detection_snapshot: - logger.warning("No detection snapshot — call scan_objects() first") - return None - - for det in self._detection_snapshot: - if object_id and det.object_id == object_id: - return det - if object_name.lower() in det.name.lower() or det.name.lower() in object_name.lower(): - return det - - available = [det.name for det in self._detection_snapshot] - logger.warning(f"Object '{object_name}' not found in snapshot. Available: {available}") - return None - - def _generate_grasps_for_pick( - self, object_name: str, object_id: str | None = None - ) -> list[Pose] | None: - """Generate grasp poses for an object. + lines: list[str] = [] - Computes a top-down approach grasp from the object's detected position. + joints = self.get_current_joints(robot_name) + if joints is not None: + lines.append(f"Joints: [{', '.join(f'{j:.3f}' for j in joints)}]") + else: + lines.append("Joints: unavailable (no state received)") - Args: - object_name: Name of the object - object_id: Optional object ID + ee_pose = self.get_ee_pose(robot_name) + if ee_pose is not None: + p = ee_pose.position + lines.append(f"EE pose: ({p.x:.4f}, {p.y:.4f}, {p.z:.4f})") + else: + lines.append("EE pose: unavailable") - Returns: - List of grasp poses (best first), or None if object not found - """ - det = self._find_object_in_detections(object_name, object_id) - if det is None: - logger.warning(f"Object '{object_name}' not found in detections") - return None + gripper_pos = self.get_gripper(robot_name) + if gripper_pos is not None: + lines.append(f"Gripper: {gripper_pos:.3f}m") + else: + lines.append("Gripper: not configured") - c = det.center - grasp_pose = Pose(Vector3(c.x, c.y, c.z), Quaternion.from_euler(Vector3(0.0, math.pi, 0.0))) - logger.info(f"Heuristic grasp for '{object_name}' at ({c.x:.3f}, {c.y:.3f}, {c.z:.3f})") - return [grasp_pose] + lines.append(f"State: {self.get_state()}") - # ========================================================================= - # Short-Horizon Skills — Single-step actions - # ========================================================================= + return "\n".join(lines) @skill def move_to_pose( @@ -1184,26 +991,38 @@ def move_to_pose( x: float, y: float, z: float, - roll: float = 0.0, - pitch: float = 0.0, - yaw: float = 0.0, + roll: float | None = None, + pitch: float | None = None, + yaw: float | None = None, robot_name: str | None = None, ) -> str: """Move the robot end-effector to a target pose. Plans a collision-free trajectory and executes it. + If roll/pitch/yaw are omitted, the current EE orientation is preserved. Args: x: Target X position in meters. y: Target Y position in meters. z: Target Z position in meters. - roll: Target roll in radians (default 0). - pitch: Target pitch in radians (default 0). - yaw: Target yaw in radians (default 0). + roll: Target roll in radians (omit to keep current orientation). + pitch: Target pitch in radians (omit to keep current orientation). + yaw: Target yaw in radians (omit to keep current orientation). robot_name: Robot to move (only needed for multi-arm setups). """ logger.info(f"Planning motion to ({x:.3f}, {y:.3f}, {z:.3f})...") - pose = Pose(Vector3(x, y, z), Quaternion.from_euler(Vector3(roll, pitch, yaw))) + + # If no orientation specified, preserve the current EE orientation + if roll is None and pitch is None and yaw is None: + current_pose = self.get_ee_pose(robot_name) + if current_pose is not None: + orientation = current_pose.orientation + else: + orientation = Quaternion(0, 0, 0, 1) # identity fallback + else: + orientation = Quaternion.from_euler(Vector3(roll or 0.0, pitch or 0.0, yaw or 0.0)) + + pose = Pose(Vector3(x, y, z), orientation) if not self.plan_to_pose(pose, robot_name): return f"Error: Planning failed — pose ({x:.3f}, {y:.3f}, {z:.3f}) may be unreachable or in collision" @@ -1249,95 +1068,6 @@ def move_to_joints( return "Reached target joint configuration" - @skill - def get_scene_info(self, robot_name: str | None = None) -> str: - """Get current robot state, detected objects, and scene information. - - Returns a summary of the robot's joint positions, end-effector pose, - gripper state, detected objects, and obstacle count. - - Args: - robot_name: Robot to query (only needed for multi-arm setups). - """ - lines: list[str] = [] - - # Robot state - joints = self.get_current_joints(robot_name) - if joints is not None: - lines.append(f"Joints: [{', '.join(f'{j:.3f}' for j in joints)}]") - else: - lines.append("Joints: unavailable (no state received)") - - ee_pose = self.get_ee_pose(robot_name) - if ee_pose is not None: - p = ee_pose.position - lines.append(f"EE pose: ({p.x:.4f}, {p.y:.4f}, {p.z:.4f})") - else: - lines.append("EE pose: unavailable") - - # Gripper - gripper_pos = self.get_gripper(robot_name) - if gripper_pos is not None: - lines.append(f"Gripper: {gripper_pos:.3f}m") - else: - lines.append("Gripper: not configured") - - # Perception - perception = self.get_perception_status() - lines.append( - f"Perception: {perception.get('cached', 0)} cached, {perception.get('added', 0)} obstacles added" - ) - - detections = self._detection_snapshot - if detections: - lines.append(f"Detected objects ({len(detections)}):") - for det in detections: - c = det.center - lines.append(f" - {det.name}: ({c.x:.3f}, {c.y:.3f}, {c.z:.3f})") - else: - lines.append("Detected objects: none") - - # Visualization - url = self.get_visualization_url() - if url: - lines.append(f"Visualization: {url}") - - # State - lines.append(f"State: {self.get_state()}") - - return "\n".join(lines) - - @skill - def scan_objects(self, min_duration: float = 1.0, robot_name: str | None = None) -> str: - """Scan the scene and list detected objects with their 3D positions. - - Refreshes perception obstacles from the latest sensor data and returns - a formatted list of all detected objects. - - Args: - min_duration: Minimum time in seconds to wait for stable detections. - robot_name: Robot context (only needed for multi-arm setups). - """ - obstacles = self.refresh_obstacles(min_duration) - - detections = self._detection_snapshot - if not detections: - return "No objects detected in scene" - - lines = [f"Detected {len(detections)} object(s):"] - for det in detections: - c = det.center - lines.append(f" - {det.name}: ({c.x:.3f}, {c.y:.3f}, {c.z:.3f})") - - if obstacles: - lines.append(f"\n{len(obstacles)} obstacle(s) added to planning world") - - return "\n".join(lines) - - # ========================================================================= - # Long-Horizon Skills — Multi-step composed behaviors - # ========================================================================= - @skill def go_home(self, robot_name: str | None = None) -> str: """Move the robot to its home/observe joint configuration. @@ -1395,194 +1125,6 @@ def go_init(self, robot_name: str | None = None) -> str: return "Reached init position" - @skill - def pick( - self, - object_name: str, - object_id: str | None = None, - robot_name: str | None = None, - ) -> str: - """Pick up an object by name using grasp planning and motion execution. - - Generates grasp poses, plans collision-free approach/grasp/retract motions, - and executes them. - - Args: - object_name: Name of the object to pick (e.g. "cup", "bottle", "can"). - object_id: Optional unique object ID from perception for precise identification. - robot_name: Robot to use (only needed for multi-arm setups). - """ - robot = self._get_robot(robot_name) - if robot is None: - return "Error: Robot not found" - rname, _, config, _ = robot - pre_grasp_offset = config.pre_grasp_offset - - # 1. Generate grasps (uses already-cached detections — call scan_objects first) - logger.info(f"Generating grasp poses for '{object_name}'...") - grasp_poses = self._generate_grasps_for_pick(object_name, object_id) - if not grasp_poses: - return f"Error: No grasp poses found for '{object_name}'. Object may not be detected." - - # 2. Try each grasp candidate - max_attempts = min(len(grasp_poses), 5) - for i, grasp_pose in enumerate(grasp_poses[:max_attempts]): - pre_grasp_pose = self._compute_pre_grasp_pose(grasp_pose, pre_grasp_offset) - - logger.info(f"Planning approach to pre-grasp (attempt {i + 1}/{max_attempts})...") - if not self.plan_to_pose(pre_grasp_pose, rname): - logger.info(f"Grasp candidate {i + 1} approach planning failed, trying next") - continue # Try next candidate - - # Open gripper before approach - logger.info("Opening gripper...") - self._set_gripper_position(0.85, rname) - time.sleep(0.5) - - # 3. Preview + execute approach - err = self._preview_execute_wait(rname) - if err: - return err - - # 4. Move to grasp pose - logger.info("Moving to grasp position...") - if not self.plan_to_pose(grasp_pose, rname): - return "Error: Grasp pose planning failed" - err = self._preview_execute_wait(rname) - if err: - return err - - # 5. Close gripper - logger.info("Closing gripper...") - self._set_gripper_position(0.0, rname) - time.sleep(1.5) # Wait for gripper to close - - # 6. Retract to pre-grasp - logger.info("Retracting with object...") - if not self.plan_to_pose(pre_grasp_pose, rname): - return "Error: Retract planning failed" - err = self._preview_execute_wait(rname) - if err: - return err - - # Store pick position so place_back() can return the object - self._last_pick_position = grasp_pose.position - - return f"Pick complete — grasped '{object_name}' successfully" - - return f"Error: All {max_attempts} grasp attempts failed for '{object_name}'" - - @skill - def place( - self, - x: float, - y: float, - z: float, - robot_name: str | None = None, - ) -> str: - """Place a held object at the specified position. - - Plans and executes an approach, lowers to the target, releases the gripper, - and retracts. - - Args: - x: Target X position in meters. - y: Target Y position in meters. - z: Target Z position in meters. - robot_name: Robot to use (only needed for multi-arm setups). - """ - robot = self._get_robot(robot_name) - if robot is None: - return "Error: Robot not found" - rname, _, config, _ = robot - pre_place_offset = config.pre_grasp_offset - - # Compute place pose (top-down approach) - place_pose = Pose(Vector3(x, y, z), Quaternion.from_euler(Vector3(0.0, math.pi, 0.0))) - pre_place_pose = self._compute_pre_grasp_pose(place_pose, pre_place_offset) - - # 1. Move to pre-place - logger.info(f"Planning approach to place position ({x:.3f}, {y:.3f}, {z:.3f})...") - if not self.plan_to_pose(pre_place_pose, rname): - return "Error: Pre-place approach planning failed" - - err = self._preview_execute_wait(rname) - if err: - return err - - # 2. Lower to place position - logger.info("Lowering to place position...") - if not self.plan_to_pose(place_pose, rname): - return "Error: Place pose planning failed" - err = self._preview_execute_wait(rname) - if err: - return err - - # 3. Release - logger.info("Releasing object...") - self._set_gripper_position(0.85, rname) - time.sleep(1.0) - - # 4. Retract - logger.info("Retracting...") - if not self.plan_to_pose(pre_place_pose, rname): - return "Error: Retract planning failed" - err = self._preview_execute_wait(rname) - if err: - return err - - return f"Place complete — object released at ({x:.3f}, {y:.3f}, {z:.3f})" - - @skill - def place_back(self, robot_name: str | None = None) -> str: - """Place the held object back at its original pick position. - - Uses the position stored from the last successful pick operation. - - Args: - robot_name: Robot to use (only needed for multi-arm setups). - """ - if self._last_pick_position is None: - return "Error: No previous pick position stored — run pick() first" - - p = self._last_pick_position - logger.info(f"Placing back at original position ({p.x:.3f}, {p.y:.3f}, {p.z:.3f})...") - return self.place(p.x, p.y, p.z, robot_name) - - @skill - def pick_and_place( - self, - object_name: str, - place_x: float, - place_y: float, - place_z: float, - object_id: str | None = None, - robot_name: str | None = None, - ) -> str: - """Pick up an object and place it at a target location. - - Combines the pick and place skills into a single end-to-end operation. - - Args: - object_name: Name of the object to pick (e.g. "cup", "bottle"). - place_x: Target X position to place the object (meters). - place_y: Target Y position to place the object (meters). - place_z: Target Z position to place the object (meters). - object_id: Optional unique object ID from perception. - robot_name: Robot to use (only needed for multi-arm setups). - """ - logger.info( - f"Starting pick and place: pick '{object_name}' → place at ({place_x:.3f}, {place_y:.3f}, {place_z:.3f})" - ) - - # Pick phase - result = self.pick(object_name, object_id, robot_name) - if result.startswith("Error:"): - return result - - # Place phase - return self.place(place_x, place_y, place_z, robot_name) - # ========================================================================= # Lifecycle # ========================================================================= @@ -1592,12 +1134,6 @@ def stop(self) -> None: """Stop the manipulation module.""" logger.info("Stopping ManipulationModule") - # Stop GraspGen Docker container (thread-safe access to shared state) - with self._lock: - if self._graspgen is not None: - self._graspgen.stop() - self._graspgen = None - # Stop TF thread if self._tf_thread is not None: self._tf_stop_event.set() From 035709d513060ef54583da38e84eb0ab3af2d248 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 24 Feb 2026 19:31:32 -0800 Subject: [PATCH 14/16] pick and place module added --- dimos/manipulation/__init__.py | 8 + dimos/manipulation/pick_and_place_module.py | 625 ++++++++++++++++++++ 2 files changed, 633 insertions(+) create mode 100644 dimos/manipulation/pick_and_place_module.py diff --git a/dimos/manipulation/__init__.py b/dimos/manipulation/__init__.py index 3ed1863092..d2a511d146 100644 --- a/dimos/manipulation/__init__.py +++ b/dimos/manipulation/__init__.py @@ -20,10 +20,18 @@ ManipulationState, manipulation_module, ) +from dimos.manipulation.pick_and_place_module import ( + PickAndPlaceModule, + PickAndPlaceModuleConfig, + pick_and_place_module, +) __all__ = [ "ManipulationModule", "ManipulationModuleConfig", "ManipulationState", + "PickAndPlaceModule", + "PickAndPlaceModuleConfig", "manipulation_module", + "pick_and_place_module", ] diff --git a/dimos/manipulation/pick_and_place_module.py b/dimos/manipulation/pick_and_place_module.py new file mode 100644 index 0000000000..7032b17222 --- /dev/null +++ b/dimos/manipulation/pick_and_place_module.py @@ -0,0 +1,625 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pick-and-place manipulation module. + +Extends ManipulationModule with perception integration and long-horizon skills: +- Perception: objects port, obstacle monitor, scan_objects, get_scene_info +- @rpc: generate_grasps (GraspGen Docker), refresh_obstacles, perception status +- @skill: pick, place, place_back, pick_and_place, scan_objects, get_scene_info +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import math +from pathlib import Path +import time +from typing import TYPE_CHECKING, Any + +from dimos.agents.annotation import skill +from dimos.constants import DIMOS_PROJECT_ROOT +from dimos.core import In, rpc +from dimos.core.docker_runner import DockerModule as DockerRunner +from dimos.manipulation.grasping.graspgen_module import GraspGenModule +from dimos.manipulation.manipulation_module import ( + ManipulationModule, + ManipulationModuleConfig, +) +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.perception.detection.type.detection3d.object import ( + Object as DetObject, # noqa: TC001 +) +from dimos.utils.data import get_data +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.msgs.geometry_msgs import PoseArray + from dimos.msgs.sensor_msgs import PointCloud2 + +logger = setup_logger() + +# The host-side path (graspgen_visualization_output_path) is volume-mounted here. +_GRASPGEN_VIZ_CONTAINER_DIR = "/output/graspgen" +_GRASPGEN_VIZ_CONTAINER_PATH = f"{_GRASPGEN_VIZ_CONTAINER_DIR}/visualization.json" + + +@dataclass +class PickAndPlaceModuleConfig(ManipulationModuleConfig): + """Configuration for PickAndPlaceModule (adds GraspGen settings).""" + + # GraspGen Docker settings + graspgen_docker_image: str = "dimos-graspgen:latest" + graspgen_gripper_type: str = "robotiq_2f_140" + graspgen_num_grasps: int = 400 + graspgen_topk_num_grasps: int = 100 + graspgen_grasp_threshold: float = -1.0 + graspgen_filter_collisions: bool = False + graspgen_save_visualization_data: bool = False + graspgen_visualization_output_path: Path = field( + default_factory=lambda: Path.home() / ".dimos" / "graspgen" / "visualization.json" + ) + + +class PickAndPlaceModule(ManipulationModule): + """Manipulation module with perception integration and pick-and-place skills. + + Extends ManipulationModule with: + - Perception: objects port, obstacle monitor, scan_objects, get_scene_info + - @rpc: generate_grasps (GraspGen Docker), refresh_obstacles, perception status + - @skill: pick, place, place_back, pick_and_place, scan_objects, get_scene_info + """ + + default_config = PickAndPlaceModuleConfig + + # Type annotation for the config attribute (mypy uses this) + config: PickAndPlaceModuleConfig + + # Input: Objects from perception (for obstacle integration) + objects: In[list[DetObject]] + + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + + # GraspGen Docker runner (lazy initialized on first generate_grasps call) + self._graspgen: DockerRunner | None = None + + # Last pick position: stored during pick so place_back() can return the object + self._last_pick_position: Vector3 | None = None + + # Snapshotted detections from the last scan_objects/refresh call. + # The live detection cache is volatile (labels change every frame), + # so pick/place use this stable snapshot instead. + self._detection_snapshot: list[DetObject] = [] + + # ========================================================================= + # Lifecycle (perception integration) + # ========================================================================= + + @rpc + def start(self) -> None: + """Start the pick-and-place module (adds perception subscriptions).""" + super().start() + + # Subscribe to objects port for perception obstacle integration + if self.objects is not None: + self.objects.observable().subscribe(self._on_objects) # type: ignore[no-untyped-call] + logger.info("Subscribed to objects port (async)") + + # Start obstacle monitor for perception integration + if self._world_monitor is not None: + self._world_monitor.start_obstacle_monitor() + + logger.info("PickAndPlaceModule started") + + def _on_objects(self, objects: list[DetObject]) -> None: + """Callback when objects received from perception (runs on RxPY thread pool).""" + try: + if self._world_monitor is not None: + self._world_monitor.on_objects(objects) + except Exception as e: + logger.error(f"Exception in _on_objects: {e}") + + # ========================================================================= + # Perception RPC Methods + # ========================================================================= + + @rpc + def refresh_obstacles(self, min_duration: float = 0.0) -> list[dict[str, Any]]: + """Refresh perception obstacles. Returns the list of obstacles added. + + Also snapshots the current detections so pick/place can use stable labels. + """ + if self._world_monitor is None: + return [] + result = self._world_monitor.refresh_obstacles(min_duration) + # Snapshot detections at refresh time — the live cache is volatile + self._detection_snapshot = self._world_monitor.get_cached_objects() + logger.info(f"Detection snapshot: {[d.name for d in self._detection_snapshot]}") + return result + + @skill + def clear_perception_obstacles(self) -> str: + """Clear all perception obstacles from the planning world. + + Use this when the planner reports COLLISION_AT_START — detected objects + may overlap the robot's current position and block planning. + """ + if self._world_monitor is None: + return "No world monitor available" + count = self._world_monitor.clear_perception_obstacles() + self._detection_snapshot = [] + return f"Cleared {count} perception obstacle(s) from planning world" + + @rpc + def get_perception_status(self) -> dict[str, int]: + """Get perception obstacle status (cached/added counts).""" + if self._world_monitor is None: + return {"cached": 0, "added": 0} + return self._world_monitor.get_perception_status() + + @rpc + def list_cached_detections(self) -> list[dict[str, Any]]: + """List cached detections from perception.""" + if self._world_monitor is None: + return [] + return self._world_monitor.list_cached_detections() + + @rpc + def list_added_obstacles(self) -> list[dict[str, Any]]: + """List perception obstacles currently in the planning world.""" + if self._world_monitor is None: + return [] + return self._world_monitor.list_added_obstacles() + + # ========================================================================= + # GraspGen + # ========================================================================= + + def _get_graspgen(self) -> DockerRunner: + """Get or create GraspGen Docker module (lazy init, thread-safe).""" + # Fast path: already initialized (no lock needed for read) + if self._graspgen is not None: + return self._graspgen + + # Slow path: need to initialize (acquire lock to prevent race condition) + with self._lock: + # Double-check: another thread may have initialized while we waited for lock + if self._graspgen is not None: + return self._graspgen + + # Ensure GraspGen model checkpoints are pulled from LFS + get_data("models_graspgen") + + docker_file = ( + DIMOS_PROJECT_ROOT + / "dimos" + / "manipulation" + / "grasping" + / "docker_context" + / "Dockerfile" + ) + + # Auto-mount host directory for visualization output when enabled. + docker_volumes: list[tuple[str, str, str]] = [] + if self.config.graspgen_save_visualization_data: + host_dir = self.config.graspgen_visualization_output_path.parent + host_dir.mkdir(parents=True, exist_ok=True) + docker_volumes.append((str(host_dir), _GRASPGEN_VIZ_CONTAINER_DIR, "rw")) + + graspgen = DockerRunner( + GraspGenModule, # type: ignore[arg-type] + docker_file=docker_file, + docker_build_context=DIMOS_PROJECT_ROOT, + docker_image=self.config.graspgen_docker_image, + docker_env={"CI": "1"}, # skip interactive system config prompt in container + docker_volumes=docker_volumes, + gripper_type=self.config.graspgen_gripper_type, + num_grasps=self.config.graspgen_num_grasps, + topk_num_grasps=self.config.graspgen_topk_num_grasps, + grasp_threshold=self.config.graspgen_grasp_threshold, + filter_collisions=self.config.graspgen_filter_collisions, + save_visualization_data=self.config.graspgen_save_visualization_data, + visualization_output_path=_GRASPGEN_VIZ_CONTAINER_PATH, + ) + graspgen.start() + self._graspgen = graspgen # cache only after successful start + return self._graspgen + + @rpc + def generate_grasps( + self, + pointcloud: PointCloud2, + scene_pointcloud: PointCloud2 | None = None, + ) -> PoseArray | None: + """Generate grasp poses for the given point cloud via GraspGen Docker module.""" + try: + graspgen = self._get_graspgen() + return graspgen.generate_grasps(pointcloud, scene_pointcloud) # type: ignore[no-any-return] + except Exception as e: + logger.error(f"Grasp generation failed: {e}") + return None + + # ========================================================================= + # Pick/Place Helpers + # ========================================================================= + + def _compute_pre_grasp_pose(self, grasp_pose: Pose, offset: float = 0.10) -> Pose: + """Compute a pre-grasp pose offset along the approach direction (local -Z). + + Args: + grasp_pose: The final grasp pose + offset: Distance to retract along the approach direction (meters) + + Returns: + Pre-grasp pose offset from the grasp pose + """ + from dimos.utils.transform_utils import offset_distance + + return offset_distance(grasp_pose, offset) + + def _find_object_in_detections( + self, object_name: str, object_id: str | None = None + ) -> DetObject | None: + """Find an object in the detection snapshot by name or ID. + + Uses the snapshot taken during the last scan_objects/refresh call, + not the volatile live cache (which changes labels every frame). + + Args: + object_name: Name/label to search for + object_id: Optional specific object ID + + Returns: + Matching DetObject, or None + """ + if not self._detection_snapshot: + logger.warning("No detection snapshot — call scan_objects() first") + return None + + for det in self._detection_snapshot: + if object_id and det.object_id == object_id: + return det + if object_name.lower() in det.name.lower() or det.name.lower() in object_name.lower(): + return det + + available = [det.name for det in self._detection_snapshot] + logger.warning(f"Object '{object_name}' not found in snapshot. Available: {available}") + return None + + def _generate_grasps_for_pick( + self, object_name: str, object_id: str | None = None + ) -> list[Pose] | None: + """Generate grasp poses for an object. + + Computes a top-down approach grasp from the object's detected position. + + Args: + object_name: Name of the object + object_id: Optional object ID + + Returns: + List of grasp poses (best first), or None if object not found + """ + det = self._find_object_in_detections(object_name, object_id) + if det is None: + logger.warning(f"Object '{object_name}' not found in detections") + return None + + c = det.center + grasp_pose = Pose(Vector3(c.x, c.y, c.z), Quaternion.from_euler(Vector3(0.0, math.pi, 0.0))) + logger.info(f"Heuristic grasp for '{object_name}' at ({c.x:.3f}, {c.y:.3f}, {c.z:.3f})") + return [grasp_pose] + + # ========================================================================= + # Perception Skills + # ========================================================================= + + @skill + def get_scene_info(self, robot_name: str | None = None) -> str: + """Get current robot state, detected objects, and scene information. + + Returns a summary of the robot's joint positions, end-effector pose, + gripper state, detected objects, and obstacle count. + + Args: + robot_name: Robot to query (only needed for multi-arm setups). + """ + lines: list[str] = [] + + # Robot state + joints = self.get_current_joints(robot_name) + if joints is not None: + lines.append(f"Joints: [{', '.join(f'{j:.3f}' for j in joints)}]") + else: + lines.append("Joints: unavailable (no state received)") + + ee_pose = self.get_ee_pose(robot_name) + if ee_pose is not None: + p = ee_pose.position + lines.append(f"EE pose: ({p.x:.4f}, {p.y:.4f}, {p.z:.4f})") + else: + lines.append("EE pose: unavailable") + + # Gripper + gripper_pos = self.get_gripper(robot_name) + if gripper_pos is not None: + lines.append(f"Gripper: {gripper_pos:.3f}m") + else: + lines.append("Gripper: not configured") + + # Perception + perception = self.get_perception_status() + lines.append( + f"Perception: {perception.get('cached', 0)} cached, {perception.get('added', 0)} obstacles added" + ) + + detections = self._detection_snapshot + if detections: + lines.append(f"Detected objects ({len(detections)}):") + for det in detections: + c = det.center + lines.append(f" - {det.name}: ({c.x:.3f}, {c.y:.3f}, {c.z:.3f})") + else: + lines.append("Detected objects: none") + + # Visualization + url = self.get_visualization_url() + if url: + lines.append(f"Visualization: {url}") + + # State + lines.append(f"State: {self.get_state()}") + + return "\n".join(lines) + + @skill + def scan_objects(self, min_duration: float = 1.0, robot_name: str | None = None) -> str: + """Scan the scene and list detected objects with their 3D positions. + + Refreshes perception obstacles from the latest sensor data and returns + a formatted list of all detected objects. + + Args: + min_duration: Minimum time in seconds to wait for stable detections. + robot_name: Robot context (only needed for multi-arm setups). + """ + obstacles = self.refresh_obstacles(min_duration) + + detections = self._detection_snapshot + if not detections: + return "No objects detected in scene" + + lines = [f"Detected {len(detections)} object(s):"] + for det in detections: + c = det.center + lines.append(f" - {det.name}: ({c.x:.3f}, {c.y:.3f}, {c.z:.3f})") + + if obstacles: + lines.append(f"\n{len(obstacles)} obstacle(s) added to planning world") + + return "\n".join(lines) + + # ========================================================================= + # Long-Horizon Skills — Pick and Place + # ========================================================================= + + @skill + def pick( + self, + object_name: str, + object_id: str | None = None, + robot_name: str | None = None, + ) -> str: + """Pick up an object by name using grasp planning and motion execution. + + Generates grasp poses, plans collision-free approach/grasp/retract motions, + and executes them. + + Args: + object_name: Name of the object to pick (e.g. "cup", "bottle", "can"). + object_id: Optional unique object ID from perception for precise identification. + robot_name: Robot to use (only needed for multi-arm setups). + """ + robot = self._get_robot(robot_name) + if robot is None: + return "Error: Robot not found" + rname, _, config, _ = robot + pre_grasp_offset = config.pre_grasp_offset + + # 1. Generate grasps (uses already-cached detections — call scan_objects first) + logger.info(f"Generating grasp poses for '{object_name}'...") + grasp_poses = self._generate_grasps_for_pick(object_name, object_id) + if not grasp_poses: + return f"Error: No grasp poses found for '{object_name}'. Object may not be detected." + + # 2. Try each grasp candidate + max_attempts = min(len(grasp_poses), 5) + for i, grasp_pose in enumerate(grasp_poses[:max_attempts]): + pre_grasp_pose = self._compute_pre_grasp_pose(grasp_pose, pre_grasp_offset) + + logger.info(f"Planning approach to pre-grasp (attempt {i + 1}/{max_attempts})...") + if not self.plan_to_pose(pre_grasp_pose, rname): + logger.info(f"Grasp candidate {i + 1} approach planning failed, trying next") + continue # Try next candidate + + # Open gripper before approach + logger.info("Opening gripper...") + self._set_gripper_position(0.85, rname) + time.sleep(0.5) + + # 3. Preview + execute approach + err = self._preview_execute_wait(rname) + if err: + return err + + # 4. Move to grasp pose + logger.info("Moving to grasp position...") + if not self.plan_to_pose(grasp_pose, rname): + return "Error: Grasp pose planning failed" + err = self._preview_execute_wait(rname) + if err: + return err + + # 5. Close gripper + logger.info("Closing gripper...") + self._set_gripper_position(0.0, rname) + time.sleep(1.5) # Wait for gripper to close + + # 6. Retract to pre-grasp + logger.info("Retracting with object...") + if not self.plan_to_pose(pre_grasp_pose, rname): + return "Error: Retract planning failed" + err = self._preview_execute_wait(rname) + if err: + return err + + # Store pick position so place_back() can return the object + self._last_pick_position = grasp_pose.position + + return f"Pick complete — grasped '{object_name}' successfully" + + return f"Error: All {max_attempts} grasp attempts failed for '{object_name}'" + + @skill + def place( + self, + x: float, + y: float, + z: float, + robot_name: str | None = None, + ) -> str: + """Place a held object at the specified position. + + Plans and executes an approach, lowers to the target, releases the gripper, + and retracts. + + Args: + x: Target X position in meters. + y: Target Y position in meters. + z: Target Z position in meters. + robot_name: Robot to use (only needed for multi-arm setups). + """ + robot = self._get_robot(robot_name) + if robot is None: + return "Error: Robot not found" + rname, _, config, _ = robot + pre_place_offset = config.pre_grasp_offset + + # Compute place pose (top-down approach) + place_pose = Pose(Vector3(x, y, z), Quaternion.from_euler(Vector3(0.0, math.pi, 0.0))) + pre_place_pose = self._compute_pre_grasp_pose(place_pose, pre_place_offset) + + # 1. Move to pre-place + logger.info(f"Planning approach to place position ({x:.3f}, {y:.3f}, {z:.3f})...") + if not self.plan_to_pose(pre_place_pose, rname): + return "Error: Pre-place approach planning failed" + + err = self._preview_execute_wait(rname) + if err: + return err + + # 2. Lower to place position + logger.info("Lowering to place position...") + if not self.plan_to_pose(place_pose, rname): + return "Error: Place pose planning failed" + err = self._preview_execute_wait(rname) + if err: + return err + + # 3. Release + logger.info("Releasing object...") + self._set_gripper_position(0.85, rname) + time.sleep(1.0) + + # 4. Retract + logger.info("Retracting...") + if not self.plan_to_pose(pre_place_pose, rname): + return "Error: Retract planning failed" + err = self._preview_execute_wait(rname) + if err: + return err + + return f"Place complete — object released at ({x:.3f}, {y:.3f}, {z:.3f})" + + @skill + def place_back(self, robot_name: str | None = None) -> str: + """Place the held object back at its original pick position. + + Uses the position stored from the last successful pick operation. + + Args: + robot_name: Robot to use (only needed for multi-arm setups). + """ + if self._last_pick_position is None: + return "Error: No previous pick position stored — run pick() first" + + p = self._last_pick_position + logger.info(f"Placing back at original position ({p.x:.3f}, {p.y:.3f}, {p.z:.3f})...") + return self.place(p.x, p.y, p.z, robot_name) + + @skill + def pick_and_place( + self, + object_name: str, + place_x: float, + place_y: float, + place_z: float, + object_id: str | None = None, + robot_name: str | None = None, + ) -> str: + """Pick up an object and place it at a target location. + + Combines the pick and place skills into a single end-to-end operation. + + Args: + object_name: Name of the object to pick (e.g. "cup", "bottle"). + place_x: Target X position to place the object (meters). + place_y: Target Y position to place the object (meters). + place_z: Target Z position to place the object (meters). + object_id: Optional unique object ID from perception. + robot_name: Robot to use (only needed for multi-arm setups). + """ + logger.info( + f"Starting pick and place: pick '{object_name}' → place at ({place_x:.3f}, {place_y:.3f}, {place_z:.3f})" + ) + + # Pick phase + result = self.pick(object_name, object_id, robot_name) + if result.startswith("Error:"): + return result + + # Place phase + return self.place(place_x, place_y, place_z, robot_name) + + # ========================================================================= + # Lifecycle + # ========================================================================= + + @rpc + def stop(self) -> None: + """Stop the pick-and-place module (cleanup GraspGen + delegate to base).""" + logger.info("Stopping PickAndPlaceModule") + + # Stop GraspGen Docker container (thread-safe access to shared state) + with self._lock: + if self._graspgen is not None: + self._graspgen.stop() + self._graspgen = None + + super().stop() + + +# Expose blueprint for declarative composition +pick_and_place_module = PickAndPlaceModule.blueprint From 9e53cf765658e6c707373a51310936cba8a356a0 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 24 Feb 2026 19:33:11 -0800 Subject: [PATCH 15/16] updated blueprints --- dimos/manipulation/manipulation_blueprints.py | 57 +++++++++++++++++-- dimos/robot/all_blueprints.py | 2 + 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/dimos/manipulation/manipulation_blueprints.py b/dimos/manipulation/manipulation_blueprints.py index e95e415373..396f2046e4 100644 --- a/dimos/manipulation/manipulation_blueprints.py +++ b/dimos/manipulation/manipulation_blueprints.py @@ -33,6 +33,7 @@ from dimos.core.transport import LCMTransport from dimos.hardware.sensors.camera.realsense import realsense_camera from dimos.manipulation.manipulation_module import manipulation_module +from dimos.manipulation.pick_and_place_module import pick_and_place_module from dimos.manipulation.planning.spec import RobotModelConfig from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 from dimos.msgs.sensor_msgs import JointState @@ -328,6 +329,43 @@ def _make_piper_config( ) +# XArm7 planner + LLM agent for testing base ManipulationModule skills +# No perception — uses the base module's planning + gripper skills only. +# Usage: dimos run coordinator-mock, then dimos run xarm7-planner-coordinator-agent +_BASE_MANIPULATION_AGENT_SYSTEM_PROMPT = """\ +You are a robotic manipulation assistant controlling an xArm7 robot arm. + +Available skills: +- get_robot_state: Get current joint positions, end-effector pose, and gripper state. +- move_to_pose: Move end-effector to ABSOLUTE x, y, z (meters) with optional roll, pitch, yaw (radians). +- move_to_joints: Move to a joint configuration (comma-separated radians). +- open_gripper / close_gripper / set_gripper: Control the gripper. +- go_home: Move to the home/observe position. +- go_init: Return to the startup position. +- reset: Clear a FAULT state and return to IDLE. Use this when a motion fails. + +COORDINATE SYSTEM (world frame, meters): +- X axis = forward (away from the robot base) +- Y axis = left +- Z axis = up +- Z=0 is the robot base level; typical working height is Z = 0.2-0.5 + +CRITICAL WORKFLOW for relative movement requests (e.g. "move 20cm forward"): +1. Call get_robot_state to get the current EE pose. +2. Add the requested offset to the CURRENT position. Example: if EE is at \ +(0.3, 0.0, 0.4) and user says "move 20cm forward", target is (0.5, 0.0, 0.4). +3. Call move_to_pose with the computed ABSOLUTE target. +NEVER pass only the offset as coordinates — that would send the robot to near-origin. + +ERROR RECOVERY: If a motion fails or the state becomes FAULT, call reset before retrying. +""" + +xarm7_planner_coordinator_agent = autoconnect( + xarm7_planner_coordinator, + Agent.blueprint(system_prompt=_BASE_MANIPULATION_AGENT_SYSTEM_PROMPT), +) + + # XArm7 with eye-in-hand RealSense camera for perception-based manipulation # TF chain: world → link7 (ManipulationModule) → camera_link (RealSense) # Usage: dimos run coordinator-mock, then dimos run xarm-perception @@ -338,7 +376,7 @@ def _make_piper_config( xarm_perception = ( autoconnect( - manipulation_module( + pick_and_place_module( robots=[ _make_xarm7_config( "arm", @@ -375,22 +413,30 @@ def _make_piper_config( _MANIPULATION_AGENT_SYSTEM_PROMPT = """\ You are a robotic manipulation assistant controlling an xArm7 robot arm. -Use ONLY these ManipulationModule skills for manipulation tasks: +Available skills: +- get_robot_state: Get current joint positions, end-effector pose, and gripper state. - scan_objects: Scan scene and list detected objects with 3D positions. Always call this first. - pick: Pick up an object by name. Requires scan_objects first. - place: Place a held object at x, y, z position. - place_back: Place a held object back at its original pick position. - pick_and_place: Pick an object and place it at a target location. -- move_to_pose: Move end-effector to x, y, z with optional roll, pitch, yaw. +- move_to_pose: Move end-effector to ABSOLUTE x, y, z (meters) with optional roll, pitch, yaw (radians). - move_to_joints: Move to a joint configuration (comma-separated radians). - open_gripper / close_gripper / set_gripper: Control the gripper. - go_home: Move to the home/observe position. - go_init: Return to the startup position. - get_scene_info: Get full robot state, detected objects, and scene info. +- reset: Clear a FAULT state and return to IDLE. +- clear_perception_obstacles: Clear detected obstacles from the planning world. \ +Use when planning fails with COLLISION_AT_START. + +COORDINATE SYSTEM (world frame, meters): X=forward, Y=left, Z=up. Z=0 is robot base. + +ERROR RECOVERY: If planning fails with COLLISION_AT_START, call clear_perception_obstacles \ +then reset, then retry. Detected objects may overlap the robot's current position. -Do NOT use the 'detect' or 'select' skills — use scan_objects instead. -For robot_name parameters, always omit or pass None (single-arm setup). After pick or place, return to init with go_init unless another action follows immediately. +Do NOT use the 'detect' or 'select' skills — use scan_objects instead. """ xarm_perception_agent = autoconnect( @@ -405,6 +451,7 @@ def _make_piper_config( "dual_xarm6_planner", "xarm6_planner_only", "xarm7_planner_coordinator", + "xarm7_planner_coordinator_agent", "xarm_perception", "xarm_perception_agent", ] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index ed90982801..3119b0ead0 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -84,6 +84,7 @@ "xarm-perception-agent": "dimos.manipulation.manipulation_blueprints:xarm_perception_agent", "xarm6-planner-only": "dimos.manipulation.manipulation_blueprints:xarm6_planner_only", "xarm7-planner-coordinator": "dimos.manipulation.manipulation_blueprints:xarm7_planner_coordinator", + "xarm7-planner-coordinator-agent": "dimos.manipulation.manipulation_blueprints:xarm7_planner_coordinator_agent", "xarm7-trajectory-sim": "dimos.simulation.sim_blueprints:xarm7_trajectory_sim", } @@ -122,6 +123,7 @@ "person-follow-skill": "dimos.agents.skills.person_follow", "person-tracker-module": "dimos.perception.detection.person_tracker", "phone-teleop-module": "dimos.teleop.phone.phone_teleop_module", + "pick-and-place-module": "dimos.manipulation.pick_and_place_module", "quest-teleop-module": "dimos.teleop.quest.quest_teleop_module", "realsense-camera": "dimos.hardware.sensors.camera.realsense.camera", "replanning-a-star-planner": "dimos.navigation.replanning_a_star.module", From 0a87c29002e5809af531d14de5ef6152ffc162c7 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 24 Feb 2026 19:33:23 -0800 Subject: [PATCH 16/16] added unit test --- dimos/manipulation/test_manipulation_unit.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dimos/manipulation/test_manipulation_unit.py b/dimos/manipulation/test_manipulation_unit.py index de551d99cd..4aa232c74f 100644 --- a/dimos/manipulation/test_manipulation_unit.py +++ b/dimos/manipulation/test_manipulation_unit.py @@ -100,7 +100,6 @@ def _make_module(): module._planner = None module._kinematics = None module._coordinator_client = None - module._graspgen = None return module @@ -129,12 +128,14 @@ def test_reset_not_during_execution(self): module._state = ManipulationState.FAULT module._error_message = "Error" - assert module.reset() is True + result = module.reset() + assert "IDLE" in result assert module._state == ManipulationState.IDLE assert module._error_message == "" module._state = ManipulationState.EXECUTING - assert module.reset() is False + result = module.reset() + assert "Error" in result def test_fail_sets_fault_state(self): """_fail helper sets FAULT state and message."""