Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dimos/agents/mcp/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dimos.core import Module, rpc # noqa: I001
from dimos.core.rpc_client import RpcCall, RPCClient

from starlette.requests import Request # noqa: TC002
from starlette.requests import Request

if TYPE_CHECKING:
import concurrent.futures
Expand Down
37 changes: 17 additions & 20 deletions dimos/agents/skills/person_follow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from threading import Event, RLock, Thread
import time
from typing import TYPE_CHECKING
from typing import Any

from langchain_core.messages import HumanMessage
import numpy as np
Expand All @@ -23,8 +23,8 @@
from dimos.agents.agent import AgentSpec
from dimos.agents.annotation import skill
from dimos.core.core import rpc
from dimos.core.global_config import GlobalConfig
from dimos.core.module import Module
from dimos.core.global_config import GlobalConfig, global_config
from dimos.core.module import Module, ModuleConfig
from dimos.core.stream import In, Out
from dimos.models.qwen.video_query import BBox
from dimos.models.segmentation.edge_tam import EdgeTAMProcessor
Expand All @@ -36,13 +36,15 @@
from dimos.navigation.visual_servoing.visual_servoing_2d import VisualServoing2D
from dimos.utils.logging_config import setup_logger

if TYPE_CHECKING:
from dimos.models.vl.base import VlModel

logger = setup_logger()


class PersonFollowSkillContainer(Module):
class Config(ModuleConfig):
camera_info: CameraInfo
use_3d_navigation: bool = False


class PersonFollowSkillContainer(Module[Config]):
"""Skill container for following a person.

This skill uses:
Expand All @@ -52,6 +54,8 @@ class PersonFollowSkillContainer(Module):
- Does not do obstacle avoidance; assumes a clear path.
"""

default_config = Config

color_image: In[Image]
global_map: In[PointCloud2]
cmd_vel: Out[Twist]
Expand All @@ -60,38 +64,31 @@ class PersonFollowSkillContainer(Module):
_frequency: float = 20.0 # Hz - control loop frequency
_max_lost_frames: int = 15 # number of frames to wait before declaring person lost

def __init__(
self,
camera_info: CameraInfo,
cfg: GlobalConfig,
use_3d_navigation: bool = False,
) -> None:
super().__init__()
self._global_config: GlobalConfig = cfg
self._use_3d_navigation: bool = use_3d_navigation
def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None:
super().__init__(global_config, **kwargs)
self._latest_image: Image | None = None
self._latest_pointcloud: PointCloud2 | None = None
self._vl_model: VlModel = QwenVlModel()
self._vl_model = QwenVlModel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for using : VlModel is so people use it through that Protocol. Otherwise they might call methods which exist on QwenVlModel but not VlModel.

self._tracker: EdgeTAMProcessor | None = None
self._thread: Thread | None = None
self._should_stop: Event = Event()
self._lock = RLock()

# Use MuJoCo camera intrinsics in simulation mode
camera_info = self.config.camera_info
if self._global_config.simulation:
from dimos.robot.unitree.mujoco_connection import MujocoConnection

camera_info = MujocoConnection.camera_info_static

self._camera_info = camera_info
self._visual_servo = VisualServoing2D(camera_info, self._global_config.simulation)
self._detection_navigation = DetectionNavigation(self.tf, camera_info)

@rpc
def start(self) -> None:
super().start()
self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image)))
if self._use_3d_navigation:
if self.config.use_3d_navigation:
self._disposables.add(Disposable(self.global_map.subscribe(self._on_pointcloud)))

@rpc
Expand Down Expand Up @@ -227,7 +224,7 @@ def _follow_loop(self, tracker: EdgeTAMProcessor, query: str) -> None:
lost_count = 0
best_detection = max(detections.detections, key=lambda d: d.bbox_2d_volume())

if self._use_3d_navigation:
if self.config.use_3d_navigation:
with self._lock:
pointcloud = self._latest_pointcloud
if pointcloud is None:
Expand Down
7 changes: 3 additions & 4 deletions dimos/control/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@
TwistBaseAdapter,
)
from dimos.msgs.geometry_msgs import (
PoseStamped, # noqa: TC001 - needed at runtime for In[PoseStamped]
Twist, # noqa: TC001 - needed at runtime for In[Twist]
PoseStamped,
Twist,
)
from dimos.msgs.sensor_msgs import (
JointState,
)
from dimos.teleop.quest.quest_types import (
Buttons, # noqa: TC001 - needed at runtime for In[Buttons]
Buttons,
)
from dimos.utils.logging_config import setup_logger

