From 652a5baba9f196fc94daed7b941c60335a1689f3 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 25 Feb 2026 15:51:31 +0000 Subject: [PATCH 1/9] Adjust config definitions for future enhancements --- dimos/core/blueprints.py | 47 ++++++++----------- dimos/core/introspection/blueprint/dot.py | 10 ++-- dimos/core/module.py | 37 +++++++++------ dimos/core/module_coordinator.py | 38 +++++++-------- dimos/core/worker.py | 33 ++++++++----- dimos/core/worker_manager.py | 15 +++--- dimos/hardware/sensors/camera/module.py | 23 ++++----- dimos/hardware/sensors/camera/zed/__init__.py | 34 +------------- dimos/hardware/sensors/fake_zed_module.py | 11 ++--- dimos/manipulation/manipulation_module.py | 10 ++-- dimos/mapping/costmapper.py | 8 ++-- dimos/mapping/voxels.py | 8 ++-- dimos/perception/object_tracker.py | 23 ++++----- dimos/perception/object_tracker_2d.py | 6 ++- dimos/protocol/service/spec.py | 6 +-- dimos/robot/foxglove_bridge.py | 32 +++++-------- dimos/robot/unitree/b1/connection.py | 30 ++++++++---- dimos/robot/unitree/b1/unitree_b1.py | 4 +- dimos/visualization/rerun/bridge.py | 3 +- 19 files changed, 173 insertions(+), 205 deletions(-) diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index 605517e6cf..8e5fc9c627 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -17,14 +17,13 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass, field from functools import cached_property, reduce -import inspect import operator import sys from types import MappingProxyType -from typing import Any, Literal, get_args, get_origin, get_type_hints +from typing import Any, Literal, Self, get_args, get_origin, get_type_hints from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module, is_module_type +from dimos.core.module import Module, ModuleBase, ModuleSpec, is_module_type from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport, PubSubTransport, pLCMTransport @@ -45,21 +44,18 @@ class StreamRef: @dataclass(frozen=True) class ModuleRef: name: str - spec: type[Spec] | type[Module] + spec: type[Spec] | type[ModuleBase] @dataclass(frozen=True) class _BlueprintAtom: - module: type[Module] + kwargs: dict[str, Any] + module: type[ModuleBase[Any]] streams: tuple[StreamRef, ...] module_refs: tuple[ModuleRef, ...] - args: tuple[Any, ...] - kwargs: dict[str, Any] @classmethod - def create( - cls, module: type[Module], args: tuple[Any, ...], kwargs: dict[str, Any] - ) -> "_BlueprintAtom": + def create(cls, module: type[ModuleBase[Any]], kwargs: dict[str, Any]) -> Self: streams: list[StreamRef] = [] module_refs: list[ModuleRef] = [] @@ -93,7 +89,6 @@ def create( module=module, streams=tuple(streams), module_refs=tuple(module_refs), - args=args, kwargs=kwargs, ) @@ -105,14 +100,14 @@ class Blueprint: default_factory=lambda: MappingProxyType({}) ) global_config_overrides: Mapping[str, Any] = field(default_factory=lambda: MappingProxyType({})) - remapping_map: Mapping[tuple[type[Module], str], str | type[Module] | type[Spec]] = field( - default_factory=lambda: MappingProxyType({}) + remapping_map: Mapping[tuple[type[ModuleBase], str], str | type[ModuleBase] | type[Spec]] = ( + field(default_factory=lambda: MappingProxyType({})) ) requirement_checks: tuple[Callable[[], str | None], ...] = field(default_factory=tuple) @classmethod - def create(cls, module: type[Module], *args: Any, **kwargs: Any) -> "Blueprint": - blueprint = _BlueprintAtom.create(module, args, kwargs) + def create(cls, module: type[ModuleBase], **kwargs: Any) -> "Blueprint": + blueprint = _BlueprintAtom.create(module, kwargs) return cls(blueprints=(blueprint,)) def transports(self, transports: dict[tuple[str, type], Any]) -> "Blueprint": @@ -134,7 +129,7 @@ def global_config(self, **kwargs: Any) -> "Blueprint": ) def remappings( - self, remappings: list[tuple[type[Module], str, str | type[Module] | type[Spec]]] + self, remappings: list[tuple[type[ModuleBase], str, str | type[ModuleBase] | type[Spec]]] ) -> "Blueprint": remappings_dict = dict(self.remapping_map) for module, old, new in remappings: @@ -160,8 +155,8 @@ def requirements(self, *checks: Callable[[], str | None]) -> "Blueprint": def _check_ambiguity( self, requested_method_name: str, - interface_methods: Mapping[str, list[tuple[type[Module], Callable[..., Any]]]], - requesting_module: type[Module], + interface_methods: Mapping[str, list[tuple[type[ModuleBase], Callable[..., Any]]]], + requesting_module: type[ModuleBase], ) -> None: if ( requested_method_name in interface_methods @@ -255,13 +250,9 @@ def _verify_no_name_conflicts(self) -> None: def _deploy_all_modules( self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig ) -> None: - module_specs: list[tuple[type[Module], tuple[Any, ...], dict[str, Any]]] = [] + module_specs: list[ModuleSpec] = [] for blueprint in self.blueprints: - kwargs = {**blueprint.kwargs} - sig = inspect.signature(blueprint.module.__init__) - if "cfg" in sig.parameters: - kwargs["cfg"] = global_config - module_specs.append((blueprint.module, blueprint.args, kwargs)) + module_specs.append((blueprint.module, global_config, blueprint.kwargs)) module_coordinator.deploy_parallel(module_specs) @@ -381,12 +372,12 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None: rpc_methods_dot = {} # Track interface methods to detect ambiguity. - interface_methods: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = ( + interface_methods: defaultdict[str, list[tuple[type[ModuleBase], Callable[..., Any]]]] = ( defaultdict(list) ) # interface_name_method -> [(module_class, method)] - interface_methods_dot: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = ( - defaultdict(list) - ) # interface_name.method -> [(module_class, method)] + interface_methods_dot: defaultdict[ + str, list[tuple[type[ModuleBase], Callable[..., Any]]] + ] = defaultdict(list) # interface_name.method -> [(module_class, method)] for blueprint in self.blueprints: for method_name in blueprint.module.rpcs.keys(): # type: ignore[attr-defined] diff --git a/dimos/core/introspection/blueprint/dot.py b/dimos/core/introspection/blueprint/dot.py index c60ad06fc8..54684f3b51 100644 --- a/dimos/core/introspection/blueprint/dot.py +++ b/dimos/core/introspection/blueprint/dot.py @@ -31,7 +31,7 @@ color_for_string, sanitize_id, ) -from dimos.core.module import Module +from dimos.core.module import ModuleBase from dimos.utils.cli import theme @@ -83,11 +83,11 @@ def render( ignored_modules = DEFAULT_IGNORED_MODULES # Collect all outputs: (name, type) -> list of producer modules - producers: dict[tuple[str, type], list[type[Module]]] = defaultdict(list) + producers: dict[tuple[str, type], list[type[ModuleBase]]] = defaultdict(list) # Collect all inputs: (name, type) -> list of consumer modules - consumers: dict[tuple[str, type], list[type[Module]]] = defaultdict(list) + consumers: dict[tuple[str, type], list[type[ModuleBase]]] = defaultdict(list) # Module name -> module class (for getting package info) - module_classes: dict[str, type[Module]] = {} + module_classes: dict[str, type[ModuleBase]] = {} for bp in blueprint_set.blueprints: module_classes[bp.module.__name__] = bp.module @@ -118,7 +118,7 @@ def render( active_channels[key] = color_for_string(TYPE_COLORS, label) # Group modules by package - def get_group(mod_class: type[Module]) -> str: + def get_group(mod_class: type[ModuleBase]) -> str: module_path = mod_class.__module__ parts = module_path.split(".") if len(parts) >= 2 and parts[0] == "dimos": diff --git a/dimos/core/module.py b/dimos/core/module.py index d6089a8f0a..5f21af5a08 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -23,28 +23,28 @@ from typing import ( TYPE_CHECKING, Any, + Protocol, get_args, get_origin, get_type_hints, overload, ) -from typing_extensions import TypeVar as TypeVarExtension - if TYPE_CHECKING: from collections.abc import Callable + from dimos.core.blueprints import Blueprint from dimos.core.introspection.module import ModuleInfo from dimos.core.rpc_client import RPCClient -from typing import TypeVar - from dask.distributed import Actor, get_worker from langchain_core.tools import tool +from pydantic import BaseModel from reactivex.disposable import CompositeDisposable from dimos.core import colors from dimos.core.core import T, rpc +from dimos.core.global_config import GlobalConfig, global_config from dimos.core.introspection.module import extract_module_info, render_module_io from dimos.core.resource import Resource from dimos.core.rpc_client import RpcCall # noqa: TC001 @@ -54,6 +54,11 @@ from dimos.protocol.tf import LCMTF, TFSpec from dimos.utils.generic import classproperty +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + @dataclass(frozen=True) class SkillInfo: @@ -90,15 +95,18 @@ def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: return loop, thr -@dataclass -class ModuleConfig: +class ModuleConfig(BaseModel): rpc_transport: type[RPCSpec] = LCMRPC tf_transport: type[TFSpec] = LCMTF frame_id_prefix: str | None = None frame_id: str | None = None -ModuleConfigT = TypeVarExtension("ModuleConfigT", bound=ModuleConfig, default=ModuleConfig) +ModuleConfigT = TypeVar("ModuleConfigT", bound=ModuleConfig, default=ModuleConfig) + + +class _BlueprintPartial(Protocol): + def __call__(self, **kwargs: Any) -> Blueprint: ... class ModuleBase(Configurable[ModuleConfigT], Resource): @@ -111,10 +119,9 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): rpc_calls: list[str] = [] - default_config: type[ModuleConfigT] = ModuleConfig # type: ignore[assignment] - - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - super().__init__(*args, **kwargs) + def __init__(self, config_args: dict[str, Any], global_config: GlobalConfig): + super().__init__(**config_args) + self._global_config = global_config self._loop, self._loop_thread = get_loop() self._disposables = CompositeDisposable() # we can completely override comms protocols if we want @@ -346,7 +353,7 @@ def __get__( module_info = _module_info_descriptor() @classproperty - def blueprint(self): # type: ignore[no-untyped-def] + def blueprint(self) -> _BlueprintPartial: # Here to prevent circular imports. from dimos.core.blueprints import Blueprint @@ -426,7 +433,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None: if not hasattr(cls, name) or getattr(cls, name) is None: setattr(cls, name, None) - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any): self.ref = None # type: ignore[assignment] # Get type hints with proper namespace resolution for subclasses @@ -455,7 +462,7 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] inner, *_ = get_args(ann) or (Any,) stream = In(inner, name, self) # type: ignore[assignment] setattr(self, name, stream) - super().__init__(*args, **kwargs) + super().__init__(global_config, **kwargs) def set_ref(self, ref) -> int: # type: ignore[no-untyped-def] worker = get_worker() @@ -505,7 +512,7 @@ def dask_register_subscriber(self, output_name: str, subscriber: RemoteIn[T]) -> getattr(self, output_name).transport.dask_register_subscriber(subscriber) -ModuleT = TypeVar("ModuleT", bound="Module") +ModuleSpec = tuple[type[ModuleBase], GlobalConfig, dict[str, Any]] def is_module_type(value: Any) -> bool: diff --git a/dimos/core/module_coordinator.py b/dimos/core/module_coordinator.py index c6d975731d..cacf8e8d41 100644 --- a/dimos/core/module_coordinator.py +++ b/dimos/core/module_coordinator.py @@ -19,7 +19,7 @@ from dimos import core from dimos.core import DimosCluster from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module, ModuleT +from dimos.core.module import ModuleBase, ModuleSpec from dimos.core.resource import Resource from dimos.core.worker_manager import WorkerManager @@ -32,7 +32,7 @@ class ModuleCoordinator(Resource): # type: ignore[misc] _global_config: GlobalConfig _n: int | None = None _memory_limit: str = "auto" - _deployed_modules: dict[type[Module], "ModuleProxy"] + _deployed_modules: dict[type[ModuleBase], "ModuleProxy"] def __init__( self, @@ -56,17 +56,20 @@ def stop(self) -> None: self._client.close_all() # type: ignore[union-attr] - def deploy(self, module_class: type[ModuleT], *args, **kwargs) -> "ModuleProxy": # type: ignore[no-untyped-def] - if not self._client: - raise ValueError("Trying to dimos.deploy before dask client has started") - - module: ModuleProxy = self._client.deploy(module_class, *args, **kwargs) # type: ignore[union-attr, attr-defined, assignment] - self._deployed_modules[module_class] = module - return module - - def deploy_parallel( - self, module_specs: list[tuple[type[ModuleT], tuple[Any, ...], dict[str, Any]]] - ) -> list["ModuleProxy"]: + def deploy( + self, + module_class: type[ModuleBase[Any]], + global_config: GlobalConfig = global_config, + **kwargs: Any, + ) -> "ModuleProxy": + if not isinstance(self._client, WorkerManager): + raise RuntimeError("Trying to dimos.deploy before dask client has started") + + module = self._client.deploy(module_class, global_config, **kwargs) + self._deployed_modules[module_class] = module # type: ignore[assignment] + return module # type: ignore[return-value] + + def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list["ModuleProxy"]: if not self._client: raise ValueError("Not started") @@ -75,11 +78,8 @@ def deploy_parallel( for (module_class, _, _), module in zip(module_specs, modules, strict=True): self._deployed_modules[module_class] = module # type: ignore[assignment] return modules # type: ignore[return-value] - else: - return [ - self.deploy(module_class, *args, **kwargs) - for module_class, args, kwargs in module_specs - ] + + return [self.deploy(m, c, **kw) for m, c, kw in module_specs] def start_all_modules(self) -> None: modules = list(self._deployed_modules.values()) @@ -95,7 +95,7 @@ def start_all_modules(self) -> None: if hasattr(module, "on_system_modules"): module.on_system_modules(module_list) - def get_instance(self, module: type[ModuleT]) -> "ModuleProxy": + def get_instance(self, module: type[ModuleBase]) -> "ModuleProxy": return self._deployed_modules.get(module) # type: ignore[return-value, no-any-return] def loop(self) -> None: diff --git a/dimos/core/worker.py b/dimos/core/worker.py index d6ff71918c..c70cd38431 100644 --- a/dimos/core/worker.py +++ b/dimos/core/worker.py @@ -17,7 +17,8 @@ import traceback from typing import Any -from dimos.core.module import ModuleT +from dimos.core.global_config import GlobalConfig +from dimos.core.module import Module, ModuleBase from dimos.core.rpc_client import RPCClient from dimos.utils.actor_registry import ActorRegistry from dimos.utils.logging_config import setup_logger @@ -40,7 +41,7 @@ class Actor: """Proxy that forwards method calls to the worker process.""" def __init__( - self, conn: Connection | None, module_class: type[ModuleT], worker_id: int + self, conn: Connection | None, module_class: type[ModuleBase], worker_id: int ) -> None: self._conn = conn self._cls = module_class @@ -98,13 +99,13 @@ def reset_forkserver_context() -> None: class Worker: def __init__( self, - module_class: type[ModuleT], - args: tuple[Any, ...] = (), - kwargs: dict[Any, Any] | None = None, + module_class: type[ModuleBase], + global_config: GlobalConfig, + kwargs: dict[str, Any], ) -> None: - self._module_class: type[ModuleT] = module_class - self._args: tuple[Any, ...] = args - self._kwargs: dict[Any, Any] = kwargs or {} + self._module_class = module_class + self._global_config = global_config + self._kwargs = kwargs self._process: Any = None self._conn: Connection | None = None self._actor: Actor | None = None @@ -118,7 +119,13 @@ def start_process(self) -> None: self._process = ctx.Process( target=_worker_entrypoint, - args=(child_conn, self._module_class, self._args, self._kwargs, self._worker_id), + args=( + child_conn, + self._module_class, + self._global_config, + self._kwargs, + self._worker_id, + ), daemon=True, ) self._process.start() @@ -168,15 +175,15 @@ def shutdown(self) -> None: def _worker_entrypoint( conn: Connection, - module_class: type[ModuleT], - args: tuple[Any, ...], - kwargs: dict[Any, Any], + module_class: type[Module], + global_config: GlobalConfig, + kwargs: dict[str, Any], worker_id: int, ) -> None: instance = None try: - instance = module_class(*args, **kwargs) + instance = module_class(global_config=global_config, **kwargs) instance.worker = worker_id _worker_loop(conn, instance, worker_id) diff --git a/dimos/core/worker_manager.py b/dimos/core/worker_manager.py index 175b650fd2..76b156e780 100644 --- a/dimos/core/worker_manager.py +++ b/dimos/core/worker_manager.py @@ -14,7 +14,8 @@ from typing import Any -from dimos.core.module import ModuleT +from dimos.core.global_config import GlobalConfig +from dimos.core.module import ModuleBase from dimos.core.rpc_client import RPCClient from dimos.core.worker import Worker from dimos.utils.actor_registry import ActorRegistry @@ -28,24 +29,26 @@ def __init__(self) -> None: self._workers: list[Worker] = [] self._closed = False - def deploy(self, module_class: type[ModuleT], *args: Any, **kwargs: Any) -> RPCClient: + def deploy( + self, module_class: type[ModuleBase], global_config: GlobalConfig, kwargs: dict[str, Any] + ) -> RPCClient: if self._closed: raise RuntimeError("WorkerManager is closed") - worker = Worker(module_class, args=args, kwargs=kwargs) + worker = Worker(module_class, global_config, kwargs) worker.deploy() self._workers.append(worker) return worker.get_instance() def deploy_parallel( - self, module_specs: list[tuple[type[ModuleT], tuple[Any, ...], dict[Any, Any]]] + self, module_specs: list[tuple[type[ModuleBase], GlobalConfig, dict[str, Any]]] ) -> list[RPCClient]: if self._closed: raise RuntimeError("WorkerManager is closed") workers: list[Worker] = [] - for module_class, args, kwargs in module_specs: - worker = Worker(module_class, args=args, kwargs=kwargs) + for module_class, global_config, kwargs in module_specs: + worker = Worker(module_class, global_config, kwargs) worker.start_process() workers.append(worker) diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index 11821d4724..659c9e07b3 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -14,6 +14,7 @@ from collections.abc import Callable from dataclasses import dataclass, field +import sys import time from typing import Any @@ -22,7 +23,6 @@ from dimos.agents.annotation import skill from dimos.core.blueprints import autoconnect from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.hardware.sensors.camera.spec import CameraHardware @@ -33,6 +33,11 @@ from dimos.spec import perception from dimos.visualization.rerun.bridge import rerun_bridge +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + def default_transform() -> Transform: return Transform( @@ -51,20 +56,16 @@ class CameraModuleConfig(ModuleConfig): frequency: float = 0.0 # Hz, 0 means no limit -class CameraModule(Module[CameraModuleConfig], perception.Camera): +CameraConfigT = TypeVar("CameraConfigT", bound=CameraModuleConfig, default=CameraModuleConfig) + + +class CameraModule(Module[CameraConfigT], perception.Camera): color_image: Out[Image] camera_info: Out[CameraInfo] - hardware: CameraHardware[Any] - - config: CameraModuleConfig default_config = CameraModuleConfig - _global_config: GlobalConfig - - def __init__(self, *args: Any, cfg: GlobalConfig = global_config, **kwargs: Any) -> None: - self._global_config = cfg - self._latest_image: Image | None = None - super().__init__(*args, **kwargs) + hardware: CameraHardware[Any] + _latest_image: Image | None = None @rpc def start(self) -> None: diff --git a/dimos/hardware/sensors/camera/zed/__init__.py b/dimos/hardware/sensors/camera/zed/__init__.py index f8e73273bf..430416e12e 100644 --- a/dimos/hardware/sensors/camera/zed/__init__.py +++ b/dimos/hardware/sensors/camera/zed/__init__.py @@ -16,46 +16,14 @@ from pathlib import Path +from dimos.hardware.sensors.camera.zed.camera import ZEDCamera, ZEDModule, zed_camera from dimos.msgs.sensor_msgs.CameraInfo import CalibrationProvider -# Check if ZED SDK is available -try: - import pyzed.sl as sl # noqa: F401 - - HAS_ZED_SDK = True -except ImportError: - HAS_ZED_SDK = False - -# Only import ZED classes if SDK is available -if HAS_ZED_SDK: - from dimos.hardware.sensors.camera.zed.camera import ZEDCamera, ZEDModule, zed_camera -else: - # Provide stub classes when SDK is not available - class ZEDCamera: # type: ignore[no-redef] - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - raise ImportError( - "ZED SDK not installed. Please install pyzed package to use ZED camera functionality." - ) - - class ZEDModule: # type: ignore[no-redef] - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - raise ImportError( - "ZED SDK not installed. Please install pyzed package to use ZED camera functionality." - ) - - def zed_camera(*args: object, **kwargs: object) -> None: # type: ignore[no-redef] - raise ModuleNotFoundError( - "ZED SDK not installed. Please install pyzed package to use ZED camera functionality.", - name="pyzed", - ) - - # Set up camera calibration provider (always available) CALIBRATION_DIR = Path(__file__).parent CameraInfo = CalibrationProvider(CALIBRATION_DIR) __all__ = [ - "HAS_ZED_SDK", "CameraInfo", "ZEDCamera", "ZEDModule", diff --git a/dimos/hardware/sensors/fake_zed_module.py b/dimos/hardware/sensors/fake_zed_module.py index ec5613077d..f119179705 100644 --- a/dimos/hardware/sensors/fake_zed_module.py +++ b/dimos/hardware/sensors/fake_zed_module.py @@ -17,7 +17,6 @@ FakeZEDModule - Replays recorded ZED data for testing without hardware. """ -from dataclasses import dataclass import functools import logging @@ -25,6 +24,7 @@ import numpy as np from dimos.core.core import rpc +from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.msgs.geometry_msgs import PoseStamped @@ -37,8 +37,8 @@ logger = setup_logger(level=logging.INFO) -@dataclass class FakeZEDModuleConfig(ModuleConfig): + recording_path: str frame_id: str = "zed_camera" @@ -54,18 +54,17 @@ class FakeZEDModule(Module[FakeZEDModuleConfig]): pose: Out[PoseStamped] default_config = FakeZEDModuleConfig - config: FakeZEDModuleConfig - def __init__(self, recording_path: str, **kwargs: object) -> None: + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: object) -> None: """ Initialize FakeZEDModule with recording path. Args: recording_path: Path to recorded data directory """ - super().__init__(**kwargs) + super().__init__(global_config, **kwargs) - self.recording_path = recording_path + self.recording_path = self.config.recording_path self._running = False # Initialize TF publisher diff --git a/dimos/manipulation/manipulation_module.py b/dimos/manipulation/manipulation_module.py index 310b77d766..3318fab7f1 100644 --- a/dimos/manipulation/manipulation_module.py +++ b/dimos/manipulation/manipulation_module.py @@ -34,6 +34,7 @@ from dimos.constants import DIMOS_PROJECT_ROOT from dimos.core import In, Module, rpc from dimos.core.docker_runner import DockerModule as DockerRunner +from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import ModuleConfig from dimos.manipulation.grasping.graspgen_module import GraspGenModule from dimos.manipulation.planning import ( @@ -118,7 +119,7 @@ class ManipulationModuleConfig(ModuleConfig): ) -class ManipulationModule(Module): +class ManipulationModule(Module[ManipulationModuleConfig]): """Motion planning module with ControlCoordinator execution. - @rpc: Low-level building blocks (plan, execute, obstacles) @@ -128,17 +129,14 @@ class ManipulationModule(Module): default_config = ManipulationModuleConfig - # Type annotation for the config attribute (mypy uses this) - config: ManipulationModuleConfig - # Input: Joint state from coordinator (for world sync) joint_state: In[JointState] # Input: Objects from perception (for obstacle integration) objects: In[list[DetObject]] - def __init__(self, *args: object, **kwargs: object) -> None: - super().__init__(*args, **kwargs) + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: + super().__init__(global_config, **kwargs) # State machine self._state = ManipulationState.IDLE diff --git a/dimos/mapping/costmapper.py b/dimos/mapping/costmapper.py index 70cd770777..561bc72a82 100644 --- a/dimos/mapping/costmapper.py +++ b/dimos/mapping/costmapper.py @@ -38,16 +38,14 @@ class Config(ModuleConfig): config: OccupancyConfig = field(default_factory=HeightCostConfig) -class CostMapper(Module): +class CostMapper(Module[Config]): default_config = Config - config: Config global_map: In[PointCloud2] global_costmap: Out[OccupancyGrid] - def __init__(self, cfg: GlobalConfig = global_config, **kwargs: object) -> None: - super().__init__(**kwargs) - self._global_config = cfg + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: object) -> None: + super().__init__(global_config, **kwargs) @rpc def start(self) -> None: diff --git a/dimos/mapping/voxels.py b/dimos/mapping/voxels.py index 4c1805e059..d079979801 100644 --- a/dimos/mapping/voxels.py +++ b/dimos/mapping/voxels.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import time +from typing import Any import numpy as np import open3d as o3d # type: ignore[import-untyped] @@ -33,7 +33,6 @@ logger = setup_logger() -@dataclass class Config(ModuleConfig): frame_id: str = "world" # -1 never publishes, 0 publishes on every frame, >0 publishes at interval in seconds @@ -51,9 +50,8 @@ class VoxelGridMapper(Module): lidar: In[PointCloud2] global_map: Out[PointCloud2] - def __init__(self, cfg: GlobalConfig = global_config, **kwargs: object) -> None: - super().__init__(**kwargs) - self._global_config = cfg + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: + super().__init__(global_config, **kwargs) dev = ( o3c.Device(self.config.device) diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index da415ac32a..8b0ce52d06 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import threading import time +from typing import Any import cv2 @@ -29,6 +29,7 @@ from reactivex.disposable import Disposable from dimos.core.core import rpc +from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 @@ -51,9 +52,10 @@ logger = setup_logger() -@dataclass class ObjectTrackingConfig(ModuleConfig): frame_id: str = "camera_link" + reid_threshold: int = 10 + reid_fail_tolerance: int = 5 class ObjectTracking(Module[ObjectTrackingConfig]): @@ -70,11 +72,8 @@ class ObjectTracking(Module[ObjectTrackingConfig]): tracked_overlay: Out[Image] # Visualization output default_config = ObjectTrackingConfig - config: ObjectTrackingConfig - def __init__( - self, reid_threshold: int = 10, reid_fail_tolerance: int = 5, **kwargs: object - ) -> None: + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: """ Initialize an object tracking module using OpenCV's CSRT tracker with ORB re-ID. @@ -86,11 +85,9 @@ def __init__( tracking is stopped. """ # Call parent Module init - super().__init__(**kwargs) + super().__init__(global_config, **kwargs) self.camera_intrinsics = None - self.reid_threshold = reid_threshold - self.reid_fail_tolerance = reid_fail_tolerance self.tracker = None self.tracking_bbox = None # Stores (x, y, w, h) for tracker initialization @@ -276,7 +273,7 @@ def reid(self, frame, current_bbox) -> bool: # type: ignore[no-untyped-def] good_matches += 1 self.last_good_matches = good_matches_list # Store good matches for visualization - return good_matches >= self.reid_threshold + return good_matches >= self.config.reid_threshold def _start_tracking_thread(self) -> None: """Start the tracking thread.""" @@ -389,7 +386,7 @@ def _process_tracking(self) -> None: # Determine final success if tracker_succeeded: - if self.reid_fail_count >= self.reid_fail_tolerance: + if self.reid_fail_count >= self.config.reid_fail_tolerance: logger.warning( f"Re-ID failed consecutively {self.reid_fail_count} times. Target lost." ) @@ -589,11 +586,11 @@ def _draw_reid_matches(self, image: NDArray[np.uint8]) -> NDArray[np.uint8]: # f"REID: WARMING UP ({self.tracking_frame_count}/{self.reid_warmup_frames})" ) status_color = (255, 255, 0) # Yellow - elif len(self.last_good_matches) >= self.reid_threshold: + elif len(self.last_good_matches) >= self.config.reid_threshold: status_text = "REID: CONFIRMED" status_color = (0, 255, 0) # Green else: - status_text = f"REID: WEAK ({self.reid_fail_count}/{self.reid_fail_tolerance})" + status_text = f"REID: WEAK ({self.reid_fail_count}/{self.config.reid_fail_tolerance})" status_color = (0, 165, 255) # Orange cv2.putText( diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py index 1264b0e92b..17f470cff6 100644 --- a/dimos/perception/object_tracker_2d.py +++ b/dimos/perception/object_tracker_2d.py @@ -16,6 +16,7 @@ import logging import threading import time +from typing import Any import cv2 @@ -33,6 +34,7 @@ from reactivex.disposable import Disposable from dimos.core.core import rpc +from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.msgs.sensor_msgs import Image, ImageFormat @@ -59,9 +61,9 @@ class ObjectTracker2D(Module[ObjectTracker2DConfig]): default_config = ObjectTracker2DConfig config: ObjectTracker2DConfig - def __init__(self, **kwargs: object) -> None: + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: """Initialize 2D object tracking module using OpenCV's CSRT tracker.""" - super().__init__(**kwargs) + super().__init__(global_config, **kwargs) # Tracker state self.tracker = None diff --git a/dimos/protocol/service/spec.py b/dimos/protocol/service/spec.py index c4e6758614..37ded73c06 100644 --- a/dimos/protocol/service/spec.py +++ b/dimos/protocol/service/spec.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar # Generic type for service configuration ConfigT = TypeVar("ConfigT") @@ -22,8 +22,8 @@ class Configurable(Generic[ConfigT]): default_config: type[ConfigT] - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] - self.config: ConfigT = self.default_config(**kwargs) + def __init__(self, **kwargs: Any) -> None: + self.config = self.default_config(**kwargs) class Service(Configurable[ConfigT], ABC): diff --git a/dimos/robot/foxglove_bridge.py b/dimos/robot/foxglove_bridge.py index 529a14c838..013cc224c7 100644 --- a/dimos/robot/foxglove_bridge.py +++ b/dimos/robot/foxglove_bridge.py @@ -13,43 +13,33 @@ # limitations under the License. import asyncio +from collections.abc import Sequence import logging import threading -from typing import TYPE_CHECKING, Any from dimos_lcm.foxglove_bridge import ( FoxgloveBridge as LCMFoxgloveBridge, ) from dimos.core import DimosCluster, Module, rpc +from dimos.core.module import ModuleConfig from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from dimos.core.global_config import GlobalConfig - logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) logger = setup_logger() -class FoxgloveBridge(Module): +class FoxgloveConfig(ModuleConfig): + shm_channels: Sequence[str] = () + jpeg_shm_channels: Sequence[str] = () + + +class FoxgloveBridge(Module[FoxgloveConfig]): _thread: threading.Thread _loop: asyncio.AbstractEventLoop - _global_config: "GlobalConfig | None" = None - - def __init__( - self, - *args: Any, - shm_channels: list[str] | None = None, - jpeg_shm_channels: list[str] | None = None, - global_config: "GlobalConfig | None" = None, - **kwargs: Any, - ) -> None: - super().__init__(*args, **kwargs) - self.shm_channels = shm_channels or [] - self.jpeg_shm_channels = jpeg_shm_channels or [] - self._global_config = global_config + default_config = FoxgloveConfig @rpc def start(self) -> None: @@ -77,8 +67,8 @@ def run_bridge() -> None: port=8765, debug=False, num_threads=4, - shm_channels=self.shm_channels, - jpeg_shm_channels=self.jpeg_shm_channels, + shm_channels=self.config.shm_channels, + jpeg_shm_channels=self.config.jpeg_shm_channels, ) self._loop.run_until_complete(bridge.run()) except Exception as e: diff --git a/dimos/robot/unitree/b1/connection.py b/dimos/robot/unitree/b1/connection.py index bae4bc0844..952069d626 100644 --- a/dimos/robot/unitree/b1/connection.py +++ b/dimos/robot/unitree/b1/connection.py @@ -21,10 +21,13 @@ import socket import threading import time +from typing import Any from reactivex.disposable import Disposable from dimos.core import In, Module, Out, rpc +from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.module import ModuleConfig from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.std_msgs import Int32 @@ -46,13 +49,21 @@ class RobotMode: RECOVERY = 6 -class B1ConnectionModule(Module): +class B1ConnectionConfig(ModuleConfig): + ip: str = "192.168.12.1" + port: int = 9090 + test_mode: bool = False + + +class B1ConnectionModule(Module[B1ConnectionConfig]): """UDP connection module for B1 robot with standard Twist interface. Accepts standard ROS Twist messages on /cmd_vel and mode changes on /b1/mode, internally converts to B1Command format, and sends UDP packets at 50Hz. """ + default_config = B1ConnectionConfig + # LCM ports (inter-module communication) cmd_vel: In[TwistStamped] mode_cmd: In[Int32] @@ -65,9 +76,7 @@ class B1ConnectionModule(Module): ros_odom_in: In[Odometry] ros_tf: In[TFMessage] - def __init__( # type: ignore[no-untyped-def] - self, ip: str = "192.168.12.1", port: int = 9090, test_mode: bool = False, *args, **kwargs - ) -> None: + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: """Initialize B1 connection module. Args: @@ -75,11 +84,11 @@ def __init__( # type: ignore[no-untyped-def] port: UDP port for joystick server test_mode: If True, print commands instead of sending UDP """ - Module.__init__(self, *args, **kwargs) + super().__init__(global_config, **kwargs) - self.ip = ip - self.port = port - self.test_mode = test_mode + self.ip = self.config.ip + self.port = self.config.port + self.test_mode = self.config.test_mode self.current_mode = RobotMode.IDLE # Start in IDLE mode self._current_cmd = B1Command(mode=RobotMode.IDLE) self.cmd_lock = threading.Lock() # Thread lock for _current_cmd access @@ -381,9 +390,10 @@ def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> bool: class MockB1ConnectionModule(B1ConnectionModule): """Test connection module that prints commands instead of sending UDP.""" - def __init__(self, ip: str = "127.0.0.1", port: int = 9090, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: # type: ignore[no-untyped-def] """Initialize test connection without creating socket.""" - super().__init__(ip, port, test_mode=True, *args, **kwargs) # type: ignore[misc] + kwargs["test_mode"] = True + super().__init__(global_config, **kwargs) def _send_loop(self) -> None: """Override to provide better test output with timeout detection.""" diff --git a/dimos/robot/unitree/b1/unitree_b1.py b/dimos/robot/unitree/b1/unitree_b1.py index a2dd6c718d..191a8388be 100644 --- a/dimos/robot/unitree/b1/unitree_b1.py +++ b/dimos/robot/unitree/b1/unitree_b1.py @@ -93,9 +93,9 @@ def start(self) -> None: logger.info("Deploying connection module...") if self.test_mode: - self.connection = self._dimos.deploy(MockB1ConnectionModule, self.ip, self.port) # type: ignore[assignment] + self.connection = self._dimos.deploy(MockB1ConnectionModule, ip=self.ip, port=self.port) # type: ignore[assignment] else: - self.connection = self._dimos.deploy(B1ConnectionModule, self.ip, self.port) # type: ignore[assignment] + self.connection = self._dimos.deploy(B1ConnectionModule, ip=self.ip, port=self.port) # type: ignore[assignment] # Configure LCM transports for connection (matching G1 pattern) self.connection.cmd_vel.transport = core.LCMTransport("/cmd_vel", TwistStamped) # type: ignore[attr-defined] diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index 1dc104f1b4..53308a1684 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -165,7 +165,7 @@ class Config(ModuleConfig): blueprint: BlueprintFactory | None = _default_blueprint -class RerunBridgeModule(Module): +class RerunBridgeModule(Module[Config]): """Bridge that logs messages from pubsubs to Rerun. Spawns its own Rerun viewer and subscribes to all topics on each provided @@ -182,7 +182,6 @@ class RerunBridgeModule(Module): """ default_config = Config - config: Config @lru_cache(maxsize=256) def _visual_override_for_entity_path( From 5f93bd93c9e23d111d158142a36343333bfda84a Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 25 Feb 2026 16:22:21 +0000 Subject: [PATCH 2/9] Fixes --- dimos/control/coordinator.py | 1 - dimos/core/docker_runner.py | 3 +-- dimos/core/module.py | 2 ++ dimos/manipulation/manipulation_module.py | 3 +-- dimos/simulation/manipulators/sim_module.py | 2 -- dimos/visualization/rerun/bridge.py | 8 +++----- 6 files changed, 7 insertions(+), 12 deletions(-) diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index c9182e6aa8..f2fca8b301 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -97,7 +97,6 @@ class TaskConfig: hand: str = "" # teleop_ik only: "left" or "right" controller -@dataclass class ControlCoordinatorConfig(ModuleConfig): """Configuration for the ControlCoordinator. diff --git a/dimos/core/docker_runner.py b/dimos/core/docker_runner.py index 9be2ff6012..8203975c0c 100644 --- a/dimos/core/docker_runner.py +++ b/dimos/core/docker_runner.py @@ -15,7 +15,7 @@ import argparse from contextlib import suppress -from dataclasses import dataclass, field +from dataclasses import field import importlib import json import os @@ -46,7 +46,6 @@ LOG_TAIL_LINES = 200 # Number of log lines to include in error messages -@dataclass(kw_only=True) class DockerModuleConfig(ModuleConfig): """ Configuration for running a DimOS module inside Docker. diff --git a/dimos/core/module.py b/dimos/core/module.py index 5f21af5a08..c6452dcdff 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -96,6 +96,8 @@ def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: class ModuleConfig(BaseModel): + model_config = {"arbitrary_types_allowed": True} + rpc_transport: type[RPCSpec] = LCMRPC tf_transport: type[TFSpec] = LCMTF frame_id_prefix: str | None = None diff --git a/dimos/manipulation/manipulation_module.py b/dimos/manipulation/manipulation_module.py index 3318fab7f1..570124db4e 100644 --- a/dimos/manipulation/manipulation_module.py +++ b/dimos/manipulation/manipulation_module.py @@ -22,7 +22,7 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import field from enum import Enum import math from pathlib import Path @@ -96,7 +96,6 @@ class ManipulationState(Enum): FAULT = 4 -@dataclass class ManipulationModuleConfig(ModuleConfig): """Configuration for ManipulationModule.""" diff --git a/dimos/simulation/manipulators/sim_module.py b/dimos/simulation/manipulators/sim_module.py index 4f1bb986d3..df7e26dd01 100644 --- a/dimos/simulation/manipulators/sim_module.py +++ b/dimos/simulation/manipulators/sim_module.py @@ -16,7 +16,6 @@ from __future__ import annotations -from dataclasses import dataclass import threading import time from typing import TYPE_CHECKING, Any @@ -34,7 +33,6 @@ from pathlib import Path -@dataclass(kw_only=True) class SimulationModuleConfig(ModuleConfig): engine: EngineType config_path: Path | Callable[[], Path] diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index 53308a1684..37ae8fdd7b 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -16,7 +16,8 @@ from __future__ import annotations -from dataclasses import dataclass, field +from collections.abc import Callable +from dataclasses import field from functools import lru_cache from typing import ( TYPE_CHECKING, @@ -88,14 +89,12 @@ logger = setup_logger() if TYPE_CHECKING: - from collections.abc import Callable - from rerun._baseclasses import Archetype from rerun.blueprint import Blueprint from dimos.protocol.pubsub.spec import SubscribeAllCapable -BlueprintFactory: TypeAlias = "Callable[[], Blueprint]" +BlueprintFactory: TypeAlias = Callable[[], "Blueprint"] # to_rerun() can return a single archetype or a list of (entity_path, archetype) tuples RerunMulti: TypeAlias = "list[tuple[str, Archetype]]" @@ -142,7 +141,6 @@ def _default_blueprint() -> Blueprint: ) -@dataclass class Config(ModuleConfig): """Configuration for RerunBridgeModule.""" From 8b9c27e706791a6bdb4b1358d6bb78c347a86513 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 25 Feb 2026 18:41:46 +0000 Subject: [PATCH 3/9] Update module.py --- dimos/core/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/core/module.py b/dimos/core/module.py index c6452dcdff..e611cfea24 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -464,7 +464,7 @@ def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any): inner, *_ = get_args(ann) or (Any,) stream = In(inner, name, self) # type: ignore[assignment] setattr(self, name, stream) - super().__init__(global_config, **kwargs) + super().__init__(config_args=kwargs, global_config=global_config) def set_ref(self, ref) -> int: # type: ignore[no-untyped-def] worker = get_worker() From fd90f9bab947cadcaa74681bd8d285c3d2c1e387 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 25 Feb 2026 18:43:09 +0000 Subject: [PATCH 4/9] Update module.py --- dimos/hardware/sensors/camera/module.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index 659c9e07b3..7c7edaaf5d 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -14,7 +14,6 @@ from collections.abc import Callable from dataclasses import dataclass, field -import sys import time from typing import Any @@ -33,11 +32,6 @@ from dimos.spec import perception from dimos.visualization.rerun.bridge import rerun_bridge -if sys.version_info >= (3, 13): - from typing import TypeVar -else: - from typing_extensions import TypeVar - def default_transform() -> Transform: return Transform( @@ -56,10 +50,7 @@ class CameraModuleConfig(ModuleConfig): frequency: float = 0.0 # Hz, 0 means no limit -CameraConfigT = TypeVar("CameraConfigT", bound=CameraModuleConfig, default=CameraModuleConfig) - - -class CameraModule(Module[CameraConfigT], perception.Camera): +class CameraModule(Module[CameraModuleConfig], perception.Camera): color_image: Out[Image] camera_info: Out[CameraInfo] From 4a9288c077c4181cbbefabd32d7b477781cf9630 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 25 Feb 2026 19:30:08 +0000 Subject: [PATCH 5/9] Update blueprints.py --- dimos/core/blueprints.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index 8e5fc9c627..bc4f3a3627 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -20,7 +20,7 @@ import operator import sys from types import MappingProxyType -from typing import Any, Literal, Self, get_args, get_origin, get_type_hints +from typing import Any, Literal, get_args, get_origin, get_type_hints from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleBase, ModuleSpec, is_module_type @@ -31,6 +31,11 @@ from dimos.utils.generic import short_id from dimos.utils.logging_config import setup_logger +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing import Any as Self + logger = setup_logger() From 8b7f74a2b155cd6ea945a52fd339f5874e5a1442 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Thu, 26 Feb 2026 18:15:06 +0000 Subject: [PATCH 6/9] Fixes --- dimos/agents/skills/person_follow.py | 37 +++++----- dimos/core/blueprints.py | 5 +- dimos/core/module.py | 14 ++-- dimos/core/native_module.py | 42 ++++++----- dimos/core/test_blueprints.py | 9 +-- dimos/core/test_core.py | 3 - dimos/core/test_native_module.py | 2 - dimos/hardware/sensors/camera/spec.py | 8 +-- dimos/hardware/sensors/camera/test_webcam.py | 6 +- dimos/hardware/sensors/camera/zed/__init__.py | 13 +++- dimos/hardware/sensors/camera/zed/test_zed.py | 7 +- dimos/hardware/sensors/lidar/livox/module.py | 4 +- dimos/mapping/osm/current_location_map.py | 6 +- dimos/mapping/osm/query.py | 7 +- dimos/models/base.py | 6 +- dimos/models/vl/base.py | 19 +++-- dimos/models/vl/moondream.py | 8 +-- dimos/models/vl/moondream_hosted.py | 13 ++-- dimos/models/vl/openai.py | 4 +- dimos/models/vl/qwen.py | 4 +- .../test_wavefront_frontier_goal_selector.py | 2 +- .../wavefront_frontier_goal_selector.py | 70 +++++++++---------- dimos/navigation/visual/query.py | 3 +- dimos/perception/detection/conftest.py | 7 +- dimos/perception/detection/module2D.py | 25 +++---- .../temporal_memory/entity_graph_db.py | 2 +- .../temporal_memory/temporal_memory.py | 4 +- .../temporal_memory/temporal_memory_deploy.py | 4 +- .../temporal_utils/graph_utils.py | 2 +- dimos/perception/object_tracker_2d.py | 3 - dimos/protocol/pubsub/bridge.py | 6 +- dimos/protocol/pubsub/impl/lcmpubsub.py | 3 +- dimos/protocol/pubsub/impl/redispubsub.py | 9 ++- dimos/protocol/service/__init__.py | 7 +- dimos/protocol/service/ddsservice.py | 11 +-- dimos/protocol/service/lcmservice.py | 37 +++++----- dimos/protocol/service/spec.py | 9 ++- dimos/protocol/service/test_lcmservice.py | 58 ++++++++------- dimos/protocol/tf/tf.py | 28 ++++---- dimos/protocol/tf/tflcmcpp.py | 9 ++- dimos/robot/drone/connection_module.py | 39 +++++------ dimos/simulation/manipulators/sim_module.py | 6 +- .../manipulators/test_sim_module.py | 3 +- dimos/utils/cli/lcmspy/lcmspy.py | 14 ++-- 44 files changed, 298 insertions(+), 280 deletions(-) diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py index 641055e6f6..86feb99363 100644 --- a/dimos/agents/skills/person_follow.py +++ b/dimos/agents/skills/person_follow.py @@ -14,7 +14,7 @@ from threading import Event, RLock, Thread import time -from typing import TYPE_CHECKING +from typing import Any from langchain_core.messages import HumanMessage import numpy as np @@ -23,8 +23,8 @@ from dimos.agents.agent import AgentSpec from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig -from dimos.core.module import Module +from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.models.qwen.video_query import BBox from dimos.models.segmentation.edge_tam import EdgeTAMProcessor @@ -36,13 +36,15 @@ from dimos.navigation.visual_servoing.visual_servoing_2d import VisualServoing2D from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from dimos.models.vl.base import VlModel - logger = setup_logger() -class PersonFollowSkillContainer(Module): +class Config(ModuleConfig): + camera_info: CameraInfo + use_3d_navigation: bool = False + + +class PersonFollowSkillContainer(Module[Config]): """Skill container for following a person. This skill uses: @@ -52,6 +54,8 @@ class PersonFollowSkillContainer(Module): - Does not do obstacle avoidance; assumes a clear path. """ + default_config = Config + color_image: In[Image] global_map: In[PointCloud2] cmd_vel: Out[Twist] @@ -60,30 +64,23 @@ class PersonFollowSkillContainer(Module): _frequency: float = 20.0 # Hz - control loop frequency _max_lost_frames: int = 15 # number of frames to wait before declaring person lost - def __init__( - self, - camera_info: CameraInfo, - cfg: GlobalConfig, - use_3d_navigation: bool = False, - ) -> None: - super().__init__() - self._global_config: GlobalConfig = cfg - self._use_3d_navigation: bool = use_3d_navigation + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: + super().__init__(global_config, **kwargs) self._latest_image: Image | None = None self._latest_pointcloud: PointCloud2 | None = None - self._vl_model: VlModel = QwenVlModel() + self._vl_model = QwenVlModel() self._tracker: EdgeTAMProcessor | None = None self._thread: Thread | None = None self._should_stop: Event = Event() self._lock = RLock() # Use MuJoCo camera intrinsics in simulation mode + camera_info = self.config.camera_info if self._global_config.simulation: from dimos.robot.unitree.mujoco_connection import MujocoConnection camera_info = MujocoConnection.camera_info_static - self._camera_info = camera_info self._visual_servo = VisualServoing2D(camera_info, self._global_config.simulation) self._detection_navigation = DetectionNavigation(self.tf, camera_info) @@ -91,7 +88,7 @@ def __init__( def start(self) -> None: super().start() self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) - if self._use_3d_navigation: + if self.config.use_3d_navigation: self._disposables.add(Disposable(self.global_map.subscribe(self._on_pointcloud))) @rpc @@ -227,7 +224,7 @@ def _follow_loop(self, tracker: EdgeTAMProcessor, query: str) -> None: lost_count = 0 best_detection = max(detections.detections, key=lambda d: d.bbox_2d_volume()) - if self._use_3d_navigation: + if self.config.use_3d_navigation: with self._lock: pointcloud = self._latest_pointcloud if pointcloud is None: diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index bc4f3a3627..e317ac41e2 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -134,7 +134,10 @@ def global_config(self, **kwargs: Any) -> "Blueprint": ) def remappings( - self, remappings: list[tuple[type[ModuleBase], str, str | type[ModuleBase] | type[Spec]]] + self, + remappings: list[ + tuple[type[ModuleBase[Any]], str, str | type[ModuleBase[Any]] | type[Spec]] + ], ) -> "Blueprint": remappings_dict = dict(self.remapping_map) for module, old, new in remappings: diff --git a/dimos/core/module.py b/dimos/core/module.py index e611cfea24..76b298d21b 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -39,7 +39,6 @@ from dask.distributed import Actor, get_worker from langchain_core.tools import tool -from pydantic import BaseModel from reactivex.disposable import CompositeDisposable from dimos.core import colors @@ -50,7 +49,7 @@ from dimos.core.rpc_client import RpcCall # noqa: TC001 from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.protocol.rpc import LCMRPC, RPCSpec -from dimos.protocol.service import Configurable # type: ignore[attr-defined] +from dimos.protocol.service import BaseConfig, Configurable from dimos.protocol.tf import LCMTF, TFSpec from dimos.utils.generic import classproperty @@ -95,11 +94,9 @@ def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: return loop, thr -class ModuleConfig(BaseModel): - model_config = {"arbitrary_types_allowed": True} - +class ModuleConfig(BaseConfig): rpc_transport: type[RPCSpec] = LCMRPC - tf_transport: type[TFSpec] = LCMTF + tf_transport: type[TFSpec] = LCMTF # type: ignore[type-arg] frame_id_prefix: str | None = None frame_id: str | None = None @@ -112,8 +109,11 @@ def __call__(self, **kwargs: Any) -> Blueprint: ... class ModuleBase(Configurable[ModuleConfigT], Resource): + # This won't type check against the TypeVar, but we need it as the default. + default_config: type[ModuleConfigT] = ModuleConfig # type: ignore[assignment] + _rpc: RPCSpec | None = None - _tf: TFSpec | None = None + _tf: TFSpec[Any] | None = None _loop: asyncio.AbstractEventLoop | None = None _loop_thread: threading.Thread | None _disposables: CompositeDisposable diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 6a93e6453a..bec23e42e1 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -40,7 +40,6 @@ class MyCppModule(NativeModule): from __future__ import annotations -from dataclasses import dataclass, field, fields import enum import inspect import json @@ -48,13 +47,22 @@ class MyCppModule(NativeModule): from pathlib import Path import signal import subprocess +import sys import threading from typing import IO, Any +from pydantic import Field + from dimos.core.core import rpc +from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.utils.logging_config import setup_logger +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + logger = setup_logger() @@ -63,15 +71,14 @@ class LogFormat(enum.Enum): JSON = "json" -@dataclass(kw_only=True) class NativeModuleConfig(ModuleConfig): """Configuration for a native (C/C++) subprocess module.""" executable: str build_command: str | None = None cwd: str | None = None - extra_args: list[str] = field(default_factory=list) - extra_env: dict[str, str] = field(default_factory=dict) + extra_args: list[str] = Field(default_factory=list) + extra_env: dict[str, str] = Field(default_factory=dict) shutdown_timeout: float = 10.0 log_format: LogFormat = LogFormat.TEXT @@ -85,26 +92,29 @@ def to_cli_args(self) -> list[str]: or its parents) and converts them to ``["--name", str(value)]`` pairs. Skips fields whose values are ``None`` and fields in ``cli_exclude``. """ - ignore_fields = {f.name for f in fields(NativeModuleConfig)} + ignore_fields = {f for f in NativeModuleConfig.model_fields} args: list[str] = [] - for f in fields(self): - if f.name in ignore_fields: + for f in self.__class__.model_fields: + if f in ignore_fields: continue - if f.name in self.cli_exclude: + if f in self.cli_exclude: continue - val = getattr(self, f.name) + val = getattr(self, f) if val is None: continue if isinstance(val, bool): - args.extend([f"--{f.name}", str(val).lower()]) + args.extend([f"--{f}", str(val).lower()]) elif isinstance(val, list): - args.extend([f"--{f.name}", ",".join(str(v) for v in val)]) + args.extend([f"--{f}", ",".join(str(v) for v in val)]) else: - args.extend([f"--{f.name}", str(val)]) + args.extend([f"--{f}", str(val)]) return args -class NativeModule(Module[NativeModuleConfig]): +_NativeConfig = TypeVar("_NativeConfig", bound=NativeModuleConfig, default=NativeModuleConfig) + + +class NativeModule(Module[_NativeConfig]): """Module that wraps a native executable as a managed subprocess. Subclass this, declare In/Out ports, and set ``default_config`` to a @@ -118,13 +128,13 @@ class NativeModule(Module[NativeModuleConfig]): LCM topics directly. On ``stop()``, the process receives SIGTERM. """ - default_config: type[NativeModuleConfig] = NativeModuleConfig + default_config: type[_NativeConfig] = NativeModuleConfig # type: ignore[assignment] _process: subprocess.Popen[bytes] | None = None _watchdog: threading.Thread | None = None _stopping: bool = False - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: + super().__init__(global_config, **kwargs) self._resolve_paths() @rpc diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index fd18fe72d8..f91591d919 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -114,14 +114,13 @@ class ModuleC(Module): def test_get_connection_set() -> None: - assert _BlueprintAtom.create(CatModule, args=("arg1",), kwargs={"k": "v"}) == _BlueprintAtom( + assert _BlueprintAtom.create(CatModule, kwargs={"k": "v"}) == _BlueprintAtom( module=CatModule, streams=( StreamRef(name="pet_cat", type=Petting, direction="in"), StreamRef(name="scratches", type=Scratch, direction="out"), ), module_refs=(), - args=("arg1",), kwargs={"k": "v"}, ) @@ -138,7 +137,6 @@ def test_autoconnect() -> None: StreamRef(name="data2", type=Data2, direction="out"), ), module_refs=(), - args=(), kwargs={}, ), _BlueprintAtom( @@ -149,7 +147,6 @@ def test_autoconnect() -> None: StreamRef(name="data3", type=Data3, direction="out"), ), module_refs=(), - args=(), kwargs={}, ), ) @@ -346,11 +343,11 @@ def test_future_annotations_support() -> None: """ # Test that streams are properly extracted from modules with future annotations - out_blueprint = _BlueprintAtom.create(FutureModuleOut, args=(), kwargs={}) + out_blueprint = _BlueprintAtom.create(FutureModuleOut, kwargs={}) assert len(out_blueprint.streams) == 1 assert out_blueprint.streams[0] == StreamRef(name="data", type=FutureData, direction="out") - in_blueprint = _BlueprintAtom.create(FutureModuleIn, args=(), kwargs={}) + in_blueprint = _BlueprintAtom.create(FutureModuleIn, kwargs={}) assert len(in_blueprint.streams) == 1 assert in_blueprint.streams[0] == StreamRef(name="data", type=FutureData, direction="in") diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 3866d55bdb..fd18d1fa86 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -45,9 +45,6 @@ class Navigation(Module): @rpc def navigate_to(self, target: Vector3) -> bool: ... - def __init__(self) -> None: - super().__init__() - @rpc def start(self) -> None: def _odom(msg) -> None: diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index a022be0685..39dab4a7d5 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -18,7 +18,6 @@ The echo script writes received CLI args to a temp file for assertions. """ -from dataclasses import dataclass import json from pathlib import Path import time @@ -59,7 +58,6 @@ def read_json_file(path: str) -> dict[str, str]: return result -@dataclass(kw_only=True) class StubNativeConfig(NativeModuleConfig): executable: str = _ECHO log_format: LogFormat = LogFormat.TEXT diff --git a/dimos/hardware/sensors/camera/spec.py b/dimos/hardware/sensors/camera/spec.py index 23fd1a076e..c913e4bfea 100644 --- a/dimos/hardware/sensors/camera/spec.py +++ b/dimos/hardware/sensors/camera/spec.py @@ -13,19 +13,19 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Generic, Protocol, TypeVar +from typing import TypeVar from reactivex.observable import Observable from dimos.msgs.geometry_msgs import Quaternion, Transform from dimos.msgs.sensor_msgs import CameraInfo from dimos.msgs.sensor_msgs.Image import Image -from dimos.protocol.service import Configurable # type: ignore[attr-defined] +from dimos.protocol.service.spec import BaseConfig, Configurable OPTICAL_ROTATION = Quaternion(-0.5, 0.5, -0.5, 0.5) -class CameraConfig(Protocol): +class CameraConfig(BaseConfig): frame_id_prefix: str | None width: int height: int @@ -35,7 +35,7 @@ class CameraConfig(Protocol): CameraConfigT = TypeVar("CameraConfigT", bound=CameraConfig) -class CameraHardware(ABC, Configurable[CameraConfigT], Generic[CameraConfigT]): +class CameraHardware(ABC, Configurable[CameraConfigT]): @abstractmethod def image_stream(self) -> Observable[Image]: pass diff --git a/dimos/hardware/sensors/camera/test_webcam.py b/dimos/hardware/sensors/camera/test_webcam.py index e40a73acc9..479da22267 100644 --- a/dimos/hardware/sensors/camera/test_webcam.py +++ b/dimos/hardware/sensors/camera/test_webcam.py @@ -17,7 +17,6 @@ import pytest from dimos import core -from dimos.hardware.sensors.camera import zed from dimos.hardware.sensors.camera.module import CameraModule from dimos.hardware.sensors.camera.webcam import Webcam from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 @@ -33,6 +32,11 @@ def dimos(): @pytest.mark.tool def test_streaming_single(dimos) -> None: + try: + from dimos.hardware.sensors.camera import zed + except ModuleNotFoundError: + pytest.skip("ZED SDK not installed") + camera = dimos.deploy( CameraModule, transform=Transform( diff --git a/dimos/hardware/sensors/camera/zed/__init__.py b/dimos/hardware/sensors/camera/zed/__init__.py index 430416e12e..2b02d22052 100644 --- a/dimos/hardware/sensors/camera/zed/__init__.py +++ b/dimos/hardware/sensors/camera/zed/__init__.py @@ -16,9 +16,20 @@ from pathlib import Path -from dimos.hardware.sensors.camera.zed.camera import ZEDCamera, ZEDModule, zed_camera from dimos.msgs.sensor_msgs.CameraInfo import CalibrationProvider +try: + import pyzed.sl # noqa: F401 + + # This awkwardness is needed as pytest implicitly imports this to collect + # the test in this directory. + HAS_ZED_SDK = True +except ImportError: + HAS_ZED_SDK = False + +if HAS_ZED_SDK: + from dimos.hardware.sensors.camera.zed.camera import ZEDCamera, ZEDModule, zed_camera + # Set up camera calibration provider (always available) CALIBRATION_DIR = Path(__file__).parent CameraInfo = CalibrationProvider(CALIBRATION_DIR) diff --git a/dimos/hardware/sensors/camera/zed/test_zed.py b/dimos/hardware/sensors/camera/zed/test_zed.py index 2d912553c6..2716e809a5 100644 --- a/dimos/hardware/sensors/camera/zed/test_zed.py +++ b/dimos/hardware/sensors/camera/zed/test_zed.py @@ -13,14 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + +from dimos.hardware.sensors.camera import zed from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +@pytest.mark.skipif(not zed.HAS_ZED_SDK, reason="ZED SDK not installed") def test_zed_import_and_calibration_access() -> None: """Test that zed module can be imported and calibrations accessed.""" - # Import zed module from camera - from dimos.hardware.sensors.camera import zed - # Test that CameraInfo is accessible assert hasattr(zed, "CameraInfo") diff --git a/dimos/hardware/sensors/lidar/livox/module.py b/dimos/hardware/sensors/lidar/livox/module.py index 672968a0eb..8d6982dcf8 100644 --- a/dimos/hardware/sensors/lidar/livox/module.py +++ b/dimos/hardware/sensors/lidar/livox/module.py @@ -26,7 +26,6 @@ from __future__ import annotations -from dataclasses import dataclass from typing import TYPE_CHECKING from dimos.core import Out # noqa: TC001 @@ -48,7 +47,6 @@ from dimos.spec import perception -@dataclass(kw_only=True) class Mid360Config(NativeModuleConfig): """Config for the C++ Mid-360 native module.""" @@ -76,7 +74,7 @@ class Mid360Config(NativeModuleConfig): host_log_data_port: int = SDK_HOST_LOG_DATA_PORT -class Mid360(NativeModule, perception.Lidar, perception.IMU): +class Mid360(NativeModule[Mid360Config], perception.Lidar, perception.IMU): """Livox Mid-360 LiDAR module backed by a native C++ binary. Ports: diff --git a/dimos/mapping/osm/current_location_map.py b/dimos/mapping/osm/current_location_map.py index ef0a832cd6..832116e25c 100644 --- a/dimos/mapping/osm/current_location_map.py +++ b/dimos/mapping/osm/current_location_map.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + from PIL import Image as PILImage, ImageDraw from dimos.mapping.osm.osm import MapImage, get_osm_map @@ -24,11 +26,11 @@ class CurrentLocationMap: - _vl_model: VlModel + _vl_model: VlModel[Any] _position: LatLon | None _map_image: MapImage | None - def __init__(self, vl_model: VlModel) -> None: + def __init__(self, vl_model: VlModel[Any]) -> None: self._vl_model = vl_model self._position = None self._map_image = None diff --git a/dimos/mapping/osm/query.py b/dimos/mapping/osm/query.py index 410f879c20..17fbfe3d4b 100644 --- a/dimos/mapping/osm/query.py +++ b/dimos/mapping/osm/query.py @@ -13,6 +13,7 @@ # limitations under the License. import re +from typing import Any from dimos.mapping.osm.osm import MapImage from dimos.mapping.types import LatLon @@ -25,7 +26,9 @@ logger = setup_logger() -def query_for_one_position(vl_model: VlModel, map_image: MapImage, query: str) -> LatLon | None: +def query_for_one_position( + vl_model: VlModel[Any], map_image: MapImage, query: str +) -> LatLon | None: full_query = f"{_PROLOGUE} {query} {_JSON} If there's a match return the x, y coordinates from the image. Example: `[123, 321]`. If there's no match return `null`." response = vl_model.query(map_image.image, full_query) coords = tuple(map(int, re.findall(r"\d+", response))) @@ -35,7 +38,7 @@ def query_for_one_position(vl_model: VlModel, map_image: MapImage, query: str) - def query_for_one_position_and_context( - vl_model: VlModel, map_image: MapImage, query: str, robot_position: LatLon + vl_model: VlModel[Any], map_image: MapImage, query: str, robot_position: LatLon ) -> tuple[LatLon, str] | None: example = '{"coordinates": [123, 321], "description": "A Starbucks on 27th Street"}' x, y = map_image.latlon_to_pixel(robot_position) diff --git a/dimos/models/base.py b/dimos/models/base.py index 2269a6d0b8..0d1ea97f12 100644 --- a/dimos/models/base.py +++ b/dimos/models/base.py @@ -23,14 +23,13 @@ import torch from dimos.core.resource import Resource -from dimos.protocol.service import Configurable # type: ignore[attr-defined] +from dimos.protocol.service import BaseConfig, Configurable # Device string type - 'cuda', 'cpu', 'cuda:0', 'cuda:1', etc. DeviceType = Annotated[str, "Device identifier (e.g., 'cuda', 'cpu', 'cuda:0')"] -@dataclass -class LocalModelConfig: +class LocalModelConfig(BaseConfig): device: DeviceType = "cuda" if torch.cuda.is_available() else "cpu" dtype: torch.dtype = torch.float32 warmup: bool = False @@ -127,7 +126,6 @@ def _ensure_cuda_initialized(self) -> None: pass -@dataclass class HuggingFaceModelConfig(LocalModelConfig): model_name: str = "" trust_remote_code: bool = True diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py index 93caba4de7..4cc6e75750 100644 --- a/dimos/models/vl/base.py +++ b/dimos/models/vl/base.py @@ -2,16 +2,22 @@ from dataclasses import dataclass import json import logging +import sys import warnings from dimos.core.resource import Resource from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D -from dimos.protocol.service import Configurable # type: ignore[attr-defined] +from dimos.protocol.service.spec import BaseConfig, Configurable from dimos.utils.data import get_data from dimos.utils.decorators import retry from dimos.utils.llm_utils import extract_json +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + logger = logging.getLogger(__name__) @@ -149,15 +155,17 @@ def vlm_point_to_detection2d_point( ) -@dataclass -class VlModelConfig: +class VlModelConfig(BaseConfig): """Configuration for VlModel.""" auto_resize: tuple[int, int] | None = None """Optional (width, height) tuple. If set, images are resized to fit.""" -class VlModel(Captioner, Resource, Configurable[VlModelConfig]): +_VlConfig = TypeVar("_VlConfig", bound=VlModelConfig) + + +class VlModel(Captioner, Resource, Configurable[_VlConfig]): """Vision-language model that can answer questions about images. Inherits from Captioner, providing a default caption() implementation @@ -166,8 +174,7 @@ class VlModel(Captioner, Resource, Configurable[VlModelConfig]): Implements Resource interface for lifecycle management. """ - default_config = VlModelConfig - config: VlModelConfig + default_config: type[_VlConfig] = VlModelConfig # type: ignore[assignment] def _prepare_image(self, image: Image) -> tuple[Image, float]: """Prepare image for inference, applying any configured transformations. diff --git a/dimos/models/vl/moondream.py b/dimos/models/vl/moondream.py index f31611e867..7b79a6d6f2 100644 --- a/dimos/models/vl/moondream.py +++ b/dimos/models/vl/moondream.py @@ -9,7 +9,7 @@ from transformers import AutoModelForCausalLM # type: ignore[import-untyped] from dimos.models.base import HuggingFaceModel, HuggingFaceModelConfig -from dimos.models.vl.base import VlModel +from dimos.models.vl.base import VlModel, VlModelConfig from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D @@ -17,8 +17,7 @@ MOONDREAM_DEFAULT_AUTO_RESIZE = (512, 512) -@dataclass -class MoondreamConfig(HuggingFaceModelConfig): +class MoondreamConfig(HuggingFaceModelConfig, VlModelConfig): """Configuration for MoondreamVlModel.""" model_name: str = "vikhyatk/moondream2" @@ -26,10 +25,9 @@ class MoondreamConfig(HuggingFaceModelConfig): auto_resize: tuple[int, int] | None = MOONDREAM_DEFAULT_AUTO_RESIZE -class MoondreamVlModel(HuggingFaceModel, VlModel): +class MoondreamVlModel(HuggingFaceModel, VlModel[MoondreamConfig]): _model_class = AutoModelForCausalLM default_config = MoondreamConfig # type: ignore[assignment] - config: MoondreamConfig # type: ignore[assignment] @cached_property def _model(self) -> AutoModelForCausalLM: diff --git a/dimos/models/vl/moondream_hosted.py b/dimos/models/vl/moondream_hosted.py index fc1f8b7a17..57df91b47e 100644 --- a/dimos/models/vl/moondream_hosted.py +++ b/dimos/models/vl/moondream_hosted.py @@ -6,20 +6,21 @@ import numpy as np from PIL import Image as PILImage -from dimos.models.vl.base import VlModel +from dimos.models.vl.base import VlModel, VlModelConfig from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D -class MoondreamHostedVlModel(VlModel): - _api_key: str | None +class Config(VlModelConfig): + api_key: str | None = None - def __init__(self, api_key: str | None = None) -> None: - self._api_key = api_key + +class MoondreamHostedVlModel(VlModel[Config]): + default_config = Config @cached_property def _client(self) -> md.vl: - api_key = self._api_key or os.getenv("MOONDREAM_API_KEY") + api_key = self.config.api_key or os.getenv("MOONDREAM_API_KEY") if not api_key: raise ValueError( "Moondream API key must be provided or set in MOONDREAM_API_KEY environment variable" diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py index f596f1ee1e..94f6e20b62 100644 --- a/dimos/models/vl/openai.py +++ b/dimos/models/vl/openai.py @@ -13,15 +13,13 @@ logger = setup_logger() -@dataclass class OpenAIVlModelConfig(VlModelConfig): model_name: str = "gpt-4o-mini" api_key: str | None = None -class OpenAIVlModel(VlModel): +class OpenAIVlModel(VlModel[OpenAIVlModelConfig]): default_config = OpenAIVlModelConfig - config: OpenAIVlModelConfig @cached_property def _client(self) -> OpenAI: diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index 93b31bf74c..dfcf3e7809 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -10,7 +10,6 @@ from dimos.msgs.sensor_msgs import Image -@dataclass class QwenVlModelConfig(VlModelConfig): """Configuration for Qwen VL model.""" @@ -18,9 +17,8 @@ class QwenVlModelConfig(VlModelConfig): api_key: str | None = None -class QwenVlModel(VlModel): +class QwenVlModel(VlModel[QwenVlModelConfig]): default_config = QwenVlModelConfig - config: QwenVlModelConfig @cached_property def _client(self) -> OpenAI: diff --git a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py index 1c8082b414..419986780a 100644 --- a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -262,7 +262,7 @@ def test_frontier_ranking(explorer) -> None: # Note: Goals might be closer than safe_distance if that's the best available frontier # The safe_distance is used for scoring, not as a hard constraint print( - f"Distance to obstacles: {obstacle_dist:.2f}m (safe distance: {explorer.safe_distance}m)" + f"Distance to obstacles: {obstacle_dist:.2f}m (safe distance: {explorer.config.safe_distance}m)" ) print(f"Frontier ranking test passed - selected goal at ({goal1.x:.2f}, {goal1.y:.2f})") diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index 3adfc1c598..4c51fc1d97 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -23,6 +23,7 @@ from dataclasses import dataclass from enum import IntFlag import threading +from typing import Any from dimos_lcm.std_msgs import Bool import numpy as np @@ -30,6 +31,8 @@ from dimos.agents.annotation import skill from dimos.core import In, Module, Out, rpc +from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.module import ModuleConfig from dimos.mapping.occupancy.inflation import simple_inflate from dimos.msgs.geometry_msgs import PoseStamped, Vector3 from dimos.msgs.nav_msgs import CostValues, OccupancyGrid @@ -76,7 +79,18 @@ def clear(self) -> None: self.points.clear() -class WavefrontFrontierExplorer(Module): +class WavefrontConfig(ModuleConfig): + min_frontier_perimeter: float = 0.5 + occupancy_threshold: int = 99 + safe_distance: float = 3.0 + lookahead_distance: float = 5.0 + max_explored_distance: float = 10.0 + info_gain_threshold: float = 0.03 + num_no_gain_attempts: int = 2 + goal_timeout: float = 15.0 + + +class WavefrontFrontierExplorer(Module[WavefrontConfig]): """ Wavefront frontier exploration algorithm implementation. @@ -91,6 +105,8 @@ class WavefrontFrontierExplorer(Module): - goal_request: Exploration goals sent to the navigator """ + default_config = WavefrontConfig + # LCM inputs global_costmap: In[OccupancyGrid] odom: In[PoseStamped] @@ -101,17 +117,7 @@ class WavefrontFrontierExplorer(Module): # LCM outputs goal_request: Out[PoseStamped] - def __init__( - self, - min_frontier_perimeter: float = 0.5, - occupancy_threshold: int = 99, - safe_distance: float = 3.0, - lookahead_distance: float = 5.0, - max_explored_distance: float = 10.0, - info_gain_threshold: float = 0.03, - num_no_gain_attempts: int = 2, - goal_timeout: float = 15.0, - ) -> None: + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: """ Initialize the frontier explorer. @@ -122,20 +128,12 @@ def __init__( info_gain_threshold: Minimum percentage increase in costmap information required to continue exploration (0.05 = 5%) num_no_gain_attempts: Maximum number of consecutive attempts with no information gain """ - super().__init__() - self.min_frontier_perimeter = min_frontier_perimeter - self.occupancy_threshold = occupancy_threshold - self.safe_distance = safe_distance - self.max_explored_distance = max_explored_distance - self.lookahead_distance = lookahead_distance - self.info_gain_threshold = info_gain_threshold - self.num_no_gain_attempts = num_no_gain_attempts + super().__init__(global_config, **kwargs) self._cache = FrontierCache() self.explored_goals = [] # type: ignore[var-annotated] # list of explored goals self.exploration_direction = Vector3(0.0, 0.0, 0.0) # current exploration direction self.last_costmap = None # store last costmap for information comparison self.no_gain_counter = 0 # track consecutive no-gain attempts - self.goal_timeout = goal_timeout # Latest data self.latest_costmap: OccupancyGrid | None = None @@ -212,7 +210,7 @@ def _count_costmap_information(self, costmap: OccupancyGrid) -> int: Number of cells that are free space or obstacles (not unknown) """ free_count = np.sum(costmap.grid == CostValues.FREE) - obstacle_count = np.sum(costmap.grid >= self.occupancy_threshold) + obstacle_count = np.sum(costmap.grid >= self.config.occupancy_threshold) return int(free_count + obstacle_count) def _get_neighbors(self, point: GridPoint, costmap: OccupancyGrid) -> list[GridPoint]: @@ -250,7 +248,7 @@ def _is_frontier_point(self, point: GridPoint, costmap: OccupancyGrid) -> bool: neighbor_cost = costmap.grid[neighbor.y, neighbor.x] # If adjacent to occupied space, not a frontier - if neighbor_cost > self.occupancy_threshold: + if neighbor_cost > self.config.occupancy_threshold: return False # Check if adjacent to free space @@ -374,7 +372,7 @@ def detect_frontiers(self, robot_pose: Vector3, costmap: OccupancyGrid) -> list[ # Check if we found a large enough frontier # Convert minimum perimeter to minimum number of cells based on resolution - min_cells = int(self.min_frontier_perimeter / costmap.resolution) + min_cells = int(self.config.min_frontier_perimeter / costmap.resolution) if len(new_frontier) >= min_cells: world_points = [] for point in new_frontier: @@ -487,7 +485,7 @@ def _compute_distance_to_obstacles(self, frontier: Vector3, costmap: OccupancyGr min_distance = float("inf") search_radius = ( - int(self.safe_distance / costmap.resolution) + 5 + int(self.config.safe_distance / costmap.resolution) + 5 ) # Search a bit beyond minimum # Search in a square around the frontier point @@ -506,14 +504,14 @@ def _compute_distance_to_obstacles(self, frontier: Vector3, costmap: OccupancyGr continue # Check if this cell is an obstacle - if costmap.grid[check_y, check_x] >= self.occupancy_threshold: + if costmap.grid[check_y, check_x] >= self.config.occupancy_threshold: # Calculate distance in meters distance = np.sqrt(dx**2 + dy**2) * costmap.resolution min_distance = min(min_distance, distance) # If no obstacles found within search radius, return the safe distance # This indicates the frontier is safely away from obstacles - return min_distance if min_distance != float("inf") else self.safe_distance + return min_distance if min_distance != float("inf") else self.config.safe_distance def _compute_comprehensive_frontier_score( self, frontier: Vector3, frontier_size: int, robot_pose: Vector3, costmap: OccupancyGrid @@ -525,25 +523,25 @@ def _compute_comprehensive_frontier_score( # Distance score: prefer moderate distances (not too close, not too far) # Normalized to 0-1 range - distance_score = 1.0 / (1.0 + abs(robot_distance - self.lookahead_distance)) + distance_score = 1.0 / (1.0 + abs(robot_distance - self.config.lookahead_distance)) # 2. Information gain (frontier size) # Normalize by a reasonable max frontier size - max_expected_frontier_size = self.min_frontier_perimeter / costmap.resolution * 10 + max_expected_frontier_size = self.config.min_frontier_perimeter / costmap.resolution * 10 info_gain_score = min(frontier_size / max_expected_frontier_size, 1.0) # 3. Distance to explored goals (bonus for being far from explored areas) # Normalize by a reasonable max distance (e.g., 10 meters) explored_goals_distance = self._compute_distance_to_explored_goals(frontier) - explored_goals_score = min(explored_goals_distance / self.max_explored_distance, 1.0) + explored_goals_score = min(explored_goals_distance / self.config.max_explored_distance, 1.0) # 4. Distance to obstacles (score based on safety) # 0 = too close to obstacles, 1 = at or beyond safe distance obstacles_distance = self._compute_distance_to_obstacles(frontier, costmap) - if obstacles_distance >= self.safe_distance: + if obstacles_distance >= self.config.safe_distance: obstacles_score = 1.0 # Fully safe else: - obstacles_score = obstacles_distance / self.safe_distance # Linear penalty + obstacles_score = obstacles_distance / self.config.safe_distance # Linear penalty # 5. Direction momentum (already in 0-1 range from dot product) momentum_score = self._compute_direction_momentum_score(frontier, robot_pose) @@ -626,15 +624,15 @@ def get_exploration_goal(self, robot_pose: Vector3, costmap: OccupancyGrid) -> V # Check if information increase meets minimum percentage threshold if last_info > 0: # Avoid division by zero info_increase_percent = (current_info - last_info) / last_info - if info_increase_percent < self.info_gain_threshold: + if info_increase_percent < self.config.info_gain_threshold: logger.info( - f"Information increase ({info_increase_percent:.2f}) below threshold ({self.info_gain_threshold:.2f})" + f"Information increase ({info_increase_percent:.2f}) below threshold ({self.config.info_gain_threshold:.2f})" ) logger.info( f"Current information: {current_info}, Last information: {last_info}" ) self.no_gain_counter += 1 - if self.no_gain_counter >= self.num_no_gain_attempts: + if self.no_gain_counter >= self.config.num_no_gain_attempts: logger.info( f"No information gain for {self.no_gain_counter} consecutive attempts" ) @@ -795,7 +793,7 @@ def _exploration_loop(self) -> None: # Wait for goal to be reached or timeout logger.info("Waiting for goal to be reached...") - goal_reached = self.goal_reached_event.wait(timeout=self.goal_timeout) + goal_reached = self.goal_reached_event.wait(timeout=self.config.goal_timeout) if goal_reached: logger.info("Goal reached, finding next frontier") diff --git a/dimos/navigation/visual/query.py b/dimos/navigation/visual/query.py index 2e0951951e..4e3931935a 100644 --- a/dimos/navigation/visual/query.py +++ b/dimos/navigation/visual/query.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from dimos.models.qwen.video_query import BBox from dimos.models.vl.base import VlModel @@ -20,7 +21,7 @@ def get_object_bbox_from_image( - vl_model: VlModel, image: Image, object_description: str + vl_model: VlModel[Any], image: Image, object_description: str ) -> BBox | None: prompt = ( f"Look at this image and find the '{object_description}'. " diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index 3b24422c47..49582f72e3 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -15,6 +15,7 @@ from collections.abc import Callable, Generator import functools from typing import TypedDict +from unittest import mock from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations from dimos_lcm.foxglove_msgs.SceneUpdate import SceneUpdate @@ -203,7 +204,8 @@ def detection3dpc(detections3dpc) -> Detection3DPC: def get_moment_2d(get_moment) -> Generator[Callable[[], Moment2D], None, None]: from dimos.perception.detection.detectors import Yolo2DDetector - module = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu")) + c = mock.create_autospec(CameraInfo, spec_set=True, instance=True) + module = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu"), camera_info=c) @functools.lru_cache(maxsize=1) def moment_provider(**kwargs) -> Moment2D: @@ -259,7 +261,8 @@ def object_db_module(get_moment): """Create and populate an ObjectDBModule with detections from multiple frames.""" from dimos.perception.detection.detectors import Yolo2DDetector - module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu")) + c = mock.create_autospec(CameraInfo, spec_set=True, instance=True) + module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu"), camera_info=c) module3d = Detection3DModule(camera_info=connection._camera_info_static()) moduleDB = ObjectDBModule(camera_info=connection._camera_info_static()) diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index cfca3b2192..2a95dfb570 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any +from collections.abc import Callable, Sequence +from typing import Annotated, Any from dimos_lcm.foxglove_msgs.ImageAnnotations import ( ImageAnnotations, ) +from pydantic.experimental.pipeline import validate_as from reactivex import operators as ops from reactivex.observable import Observable from reactivex.subject import Subject @@ -36,24 +36,21 @@ from dimos.utils.reactive import backpressure -@dataclass class Config(ModuleConfig): max_freq: float = 10 detector: Callable[[Any], Detector] | None = Yolo2DDetector publish_detection_images: bool = True - camera_info: CameraInfo = None # type: ignore[assignment] - filter: list[Filter2D] | Filter2D | None = None + camera_info: CameraInfo + filter: Annotated[ + Sequence[Filter2D], + validate_as(Sequence[Filter2D] | Filter2D).transform( + lambda f: f if isinstance(f, Sequence) else (f,) + ), + ] = () - def __post_init__(self) -> None: - if self.filter is None: - self.filter = [] - elif not isinstance(self.filter, list): - self.filter = [self.filter] - -class Detection2DModule(Module): +class Detection2DModule(Module[Config]): default_config = Config - config: Config detector: Detector color_image: In[Image] diff --git a/dimos/perception/experimental/temporal_memory/entity_graph_db.py b/dimos/perception/experimental/temporal_memory/entity_graph_db.py index 7109459f40..953fd00dac 100644 --- a/dimos/perception/experimental/temporal_memory/entity_graph_db.py +++ b/dimos/perception/experimental/temporal_memory/entity_graph_db.py @@ -931,7 +931,7 @@ def estimate_and_save_distances( self, parsed: dict[str, Any], frame_image: "Image", - vlm: "VlModel", + vlm: "VlModel[Any]", timestamp_s: float, max_distance_pairs: int = 5, ) -> None: diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py index 66b6fce911..4d54bdc514 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -112,7 +112,7 @@ class TemporalMemory(Module): color_image: In[Image] def __init__( - self, vlm: VlModel | None = None, config: TemporalMemoryConfig | None = None + self, vlm: VlModel[Any] | None = None, config: TemporalMemoryConfig | None = None ) -> None: super().__init__() @@ -183,7 +183,7 @@ def __init__( ) @property - def vlm(self) -> VlModel: + def vlm(self) -> VlModel[Any]: """Get or create VLM instance lazily.""" if self._vlm is None: from dimos.models.vl.openai import OpenAIVlModel diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py b/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py index ab3cc7a0f5..755e9cc1f0 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py @@ -17,7 +17,7 @@ """ import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from dimos.core._dask_exports import DimosCluster from dimos.models.vl.base import VlModel @@ -32,7 +32,7 @@ def deploy( dimos: DimosCluster, camera: CameraSpec, - vlm: VlModel | None = None, + vlm: VlModel[Any] | None = None, config: TemporalMemoryConfig | None = None, ) -> TemporalMemory: """Deploy TemporalMemory with a camera. diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py b/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py index 8d05f8c1e1..9d30cd3338 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py +++ b/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py @@ -30,7 +30,7 @@ def extract_time_window( question: str, - vlm: "VlModel", + vlm: "VlModel[Any]", latest_frame: "Image | None" = None, ) -> float | None: """Extract time window from question using VLM with example-based learning. diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py index 17f470cff6..27b7c0e93c 100644 --- a/dimos/perception/object_tracker_2d.py +++ b/dimos/perception/object_tracker_2d.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import logging import threading import time @@ -45,7 +44,6 @@ logger = setup_logger(level=logging.INFO) -@dataclass class ObjectTracker2DConfig(ModuleConfig): frame_id: str = "camera_link" @@ -59,7 +57,6 @@ class ObjectTracker2D(Module[ObjectTracker2DConfig]): tracked_overlay: Out[Image] # Visualization output default_config = ObjectTracker2DConfig - config: ObjectTracker2DConfig def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: """Initialize 2D object tracking module using OpenCV's CSRT tracker.""" diff --git a/dimos/protocol/pubsub/bridge.py b/dimos/protocol/pubsub/bridge.py index f312caed7b..72cbe155d9 100644 --- a/dimos/protocol/pubsub/bridge.py +++ b/dimos/protocol/pubsub/bridge.py @@ -16,10 +16,9 @@ from __future__ import annotations -from dataclasses import dataclass from typing import TYPE_CHECKING, Generic, Protocol, TypeVar -from dimos.protocol.service.spec import Service +from dimos.protocol.service.spec import BaseConfig, Service if TYPE_CHECKING: from collections.abc import Callable @@ -66,8 +65,7 @@ def pass_msg(msg: MsgFrom, topic: TopicFrom) -> None: return pubsub1.subscribe_all(pass_msg) -@dataclass -class BridgeConfig(Generic[TopicFrom, TopicTo, MsgFrom, MsgTo]): +class BridgeConfig(BaseConfig, Generic[TopicFrom, TopicTo, MsgFrom, MsgTo]): """Configuration for a one-way bridge.""" source: AllPubSub[TopicFrom, MsgFrom] diff --git a/dimos/protocol/pubsub/impl/lcmpubsub.py b/dimos/protocol/pubsub/impl/lcmpubsub.py index bf6bbd0dec..09b84ff644 100644 --- a/dimos/protocol/pubsub/impl/lcmpubsub.py +++ b/dimos/protocol/pubsub/impl/lcmpubsub.py @@ -25,7 +25,7 @@ ) from dimos.protocol.pubsub.patterns import Glob from dimos.protocol.pubsub.spec import AllPubSub -from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf +from dimos.protocol.service.lcmservice import LCMService, autoconf from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -83,7 +83,6 @@ class LCMPubSubBase(LCMService, AllPubSub[Topic, Any]): RegexSubscribable directly without needing discovery-based fallback. """ - default_config = LCMConfig _stop_event: threading.Event _thread: threading.Thread | None diff --git a/dimos/protocol/pubsub/impl/redispubsub.py b/dimos/protocol/pubsub/impl/redispubsub.py index 6cc089e953..b299d6b883 100644 --- a/dimos/protocol/pubsub/impl/redispubsub.py +++ b/dimos/protocol/pubsub/impl/redispubsub.py @@ -14,25 +14,24 @@ from collections import defaultdict from collections.abc import Callable -from dataclasses import dataclass, field import json import threading import time from types import TracebackType from typing import Any +from pydantic import Field import redis # type: ignore[import-not-found] from dimos.protocol.pubsub.spec import PubSub -from dimos.protocol.service.spec import Service +from dimos.protocol.service.spec import BaseConfig, Service -@dataclass -class RedisConfig: +class RedisConfig(BaseConfig): host: str = "localhost" port: int = 6379 db: int = 0 - kwargs: dict[str, Any] = field(default_factory=dict) + kwargs: dict[str, Any] = Field(default_factory=dict) class Redis(PubSub[str, Any], Service[RedisConfig]): diff --git a/dimos/protocol/service/__init__.py b/dimos/protocol/service/__init__.py index fb9df08ca9..ed6caf93c2 100644 --- a/dimos/protocol/service/__init__.py +++ b/dimos/protocol/service/__init__.py @@ -1,8 +1,9 @@ from dimos.protocol.service.lcmservice import LCMService -from dimos.protocol.service.spec import Configurable as Configurable, Service as Service +from dimos.protocol.service.spec import BaseConfig, Configurable, Service -__all__ = [ +__all__ = ( + "BaseConfig", "Configurable", "LCMService", "Service", -] +) diff --git a/dimos/protocol/service/ddsservice.py b/dimos/protocol/service/ddsservice.py index 6ed04c07ad..b5562defff 100644 --- a/dimos/protocol/service/ddsservice.py +++ b/dimos/protocol/service/ddsservice.py @@ -14,9 +14,8 @@ from __future__ import annotations -from dataclasses import dataclass import threading -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING try: from cyclonedds.domain import DomainParticipant @@ -26,7 +25,7 @@ DDS_AVAILABLE = False DomainParticipant = None # type: ignore[assignment, misc] -from dimos.protocol.service.spec import Service +from dimos.protocol.service.spec import BaseConfig, Service from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -38,8 +37,7 @@ _participants_lock = threading.Lock() -@dataclass -class DDSConfig: +class DDSConfig(BaseConfig): """Configuration for DDS service.""" domain_id: int = 0 @@ -49,9 +47,6 @@ class DDSConfig: class DDSService(Service[DDSConfig]): default_config = DDSConfig - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - def start(self) -> None: """Start the DDS service.""" domain_id = self.config.domain_id diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py index 4655780fb3..f56a61f623 100644 --- a/dimos/protocol/service/lcmservice.py +++ b/dimos/protocol/service/lcmservice.py @@ -15,15 +15,16 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass import os import platform +import sys import threading import traceback +from typing import Any -import lcm +import lcm as lcm_mod -from dimos.protocol.service.spec import Service +from dimos.protocol.service.spec import BaseConfig, Service from dimos.protocol.service.system_configurator import ( BufferConfiguratorLinux, BufferConfiguratorMacOS, @@ -35,6 +36,11 @@ ) from dimos.utils.logging_config import setup_logger +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + logger = setup_logger() _DEFAULT_LCM_HOST = "239.255.76.67" @@ -66,41 +72,38 @@ def autoconf(check_only: bool = False) -> None: configure_system(checks, check_only=check_only) -@dataclass -class LCMConfig: +class LCMConfig(BaseConfig): ttl: int = 0 - url: str | None = None + url: str = _DEFAULT_LCM_URL autoconf: bool = True - lcm: lcm.LCM | None = None - - def __post_init__(self) -> None: - if self.url is None: - self.url = _DEFAULT_LCM_URL + lcm: lcm_mod.LCM | None = None +_Config = TypeVar("_Config", bound=LCMConfig, default=LCMConfig) _LCM_LOOP_TIMEOUT = 50 # this class just sets up cpp LCM instance # and runs its handle loop in a thread # higher order stuff is done by pubsub/impl/lcmpubsub.py -class LCMService(Service[LCMConfig]): - default_config = LCMConfig - l: lcm.LCM | None +class LCMService(Service[_Config]): + default_config = LCMConfig # type: ignore[assignment] + + l: lcm_mod.LCM | None _stop_event: threading.Event _l_lock: threading.Lock _thread: threading.Thread | None _call_thread_pool: ThreadPoolExecutor | None = None _call_thread_pool_lock: threading.RLock = threading.RLock() - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) # we support passing an existing LCM instance if self.config.lcm: self.l = self.config.lcm else: - self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() + self.l = lcm_mod.LCM(self.config.url) if self.config.url else lcm_mod.LCM() self._l_lock = threading.Lock() self._stop_event = threading.Event() @@ -135,7 +138,7 @@ def start(self) -> None: if self.config.lcm: self.l = self.config.lcm else: - self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() + self.l = lcm_mod.LCM(self.config.url) if self.config.url else lcm_mod.LCM() try: autoconf(check_only=not self.config.autoconf) diff --git a/dimos/protocol/service/spec.py b/dimos/protocol/service/spec.py index 37ded73c06..4dcb9398b6 100644 --- a/dimos/protocol/service/spec.py +++ b/dimos/protocol/service/spec.py @@ -15,8 +15,15 @@ from abc import ABC from typing import Any, Generic, TypeVar +from pydantic import BaseModel + + +class BaseConfig(BaseModel): + model_config = {"arbitrary_types_allowed": True} + + # Generic type for service configuration -ConfigT = TypeVar("ConfigT") +ConfigT = TypeVar("ConfigT", bound=BaseConfig) class Configurable(Generic[ConfigT]): diff --git a/dimos/protocol/service/test_lcmservice.py b/dimos/protocol/service/test_lcmservice.py index 4231302426..6b3dcd6e06 100644 --- a/dimos/protocol/service/test_lcmservice.py +++ b/dimos/protocol/service/test_lcmservice.py @@ -14,7 +14,9 @@ import threading import time -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch + +from lcm import LCM from dimos.protocol.pubsub.impl.lcmpubsub import Topic from dimos.protocol.service.lcmservice import ( @@ -91,10 +93,6 @@ def test_custom_url(self) -> None: config = LCMConfig(url=custom_url) assert config.url == custom_url - def test_post_init_sets_default_url_when_none(self) -> None: - config = LCMConfig(url=None) - assert config.url == _DEFAULT_LCM_URL - def test_autoconf_can_be_disabled(self) -> None: config = LCMConfig(autoconf=False) assert config.autoconf is False @@ -120,8 +118,8 @@ def test_str_with_lcm_type(self) -> None: class TestLCMService: def test_init_with_default_config(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -131,8 +129,8 @@ def test_init_with_default_config(self) -> None: def test_init_with_custom_url(self) -> None: custom_url = "udpm://192.168.1.1:7777?ttl=1" - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance # Pass url as kwarg, not config= @@ -140,17 +138,17 @@ def test_init_with_custom_url(self) -> None: mock_lcm_class.assert_called_once_with(custom_url) def test_init_with_existing_lcm_instance(self) -> None: - mock_lcm_instance = MagicMock() + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: # Pass lcm as kwarg service = LCMService(lcm=mock_lcm_instance) mock_lcm_class.assert_not_called() assert service.l == mock_lcm_instance def test_start_and_stop(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance with patch("dimos.protocol.service.lcmservice.autoconf"): @@ -168,8 +166,8 @@ def test_start_and_stop(self) -> None: assert not service._thread.is_alive() def test_start_calls_configure_system(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance with patch("dimos.protocol.service.lcmservice.autoconf") as mock_configure: @@ -182,8 +180,8 @@ def test_start_calls_configure_system(self) -> None: service.stop() def test_start_with_autoconf_disabled(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance with patch("dimos.protocol.service.lcmservice.autoconf") as mock_configure: @@ -196,8 +194,8 @@ def test_start_with_autoconf_disabled(self) -> None: service.stop() def test_getstate_excludes_unpicklable_attrs(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -211,8 +209,8 @@ def test_getstate_excludes_unpicklable_attrs(self) -> None: assert "_call_thread_pool_lock" not in state def test_setstate_reinitializes_runtime_attrs(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -231,8 +229,8 @@ def test_setstate_reinitializes_runtime_attrs(self) -> None: assert hasattr(new_service._l_lock, "release") def test_start_reinitializes_lcm_after_unpickling(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance with patch("dimos.protocol.service.lcmservice.autoconf"): @@ -252,8 +250,8 @@ def test_start_reinitializes_lcm_after_unpickling(self) -> None: new_service.stop() def test_stop_cleans_up_lcm_instance(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance with patch("dimos.protocol.service.lcmservice.autoconf"): @@ -265,7 +263,7 @@ def test_stop_cleans_up_lcm_instance(self) -> None: assert service.l is None def test_stop_preserves_external_lcm_instance(self) -> None: - mock_lcm_instance = MagicMock() + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) with patch("dimos.protocol.service.lcmservice.autoconf"): # Pass lcm as kwarg @@ -277,8 +275,8 @@ def test_stop_preserves_external_lcm_instance(self) -> None: assert service.l == mock_lcm_instance def test_get_call_thread_pool_creates_pool(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -296,8 +294,8 @@ def test_get_call_thread_pool_creates_pool(self) -> None: pool.shutdown(wait=False) def test_stop_shuts_down_thread_pool(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance with patch("dimos.protocol.service.lcmservice.autoconf"): diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 825e89fc8c..1b5ccadf3c 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -16,7 +16,7 @@ from abc import abstractmethod from collections import deque -from dataclasses import dataclass, field +from dataclasses import field from functools import reduce from typing import TypeVar @@ -25,23 +25,22 @@ from dimos.msgs.tf2_msgs import TFMessage from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic from dimos.protocol.pubsub.spec import PubSub -from dimos.protocol.service.lcmservice import Service # type: ignore[attr-defined] +from dimos.protocol.service.spec import BaseConfig, Service CONFIG = TypeVar("CONFIG") # generic configuration for transform service -@dataclass -class TFConfig: +class TFConfig(BaseConfig): buffer_size: float = 10.0 # seconds rate_limit: float = 10.0 # Hz -# generic specification for transform service -class TFSpec(Service[TFConfig]): - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] - super().__init__(**kwargs) +_TFConfig = TypeVar("_TFConfig", bound=TFConfig) + +# generic specification for transform service +class TFSpec(Service[_TFConfig]): @abstractmethod def publish(self, *args: Transform) -> None: ... @@ -244,15 +243,17 @@ def __str__(self) -> str: return "\n".join(lines) -@dataclass class PubSubTFConfig(TFConfig): topic: Topic | None = None # Required field but needs default for dataclass inheritance pubsub: type[PubSub] | PubSub | None = None # type: ignore[type-arg] autostart: bool = True -class PubSubTF(MultiTBuffer, TFSpec): - default_config: type[PubSubTFConfig] = PubSubTFConfig +_PubSubConfig = TypeVar("_PubSubConfig", bound=PubSubTFConfig) + + +class PubSubTF(MultiTBuffer, TFSpec[_PubSubConfig]): + default_config: type[_PubSubConfig] = PubSubTFConfig # type: ignore[assignment] def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] TFSpec.__init__(self, **kwargs) @@ -330,15 +331,14 @@ def receive_msg(self, msg: TFMessage, topic: Topic) -> None: self.receive_tfmessage(msg) -@dataclass class LCMPubsubConfig(PubSubTFConfig): topic: Topic = field(default_factory=lambda: Topic("/tf", TFMessage)) pubsub: type[PubSub] | PubSub | None = LCM # type: ignore[type-arg] autostart: bool = True -class LCMTF(PubSubTF): - default_config: type[LCMPubsubConfig] = LCMPubsubConfig +class LCMTF(PubSubTF[LCMPubsubConfig]): + default_config = LCMPubsubConfig TF = LCMTF diff --git a/dimos/protocol/tf/tflcmcpp.py b/dimos/protocol/tf/tflcmcpp.py index 158a68d3d8..bf2885958d 100644 --- a/dimos/protocol/tf/tflcmcpp.py +++ b/dimos/protocol/tf/tflcmcpp.py @@ -13,15 +13,18 @@ # limitations under the License. from datetime import datetime -from typing import Union from dimos.msgs.geometry_msgs import Transform from dimos.protocol.service.lcmservice import LCMConfig, LCMService from dimos.protocol.tf.tf import TFConfig, TFSpec +class Config(TFConfig, LCMConfig): + """Combined config""" + + # this doesn't work due to tf_lcm_py package -class TFLCM(TFSpec, LCMService): +class TFLCM(TFSpec[Config], LCMService[Config]): """A service for managing and broadcasting transforms using LCM. This is not a separete module, You can include this in your module if you need to access transforms. @@ -34,7 +37,7 @@ class TFLCM(TFSpec, LCMService): for each module. """ - default_config = Union[TFConfig, LCMConfig] + default_config = Config def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) diff --git a/dimos/robot/drone/connection_module.py b/dimos/robot/drone/connection_module.py index db5c4ca4cc..a0522668b8 100644 --- a/dimos/robot/drone/connection_module.py +++ b/dimos/robot/drone/connection_module.py @@ -26,6 +26,8 @@ from dimos.agents.annotation import skill from dimos.core import In, Module, Out, rpc +from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.module import ModuleConfig from dimos.mapping.types import LatLon from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 from dimos.msgs.sensor_msgs import Image @@ -43,9 +45,17 @@ def _add_disposable(composite: CompositeDisposable, item: Disposable | Any) -> N composite.add(Disposable(item)) -class DroneConnectionModule(Module): +class Config(ModuleConfig): + connection_string: str = "udp:0.0.0.0:14550" + video_port: int = 5600 + outdoor: bool = False + + +class DroneConnectionModule(Module[Config]): """Module that handles drone sensor data and movement commands.""" + default_config = Config + # Inputs movecmd: In[Vector3] movecmd_twist: In[Twist] # Twist commands from tracking/navigation @@ -60,9 +70,6 @@ class DroneConnectionModule(Module): video: Out[Image] follow_object_cmd: Out[Any] - # Parameters - connection_string: str - # Internal state _odom: PoseStamped | None = None _status: dict[str, Any] = {} @@ -71,14 +78,7 @@ class DroneConnectionModule(Module): _latest_status: dict[str, Any] | None = None _latest_status_lock: threading.RLock - def __init__( - self, - connection_string: str = "udp:0.0.0.0:14550", - video_port: int = 5600, - outdoor: bool = False, - *args: Any, - **kwargs: Any, - ) -> None: + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: """Initialize drone connection module. Args: @@ -86,9 +86,6 @@ def __init__( video_port: UDP port for video stream outdoor: Use GPS only mode (no velocity integration) """ - self.connection_string = connection_string - self.video_port = video_port - self.outdoor = outdoor self.connection: MavlinkConnection | None = None self.video_stream: DJIDroneVideoStream | None = None self._latest_video_frame = None @@ -97,23 +94,25 @@ def __init__( self._latest_status_lock = threading.RLock() self._running = False self._telemetry_thread: threading.Thread | None = None - Module.__init__(self, *args, **kwargs) + super().__init__(global_config, **kwargs) @rpc def start(self) -> None: """Start the connection and subscribe to sensor streams.""" # Check for replay mode - if self.connection_string == "replay": + if self.config.connection_string == "replay": from dimos.robot.drone.dji_video_stream import FakeDJIVideoStream from dimos.robot.drone.mavlink_connection import FakeMavlinkConnection self.connection = FakeMavlinkConnection("replay") - self.video_stream = FakeDJIVideoStream(port=self.video_port) + self.video_stream = FakeDJIVideoStream(port=self.config.video_port) else: - self.connection = MavlinkConnection(self.connection_string, outdoor=self.outdoor) + self.connection = MavlinkConnection( + self.config.connection_string, outdoor=self.config.outdoor + ) self.connection.connect() - self.video_stream = DJIDroneVideoStream(port=self.video_port) + self.video_stream = DJIDroneVideoStream(port=self.config.video_port) if not self.connection.connected: logger.error("Failed to connect to drone") diff --git a/dimos/simulation/manipulators/sim_module.py b/dimos/simulation/manipulators/sim_module.py index df7e26dd01..964ddcbee0 100644 --- a/dimos/simulation/manipulators/sim_module.py +++ b/dimos/simulation/manipulators/sim_module.py @@ -23,6 +23,7 @@ from reactivex.disposable import Disposable from dimos.core import In, Module, Out, rpc +from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import ModuleConfig from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState from dimos.simulation.engines import EngineType, get_engine @@ -43,7 +44,6 @@ class SimulationModule(Module[SimulationModuleConfig]): """Module wrapper for manipulator simulation across engines.""" default_config = SimulationModuleConfig - config: SimulationModuleConfig joint_state: Out[JointState] robot_state: Out[RobotState] @@ -52,8 +52,8 @@ class SimulationModule(Module[SimulationModuleConfig]): MIN_CONTROL_RATE = 1.0 - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: + super().__init__(global_config, **kwargs) self._backend: SimManipInterface | None = None self._control_rate = 100.0 self._monitor_rate = 100.0 diff --git a/dimos/simulation/manipulators/test_sim_module.py b/dimos/simulation/manipulators/test_sim_module.py index 334e2ce85f..72408fefed 100644 --- a/dimos/simulation/manipulators/test_sim_module.py +++ b/dimos/simulation/manipulators/test_sim_module.py @@ -17,10 +17,11 @@ import pytest +from dimos.protocol.rpc import RPCSpec from dimos.simulation.manipulators.sim_module import SimulationModule -class _DummyRPC: +class _DummyRPC(RPCSpec): def serve_module_rpc(self, _module) -> None: # type: ignore[no-untyped-def] return None diff --git a/dimos/utils/cli/lcmspy/lcmspy.py b/dimos/utils/cli/lcmspy/lcmspy.py index 5493e53024..b66b40f4dd 100755 --- a/dimos/utils/cli/lcmspy/lcmspy.py +++ b/dimos/utils/cli/lcmspy/lcmspy.py @@ -13,10 +13,10 @@ # limitations under the License. from collections import deque -from dataclasses import dataclass from enum import Enum import threading import time +from typing import Any from dimos.protocol.service.lcmservice import LCMConfig, LCMService @@ -116,20 +116,19 @@ def __str__(self) -> str: return f"topic({self.name})" -@dataclass class LCMSpyConfig(LCMConfig): topic_history_window: float = 60.0 -class LCMSpy(LCMService, Topic): +class LCMSpy(LCMService[LCMSpyConfig], Topic): default_config = LCMSpyConfig topic = dict[str, Topic] graph_log_window: float = 1.0 topic_class: type[Topic] = Topic - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - Topic.__init__(self, name="total", history_window=self.config.topic_history_window) # type: ignore[attr-defined] + Topic.__init__(self, name="total", history_window=self.config.topic_history_window) self.topic = {} # type: ignore[assignment] def start(self) -> None: @@ -166,7 +165,6 @@ def update_graphs(self, step_window: float = 1.0) -> None: self.bandwidth_history.append(kbps) -@dataclass class GraphLCMSpyConfig(LCMSpyConfig): graph_log_window: float = 1.0 @@ -178,9 +176,9 @@ class GraphLCMSpy(LCMSpy, GraphTopic): graph_log_stop_event: threading.Event = threading.Event() topic_class: type[Topic] = GraphTopic - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - GraphTopic.__init__(self, name="total", history_window=self.config.topic_history_window) # type: ignore[attr-defined] + GraphTopic.__init__(self, name="total", history_window=self.config.topic_history_window) def start(self) -> None: super().start() From 2f17d436f3925b5f7054cfbd23315b04458727c1 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Thu, 26 Feb 2026 19:37:37 +0000 Subject: [PATCH 7/9] Fixes --- .../temporal_memory/temporal_memory.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py index 4d54bdc514..6d66955a61 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -34,6 +34,7 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc +from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In from dimos.models.vl.base import VlModel @@ -69,6 +70,8 @@ class Frame: @dataclass class TemporalMemoryConfig(ModuleConfig): + vlm: VlModel[Any] | None = None + # Frame processing fps: float = 1.0 window_s: float = 2.0 @@ -100,7 +103,7 @@ class TemporalMemoryConfig(ModuleConfig): nearby_distance_meters: float = 5.0 # "Nearby" threshold -class TemporalMemory(Module): +class TemporalMemory(Module[TemporalMemoryConfig]): """ builds temporal understanding of video streams using vlms. @@ -110,14 +113,12 @@ class TemporalMemory(Module): """ color_image: In[Image] + default_config = TemporalMemoryConfig - def __init__( - self, vlm: VlModel[Any] | None = None, config: TemporalMemoryConfig | None = None - ) -> None: - super().__init__() + def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: + super().__init__(global_config, **kwargs) - self._vlm = vlm # Can be None for blueprint usage - self.config: TemporalMemoryConfig = config or TemporalMemoryConfig() + self._vlm = self.config.vlm # Can be None for blueprint usage # single lock protects all state self._state_lock = threading.Lock() From 2cbe929f1a3581e44eed3d8ba56cb7f190c83bd4 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Thu, 26 Feb 2026 19:46:18 +0000 Subject: [PATCH 8/9] Fixes --- dimos/agents/mcp/mcp_server.py | 2 +- dimos/control/coordinator.py | 6 +++--- dimos/core/_test_future_annotations_helper.py | 2 +- dimos/core/module.py | 2 +- dimos/hardware/sensors/lidar/fastlio2/module.py | 6 +++--- dimos/hardware/sensors/lidar/livox/module.py | 6 +++--- dimos/models/base.py | 1 - dimos/models/vl/base.py | 1 - dimos/models/vl/moondream.py | 1 - dimos/models/vl/openai.py | 1 - dimos/models/vl/qwen.py | 1 - dimos/protocol/pubsub/impl/lcmpubsub.py | 11 ++++------- dimos/simulation/manipulators/sim_module.py | 8 +++----- pyproject.toml | 8 ++++++-- 14 files changed, 25 insertions(+), 31 deletions(-) diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index 1f8ce92888..6dfb15ca58 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -31,7 +31,7 @@ from dimos.core import Module, rpc # noqa: I001 from dimos.core.rpc_client import RpcCall, RPCClient -from starlette.requests import Request # noqa: TC002 +from starlette.requests import Request if TYPE_CHECKING: import concurrent.futures diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index f2fca8b301..34e7ab7cac 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -49,14 +49,14 @@ TwistBaseAdapter, ) from dimos.msgs.geometry_msgs import ( - PoseStamped, # noqa: TC001 - needed at runtime for In[PoseStamped] - Twist, # noqa: TC001 - needed at runtime for In[Twist] + PoseStamped, + Twist, ) from dimos.msgs.sensor_msgs import ( JointState, ) from dimos.teleop.quest.quest_types import ( - Buttons, # noqa: TC001 - needed at runtime for In[Buttons] + Buttons, ) from dimos.utils.logging_config import setup_logger diff --git a/dimos/core/_test_future_annotations_helper.py b/dimos/core/_test_future_annotations_helper.py index 08c5ec0063..a96ce0f587 100644 --- a/dimos/core/_test_future_annotations_helper.py +++ b/dimos/core/_test_future_annotations_helper.py @@ -21,7 +21,7 @@ from __future__ import annotations from dimos.core.module import Module -from dimos.core.stream import In, Out # noqa +from dimos.core.stream import In, Out class FutureData: diff --git a/dimos/core/module.py b/dimos/core/module.py index 76b298d21b..65e150d57e 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -46,7 +46,7 @@ from dimos.core.global_config import GlobalConfig, global_config from dimos.core.introspection.module import extract_module_info, render_module_io from dimos.core.resource import Resource -from dimos.core.rpc_client import RpcCall # noqa: TC001 +from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.protocol.rpc import LCMRPC, RPCSpec from dimos.protocol.service import BaseConfig, Configurable diff --git a/dimos/hardware/sensors/lidar/fastlio2/module.py b/dimos/hardware/sensors/lidar/fastlio2/module.py index ee9a0783a0..fcd20ee9dd 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/module.py +++ b/dimos/hardware/sensors/lidar/fastlio2/module.py @@ -34,7 +34,7 @@ from pathlib import Path from typing import TYPE_CHECKING -from dimos.core import Out # noqa: TC001 +from dimos.core import Out from dimos.core.native_module import NativeModule, NativeModuleConfig from dimos.hardware.sensors.lidar.livox.ports import ( SDK_CMD_DATA_PORT, @@ -48,8 +48,8 @@ SDK_POINT_DATA_PORT, SDK_PUSH_MSG_PORT, ) -from dimos.msgs.nav_msgs.Odometry import Odometry # noqa: TC001 -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 # noqa: TC001 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.spec import mapping, perception _CONFIG_DIR = Path(__file__).parent / "config" diff --git a/dimos/hardware/sensors/lidar/livox/module.py b/dimos/hardware/sensors/lidar/livox/module.py index 8d6982dcf8..f87384657f 100644 --- a/dimos/hardware/sensors/lidar/livox/module.py +++ b/dimos/hardware/sensors/lidar/livox/module.py @@ -28,7 +28,7 @@ from typing import TYPE_CHECKING -from dimos.core import Out # noqa: TC001 +from dimos.core import Out from dimos.core.native_module import NativeModule, NativeModuleConfig from dimos.hardware.sensors.lidar.livox.ports import ( SDK_CMD_DATA_PORT, @@ -42,8 +42,8 @@ SDK_POINT_DATA_PORT, SDK_PUSH_MSG_PORT, ) -from dimos.msgs.sensor_msgs.Imu import Imu # noqa: TC001 -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 # noqa: TC001 +from dimos.msgs.sensor_msgs.Imu import Imu +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.spec import perception diff --git a/dimos/models/base.py b/dimos/models/base.py index 0d1ea97f12..d03ce5c539 100644 --- a/dimos/models/base.py +++ b/dimos/models/base.py @@ -16,7 +16,6 @@ from __future__ import annotations -from dataclasses import dataclass from functools import cached_property from typing import Annotated, Any diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py index 4cc6e75750..cc8c2081bb 100644 --- a/dimos/models/vl/base.py +++ b/dimos/models/vl/base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass import json import logging import sys diff --git a/dimos/models/vl/moondream.py b/dimos/models/vl/moondream.py index 7b79a6d6f2..c444d8b9ed 100644 --- a/dimos/models/vl/moondream.py +++ b/dimos/models/vl/moondream.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from functools import cached_property from typing import Any import warnings diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py index 94f6e20b62..ec774189e4 100644 --- a/dimos/models/vl/openai.py +++ b/dimos/models/vl/openai.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from functools import cached_property import os from typing import Any diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index dfcf3e7809..014c6f73a5 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from functools import cached_property import os from typing import Any diff --git a/dimos/protocol/pubsub/impl/lcmpubsub.py b/dimos/protocol/pubsub/impl/lcmpubsub.py index 09b84ff644..4e792f5965 100644 --- a/dimos/protocol/pubsub/impl/lcmpubsub.py +++ b/dimos/protocol/pubsub/impl/lcmpubsub.py @@ -14,10 +14,13 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass import re -from typing import TYPE_CHECKING, Any +import threading +from typing import Any +from dimos.msgs import DimosMsg from dimos.protocol.pubsub.encoders import ( JpegEncoderMixin, LCMEncoderMixin, @@ -28,12 +31,6 @@ from dimos.protocol.service.lcmservice import LCMService, autoconf from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from collections.abc import Callable - import threading - - from dimos.msgs import DimosMsg - logger = setup_logger() diff --git a/dimos/simulation/manipulators/sim_module.py b/dimos/simulation/manipulators/sim_module.py index 964ddcbee0..dbc9ce3c28 100644 --- a/dimos/simulation/manipulators/sim_module.py +++ b/dimos/simulation/manipulators/sim_module.py @@ -16,9 +16,11 @@ from __future__ import annotations +from collections.abc import Callable +from pathlib import Path import threading import time -from typing import TYPE_CHECKING, Any +from typing import Any from reactivex.disposable import Disposable @@ -29,10 +31,6 @@ from dimos.simulation.engines import EngineType, get_engine from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface -if TYPE_CHECKING: - from collections.abc import Callable - from pathlib import Path - class SimulationModuleConfig(ModuleConfig): engine: EngineType diff --git a/pyproject.toml b/pyproject.toml index b34bee992c..5401d4c277 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -324,8 +324,12 @@ exclude = [ [tool.ruff.lint] extend-select = ["E", "W", "F", "B", "UP", "N", "I", "C90", "A", "RUF", "TCH"] -# TODO: All of these should be fixed, but it's easier commit autofixes first -ignore = ["A001", "A002", "B008", "B017", "B019", "B024", "B026", "B904", "C901", "E402", "E501", "E721", "E722", "E741", "F811", "F821", "F821", "F821", "N801", "N802", "N803", "N806", "N817", "N999", "RUF003", "RUF009", "RUF012", "RUF034", "RUF043", "RUF059", "UP007"] +ignore = [ + # TODO: All of these should be fixed, but it's easier commit autofixes first + "A001", "A002", "B008", "B017", "B019", "B024", "B026", "B904", "C901", "E402", "E501", "E721", "E722", "E741", "F811", "F821", "F821", "F821", "N801", "N802", "N803", "N806", "N817", "N999", "RUF003", "RUF009", "RUF012", "RUF034", "RUF043", "RUF059", "UP007", + # This breaks runtime type checking (both for us, and users introspecting our APIs) + "TC001", "TC002", "TC003" +] [tool.ruff.lint.per-file-ignores] "dimos/models/Detic/*" = ["ALL"] From 9ebc89c89f87224fc5ae32b3a84cc420cdbd974b Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Fri, 27 Feb 2026 13:46:19 +0000 Subject: [PATCH 9/9] Apply suggestion from @Dreamsorcerer --- dimos/agents/skills/person_follow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py index 86feb99363..0a90974cce 100644 --- a/dimos/agents/skills/person_follow.py +++ b/dimos/agents/skills/person_follow.py @@ -68,7 +68,8 @@ def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) - super().__init__(global_config, **kwargs) self._latest_image: Image | None = None self._latest_pointcloud: PointCloud2 | None = None - self._vl_model = QwenVlModel() + # Use VlModel to keep usage in this class generic + self._vl_model: VlModel = QwenVlModel() self._tracker: EdgeTAMProcessor | None = None self._thread: Thread | None = None self._should_stop: Event = Event()