diff --git a/docs/conf.py b/docs/conf.py index db4fcc0fa..ba10ee033 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -89,7 +89,7 @@ ("py:class", "p4p.nt.ndarray.NTNDArray"), ("py:class", "p4p.nt.NTTable"), # Problems in FastCS itself - ("py:class", "T"), + ("py:class", "BaseController"), ("py:class", "AttrIOUpdateCallback"), ("py:class", "fastcs.transports.epics.pva.pvi_tree._PviSignalInfo"), ("py:class", "fastcs.logging._logging.LogLevel"), diff --git a/src/fastcs/control_system.py b/src/fastcs/control_system.py index fcd13ada0..72d5bd08f 100644 --- a/src/fastcs/control_system.py +++ b/src/fastcs/control_system.py @@ -8,7 +8,7 @@ from fastcs.controllers import BaseController, Controller from fastcs.logging import bind_logger -from fastcs.methods import Command, Scan, ScanCallback +from fastcs.methods import ScanCallback from fastcs.tracer import Tracer from fastcs.transports import ControllerAPI, Transport @@ -174,23 +174,11 @@ def build_controller_api(controller: Controller) -> ControllerAPI: def _build_controller_api(controller: BaseController, path: list[str]) -> ControllerAPI: - scan_methods: dict[str, Scan] = {} - command_methods: dict[str, Command] = {} - for attr_name in dir(controller): - attr = getattr(controller, attr_name) - match attr: - case Scan(enabled=True): - scan_methods[attr_name] = attr - case Command(enabled=True): - command_methods[attr_name] = attr - case _: - pass - return ControllerAPI( path=path, attributes=controller.attributes, - scan_methods=scan_methods, - command_methods=command_methods, + command_methods=controller.command_methods, + scan_methods=controller.scan_methods, sub_apis={ name: _build_controller_api(sub_controller, path + [name]) for name, sub_controller in controller.sub_controllers.items() diff --git a/src/fastcs/controllers/base_controller.py b/src/fastcs/controllers/base_controller.py index da053b648..0e5fc1de0 100755 --- a/src/fastcs/controllers/base_controller.py +++ b/src/fastcs/controllers/base_controller.py @@ -7,6 +7,7 @@ from fastcs.attributes import AnyAttributeIO, Attribute, AttrR, AttrW, HintedAttribute from fastcs.logging import bind_logger +from fastcs.methods import Command, Scan, UnboundCommand, UnboundScan from fastcs.tracer import Tracer logger = bind_logger(logger_name=__name__) @@ -46,6 +47,8 @@ def __init__( # Internal state that should not be accessed directly by base classes self.__attributes: dict[str, Attribute] = {} self.__sub_controllers: dict[str, BaseController] = {} + self.__command_methods: dict[str, Command] = {} + self.__scan_methods: dict[str, Scan] = {} self.__hinted_attributes: dict[str, HintedAttribute] = {} self.__hinted_sub_controllers: dict[str, type[BaseController]] = {} @@ -95,10 +98,6 @@ class method and a controller instance, so that it can be called from any context with the controller instance passed as the ``self`` argument. """ - # Lazy import to avoid circular references - from fastcs.methods.command import UnboundCommand - from fastcs.methods.scan import UnboundScan - # Using a dictionary instead of a set to maintain order. class_dir = {key: None for key in dir(type(self)) if not key.startswith("_")} class_type_hints = { @@ -114,8 +113,21 @@ class method and a controller instance, so that it can be called from any attr = getattr(self, attr_name, None) if isinstance(attr, Attribute): setattr(self, attr_name, deepcopy(attr)) - elif isinstance(attr, UnboundScan | UnboundCommand): - setattr(self, attr_name, attr.bind(self)) + else: + if isinstance(attr, Command): + self.add_command(attr_name, attr) + elif isinstance(attr, Scan): + self.add_scan(attr_name, attr) + elif isinstance( + unbound_command := getattr(attr, "__unbound_command__", None), + UnboundCommand, + ): + self.add_command(attr_name, unbound_command.bind(self)) + elif isinstance( + unbound_scan := getattr(attr, "__unbound_scan__", None), + UnboundScan, + ): + self.add_scan(attr_name, unbound_scan.bind(self)) def _validate_io(self, ios: Sequence[AnyAttributeIO]): """Validate that there is exactly one AttributeIO class registered to the @@ -137,6 +149,10 @@ def __repr__(self): def __setattr__(self, name, value): if isinstance(value, Attribute): self.add_attribute(name, value) + elif isinstance(value, Command): + self.add_command(name, value) + elif isinstance(value, Scan): + self.add_scan(name, value) elif isinstance(value, BaseController): self.add_sub_controller(name, value) else: @@ -300,3 +316,19 @@ def add_sub_controller(self, name: str, sub_controller: BaseController): @property def sub_controllers(self) -> dict[str, BaseController]: return self.__sub_controllers + + def add_command(self, name: str, command: Command): + self.__command_methods[name] = command + super().__setattr__(name, command) + + @property + def command_methods(self) -> dict[str, Command]: + return self.__command_methods + + def add_scan(self, name: str, scan: Scan): + self.__scan_methods[name] = scan + super().__setattr__(name, scan) + + @property + def scan_methods(self) -> dict[str, Scan]: + return self.__scan_methods diff --git a/src/fastcs/methods/__init__.py b/src/fastcs/methods/__init__.py index e365d7e91..abe6d61f2 100644 --- a/src/fastcs/methods/__init__.py +++ b/src/fastcs/methods/__init__.py @@ -1,6 +1,8 @@ from .command import Command as Command from .command import CommandCallback as CommandCallback +from .command import UnboundCommand as UnboundCommand from .command import command as command from .scan import Scan as Scan from .scan import ScanCallback as ScanCallback +from .scan import UnboundScan as UnboundScan from .scan import scan as scan diff --git a/src/fastcs/methods/command.py b/src/fastcs/methods/command.py index e73bd4851..87d0e0b1d 100644 --- a/src/fastcs/methods/command.py +++ b/src/fastcs/methods/command.py @@ -1,16 +1,22 @@ from collections.abc import Callable, Coroutine from types import MethodType +from typing import TYPE_CHECKING -from fastcs.controllers import BaseController +from fastcs.logging import bind_logger from fastcs.methods.method import Controller_T, Method +if TYPE_CHECKING: + from fastcs.controllers import BaseController # noqa: F401 + +logger = bind_logger(logger_name=__name__) + UnboundCommandCallback = Callable[[Controller_T], Coroutine[None, None, None]] """A Command callback that is unbound and must be called with a `Controller` instance""" CommandCallback = Callable[[], Coroutine[None, None, None]] """A Command callback that is bound and can be called without `self`""" -class Command(Method[BaseController]): +class Command(Method["BaseController"]): """A `Controller` `Method` that performs a single action when called. This class contains a function that is bound to a specific `Controller` instance and @@ -28,7 +34,18 @@ def _validate(self, fn: CommandCallback) -> None: raise TypeError(f"Command method cannot have arguments: {fn}") async def __call__(self): - return await self._fn() + return await self.fn() + + @property + def fn(self) -> CommandCallback: + async def command(): + try: + return await self._fn() + except Exception: + logger.exception("Command failed", fn=self._fn) + raise + + return command class UnboundCommand(Method[Controller_T]): @@ -56,15 +73,12 @@ def _validate(self, fn: UnboundCommandCallback[Controller_T]) -> None: def bind(self, controller: Controller_T) -> Command: return Command(MethodType(self.fn, controller), group=self.group) - def __call__(self): - raise NotImplementedError( - "Method must be bound to a controller instance to be callable" - ) - def command( *, group: str | None = None -) -> Callable[[UnboundCommandCallback[Controller_T]], UnboundCommand[Controller_T]]: +) -> Callable[ + [UnboundCommandCallback[Controller_T]], UnboundCommandCallback[Controller_T] +]: """Decorator to register a `Controller` method as a `Command` The `Command` will be passed to the transport layer to expose in the API @@ -75,7 +89,9 @@ def command( def wrapper( fn: UnboundCommandCallback[Controller_T], - ) -> UnboundCommand[Controller_T]: - return UnboundCommand(fn, group=group) + ) -> UnboundCommandCallback[Controller_T]: + setattr(fn, "__unbound_command__", UnboundCommand(fn, group=group)) # noqa: B010 + + return fn return wrapper diff --git a/src/fastcs/methods/method.py b/src/fastcs/methods/method.py index f3d5cb098..52b1b4e84 100644 --- a/src/fastcs/methods/method.py +++ b/src/fastcs/methods/method.py @@ -1,14 +1,16 @@ from asyncio import iscoroutinefunction from collections.abc import Callable, Coroutine from inspect import Signature, getdoc, signature -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar -from fastcs.controllers.base_controller import BaseController from fastcs.tracer import Tracer +if TYPE_CHECKING: + from fastcs.controllers import BaseController # noqa: F401 + MethodCallback = Callable[..., Coroutine[None, None, None]] """Generic protocol for all `Controller` Method callbacks""" -Controller_T = TypeVar("Controller_T", bound=BaseController) +Controller_T = TypeVar("Controller_T", bound="BaseController") # noqa: F821 """Generic `Controller` class that an unbound method must be called with as `self`""" diff --git a/src/fastcs/methods/scan.py b/src/fastcs/methods/scan.py index c35850fe3..a20b7e944 100644 --- a/src/fastcs/methods/scan.py +++ b/src/fastcs/methods/scan.py @@ -1,10 +1,13 @@ from collections.abc import Callable, Coroutine from types import MethodType +from typing import TYPE_CHECKING -from fastcs.controllers import BaseController from fastcs.logging import bind_logger from fastcs.methods.method import Controller_T, Method +if TYPE_CHECKING: + from fastcs.controllers import BaseController # noqa: F401 + logger = bind_logger(logger_name=__name__) UnboundScanCallback = Callable[[Controller_T], Coroutine[None, None, None]] @@ -13,7 +16,7 @@ """A Scan callback that is bound and can be called without `self`""" -class Scan(Method[BaseController]): +class Scan(Method["BaseController"]): """A `Controller` `Method` that will be called periodically in the background. This class contains a function that is bound to a specific `Controller` instance and @@ -40,7 +43,7 @@ async def __call__(self): return await self._fn() @property - def fn(self): + def fn(self) -> ScanCallback: async def scan(): try: return await self._fn() @@ -80,15 +83,10 @@ def _validate(self, fn: UnboundScanCallback[Controller_T]) -> None: def bind(self, controller: Controller_T) -> Scan: return Scan(MethodType(self.fn, controller), self._period) - def __call__(self): - raise NotImplementedError( - "Method must be bound to a controller instance to be callable" - ) - def scan( period: float, -) -> Callable[[UnboundScanCallback[Controller_T]], UnboundScan[Controller_T]]: +) -> Callable[[UnboundScanCallback[Controller_T]], UnboundScanCallback[Controller_T]]: """Decorator to register a `Controller` method as a `Scan` The `Scan` method will be called periodically in the background. @@ -97,7 +95,11 @@ def scan( if period <= 0: raise ValueError("Scan method must have a positive scan period") - def wrapper(fn: UnboundScanCallback[Controller_T]) -> UnboundScan[Controller_T]: - return UnboundScan(fn, period) + def wrapper( + fn: UnboundScanCallback[Controller_T], + ) -> UnboundScanCallback[Controller_T]: + setattr(fn, "__unbound_scan__", UnboundScan(fn, period=period)) # noqa: B010 + + return fn return wrapper diff --git a/src/fastcs/transports/controller_api.py b/src/fastcs/transports/controller_api.py index 9048d5ef5..3a13fae0a 100644 --- a/src/fastcs/transports/controller_api.py +++ b/src/fastcs/transports/controller_api.py @@ -88,7 +88,7 @@ def _add_attribute_update_tasks( def _get_periodic_scan_coros( - scan_dict: dict[float, list[Scan]], + scan_dict: dict[float, list[ScanCallback]], ) -> list[ScanCallback]: periodic_scan_coros: list[ScanCallback] = [] for period, methods in scan_dict.items(): @@ -97,11 +97,13 @@ def _get_periodic_scan_coros( return periodic_scan_coros -def _create_periodic_scan_coro(period: float, scans: list[Scan]) -> ScanCallback: +def _create_periodic_scan_coro( + period: float, scans: list[ScanCallback] +) -> ScanCallback: async def _sleep(): await asyncio.sleep(period) - methods = [_sleep] + scans # Create periodic behavior + methods = [_sleep] + list(scans) # Create periodic behavior async def scan_coro() -> None: while True: diff --git a/tests/test_control_system.py b/tests/test_control_system.py index 347e2ed35..1125a3703 100644 --- a/tests/test_control_system.py +++ b/tests/test_control_system.py @@ -22,7 +22,7 @@ async def test_scan_tasks(controller): for _ in range(3): count = controller.count - await asyncio.sleep(controller.counter.period + 0.01) + await asyncio.sleep(0.1) assert controller.count > count diff --git a/tests/test_methods.py b/tests/test_methods.py index 8095200e4..a3e990996 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -39,17 +39,13 @@ async def do_nothing(self): async def do_nothing_with_arg(self, arg): pass - unbound_command = UnboundCommand(TestController.do_nothing, group="Test") - - with pytest.raises(NotImplementedError): - await unbound_command() - with pytest.raises(TypeError): UnboundCommand(TestController.do_nothing_with_arg) # type: ignore with pytest.raises(TypeError): Command(TestController().do_nothing_with_arg) # type: ignore + unbound_command = UnboundCommand(TestController.do_nothing, group="Test") command = unbound_command.bind(TestController()) # Test that group is passed when binding commands assert command.group == "Test" @@ -66,19 +62,14 @@ async def update_nothing(self): async def update_nothing_with_arg(self, arg): pass - unbound_scan = UnboundScan(TestController.update_nothing, 1.0) - - assert unbound_scan.period == 1.0 - - with pytest.raises(NotImplementedError): - await unbound_scan() - with pytest.raises(TypeError): UnboundScan(TestController.update_nothing_with_arg, 1.0) # type: ignore with pytest.raises(TypeError): Scan(TestController().update_nothing_with_arg, 1.0) # type: ignore + unbound_scan = UnboundScan(TestController.update_nothing, 1.0) + assert unbound_scan.period == 1.0 scan = unbound_scan.bind(TestController()) assert scan.period == 1.0