Expand Down Expand Up @@ -97,7 +97,6 @@ class TaskConfig:
hand: str = "" # teleop_ik only: "left" or "right" controller


@dataclass
class ControlCoordinatorConfig(ModuleConfig):
"""Configuration for the ControlCoordinator.
Expand Down
2 changes: 1 addition & 1 deletion dimos/core/_test_future_annotations_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from __future__ import annotations

from dimos.core.module import Module
from dimos.core.stream import In, Out # noqa
from dimos.core.stream import In, Out


class FutureData:
Expand Down
53 changes: 26 additions & 27 deletions dimos/core/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from functools import cached_property, reduce
import inspect
import operator
import sys
from types import MappingProxyType
Expand All @@ -27,14 +26,19 @@
from dimos.protocol.service.system_configurator.base import SystemConfigurator

from dimos.core.global_config import GlobalConfig, global_config
from dimos.core.module import Module, is_module_type
from dimos.core.module import Module, ModuleBase, ModuleSpec, is_module_type
from dimos.core.module_coordinator import ModuleCoordinator
from dimos.core.stream import In, Out
from dimos.core.transport import LCMTransport, PubSubTransport, pLCMTransport
from dimos.spec.utils import Spec, is_spec, spec_annotation_compliance, spec_structural_compliance
from dimos.utils.generic import short_id
from dimos.utils.logging_config import setup_logger

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing import Any as Self

logger = setup_logger()


Expand All @@ -48,21 +52,18 @@ class StreamRef:
@dataclass(frozen=True)
class ModuleRef:
name: str
spec: type[Spec] | type[Module]
spec: type[Spec] | type[ModuleBase]


@dataclass(frozen=True)
class _BlueprintAtom:
module: type[Module]
kwargs: dict[str, Any]
module: type[ModuleBase[Any]]
streams: tuple[StreamRef, ...]
module_refs: tuple[ModuleRef, ...]
args: tuple[Any, ...]
kwargs: dict[str, Any]

@classmethod
def create(
cls, module: type[Module], args: tuple[Any, ...], kwargs: dict[str, Any]
) -> "_BlueprintAtom":
def create(cls, module: type[ModuleBase[Any]], kwargs: dict[str, Any]) -> Self:
streams: list[StreamRef] = []
module_refs: list[ModuleRef] = []

Expand Down Expand Up @@ -96,7 +97,6 @@ def create(
module=module,
streams=tuple(streams),
module_refs=tuple(module_refs),
args=args,
kwargs=kwargs,
)

Expand All @@ -108,15 +108,15 @@ class Blueprint:
default_factory=lambda: MappingProxyType({})
)
global_config_overrides: Mapping[str, Any] = field(default_factory=lambda: MappingProxyType({}))
remapping_map: Mapping[tuple[type[Module], str], str | type[Module] | type[Spec]] = field(
default_factory=lambda: MappingProxyType({})
remapping_map: Mapping[tuple[type[ModuleBase], str], str | type[ModuleBase] | type[Spec]] = (
field(default_factory=lambda: MappingProxyType({}))
)
requirement_checks: tuple[Callable[[], str | None], ...] = field(default_factory=tuple)
configurator_checks: "tuple[SystemConfigurator, ...]" = field(default_factory=tuple)

@classmethod
def create(cls, module: type[Module], *args: Any, **kwargs: Any) -> "Blueprint":
blueprint = _BlueprintAtom.create(module, args, kwargs)
def create(cls, module: type[ModuleBase], **kwargs: Any) -> "Blueprint":
blueprint = _BlueprintAtom.create(module, kwargs)
return cls(blueprints=(blueprint,))

def transports(self, transports: dict[tuple[str, type], Any]) -> "Blueprint":
Expand All @@ -140,7 +140,10 @@ def global_config(self, **kwargs: Any) -> "Blueprint":
)

def remappings(
self, remappings: list[tuple[type[Module], str, str | type[Module] | type[Spec]]]
self,
remappings: list[
tuple[type[ModuleBase[Any]], str, str | type[ModuleBase[Any]] | type[Spec]]
],
) -> "Blueprint":
remappings_dict = dict(self.remapping_map)
for module, old, new in remappings:
Expand Down Expand Up @@ -178,8 +181,8 @@ def configurators(self, *checks: "SystemConfigurator") -> "Blueprint":
def _check_ambiguity(
self,
requested_method_name: str,
interface_methods: Mapping[str, list[tuple[type[Module], Callable[..., Any]]]],
requesting_module: type[Module],
interface_methods: Mapping[str, list[tuple[type[ModuleBase], Callable[..., Any]]]],
requesting_module: type[ModuleBase],
) -> None:
if (
requested_method_name in interface_methods
Expand Down Expand Up @@ -288,13 +291,9 @@ def _verify_no_name_conflicts(self) -> None:
def _deploy_all_modules(
self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig
) -> None:
module_specs: list[tuple[type[Module], tuple[Any, ...], dict[str, Any]]] = []
module_specs: list[ModuleSpec] = []
for blueprint in self.blueprints:
kwargs = {**blueprint.kwargs}
sig = inspect.signature(blueprint.module.__init__)
if "cfg" in sig.parameters:
kwargs["cfg"] = global_config
module_specs.append((blueprint.module, blueprint.args, kwargs))
module_specs.append((blueprint.module, global_config, blueprint.kwargs))

module_coordinator.deploy_parallel(module_specs)

Expand Down Expand Up @@ -414,12 +413,12 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None:
rpc_methods_dot = {}

# Track interface methods to detect ambiguity.
interface_methods: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = (
interface_methods: defaultdict[str, list[tuple[type[ModuleBase], Callable[..., Any]]]] = (
defaultdict(list)
) # interface_name_method -> [(module_class, method)]
interface_methods_dot: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = (
defaultdict(list)
) # interface_name.method -> [(module_class, method)]
interface_methods_dot: defaultdict[
str, list[tuple[type[ModuleBase], Callable[..., Any]]]
] = defaultdict(list) # interface_name.method -> [(module_class, method)]

for blueprint in self.blueprints:
for method_name in blueprint.module.rpcs.keys(): # type: ignore[attr-defined]
Expand Down
3 changes: 1 addition & 2 deletions dimos/core/docker_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import argparse
from contextlib import suppress
from dataclasses import dataclass, field
from dataclasses import field
import importlib
import json
import os
Expand Down Expand Up @@ -46,7 +46,6 @@
LOG_TAIL_LINES = 200 # Number of log lines to include in error messages


@dataclass(kw_only=True)
class DockerModuleConfig(ModuleConfig):
"""
Configuration for running a DimOS module inside Docker.
Expand Down
10 changes: 5 additions & 5 deletions dimos/core/introspection/blueprint/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
color_for_string,
sanitize_id,
)
from dimos.core.module import Module
from dimos.core.module import ModuleBase
from dimos.utils.cli import theme


Expand Down Expand Up @@ -83,11 +83,11 @@ def render(
ignored_modules = DEFAULT_IGNORED_MODULES

# Collect all outputs: (name, type) -> list of producer modules
producers: dict[tuple[str, type], list[type[Module]]] = defaultdict(list)
producers: dict[tuple[str, type], list[type[ModuleBase]]] = defaultdict(list)
# Collect all inputs: (name, type) -> list of consumer modules
consumers: dict[tuple[str, type], list[type[Module]]] = defaultdict(list)
consumers: dict[tuple[str, type], list[type[ModuleBase]]] = defaultdict(list)
# Module name -> module class (for getting package info)
module_classes: dict[str, type[Module]] = {}
module_classes: dict[str, type[ModuleBase]] = {}

for bp in blueprint_set.blueprints:
module_classes[bp.module.__name__] = bp.module
Expand Down Expand Up @@ -118,7 +118,7 @@ def render(
active_channels[key] = color_for_string(TYPE_COLORS, label)

# Group modules by package
def get_group(mod_class: type[Module]) -> str:
def get_group(mod_class: type[ModuleBase]) -> str:
module_path = mod_class.__module__
parts = module_path.split(".")
if len(parts) >= 2 and parts[0] == "dimos":
Expand Down
Loading
Loading