Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
("py:class", "p4p.nt.ndarray.NTNDArray"),
("py:class", "p4p.nt.NTTable"),
# Problems in FastCS itself
("py:class", "T"),
("py:class", "BaseController"),
("py:class", "AttrIOUpdateCallback"),
("py:class", "fastcs.transports.epics.pva.pvi_tree._PviSignalInfo"),
("py:class", "fastcs.logging._logging.LogLevel"),
Expand Down
18 changes: 3 additions & 15 deletions src/fastcs/control_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from fastcs.controllers import BaseController, Controller
from fastcs.logging import bind_logger
from fastcs.methods import Command, Scan, ScanCallback
from fastcs.methods import ScanCallback
from fastcs.tracer import Tracer
from fastcs.transports import ControllerAPI, Transport

Expand Down Expand Up @@ -174,23 +174,11 @@ def build_controller_api(controller: Controller) -> ControllerAPI:


def _build_controller_api(controller: BaseController, path: list[str]) -> ControllerAPI:
scan_methods: dict[str, Scan] = {}
command_methods: dict[str, Command] = {}
for attr_name in dir(controller):
attr = getattr(controller, attr_name)
match attr:
case Scan(enabled=True):
scan_methods[attr_name] = attr
case Command(enabled=True):
command_methods[attr_name] = attr
case _:
pass

return ControllerAPI(
path=path,
attributes=controller.attributes,
scan_methods=scan_methods,
command_methods=command_methods,
command_methods=controller.command_methods,
scan_methods=controller.scan_methods,
sub_apis={
name: _build_controller_api(sub_controller, path + [name])
for name, sub_controller in controller.sub_controllers.items()
Expand Down
44 changes: 38 additions & 6 deletions src/fastcs/controllers/base_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from fastcs.attributes import AnyAttributeIO, Attribute, AttrR, AttrW, HintedAttribute
from fastcs.logging import bind_logger
from fastcs.methods import Command, Scan, UnboundCommand, UnboundScan
from fastcs.tracer import Tracer

logger = bind_logger(logger_name=__name__)
Expand Down Expand Up @@ -46,6 +47,8 @@ def __init__(
# Internal state that should not be accessed directly by base classes
self.__attributes: dict[str, Attribute] = {}
self.__sub_controllers: dict[str, BaseController] = {}
self.__command_methods: dict[str, Command] = {}
self.__scan_methods: dict[str, Scan] = {}

self.__hinted_attributes: dict[str, HintedAttribute] = {}
self.__hinted_sub_controllers: dict[str, type[BaseController]] = {}
Expand Down Expand Up @@ -95,10 +98,6 @@ class method and a controller instance, so that it can be called from any
context with the controller instance passed as the ``self`` argument.

"""
# Lazy import to avoid circular references
from fastcs.methods.command import UnboundCommand
from fastcs.methods.scan import UnboundScan

# Using a dictionary instead of a set to maintain order.
class_dir = {key: None for key in dir(type(self)) if not key.startswith("_")}
class_type_hints = {
Expand All @@ -114,8 +113,21 @@ class method and a controller instance, so that it can be called from any
attr = getattr(self, attr_name, None)
if isinstance(attr, Attribute):
setattr(self, attr_name, deepcopy(attr))
elif isinstance(attr, UnboundScan | UnboundCommand):
setattr(self, attr_name, attr.bind(self))
else:
if isinstance(attr, Command):
self.add_command(attr_name, attr)
elif isinstance(attr, Scan):
self.add_scan(attr_name, attr)
elif isinstance(
unbound_command := getattr(attr, "__unbound_command__", None),
UnboundCommand,
):
self.add_command(attr_name, unbound_command.bind(self))
elif isinstance(
unbound_scan := getattr(attr, "__unbound_scan__", None),
UnboundScan,
):
self.add_scan(attr_name, unbound_scan.bind(self))

def _validate_io(self, ios: Sequence[AnyAttributeIO]):
"""Validate that there is exactly one AttributeIO class registered to the
Expand All @@ -137,6 +149,10 @@ def __repr__(self):
def __setattr__(self, name, value):
if isinstance(value, Attribute):
self.add_attribute(name, value)
elif isinstance(value, Command):
self.add_command(name, value)
elif isinstance(value, Scan):
self.add_scan(name, value)
elif isinstance(value, BaseController):
self.add_sub_controller(name, value)
else:
Expand Down Expand Up @@ -300,3 +316,19 @@ def add_sub_controller(self, name: str, sub_controller: BaseController):
@property
def sub_controllers(self) -> dict[str, BaseController]:
return self.__sub_controllers

def add_command(self, name: str, command: Command):
self.__command_methods[name] = command
super().__setattr__(name, command)

@property
def command_methods(self) -> dict[str, Command]:
return self.__command_methods

def add_scan(self, name: str, scan: Scan):
self.__scan_methods[name] = scan
super().__setattr__(name, scan)

@property
def scan_methods(self) -> dict[str, Scan]:
return self.__scan_methods
2 changes: 2 additions & 0 deletions src/fastcs/methods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .command import Command as Command
from .command import CommandCallback as CommandCallback
from .command import UnboundCommand as UnboundCommand
from .command import command as command
from .scan import Scan as Scan
from .scan import ScanCallback as ScanCallback
from .scan import UnboundScan as UnboundScan
from .scan import scan as scan
38 changes: 27 additions & 11 deletions src/fastcs/methods/command.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from collections.abc import Callable, Coroutine
from types import MethodType
from typing import TYPE_CHECKING

from fastcs.controllers import BaseController
from fastcs.logging import bind_logger
from fastcs.methods.method import Controller_T, Method

if TYPE_CHECKING:
from fastcs.controllers import BaseController # noqa: F401

logger = bind_logger(logger_name=__name__)

UnboundCommandCallback = Callable[[Controller_T], Coroutine[None, None, None]]
"""A Command callback that is unbound and must be called with a `Controller` instance"""
CommandCallback = Callable[[], Coroutine[None, None, None]]
"""A Command callback that is bound and can be called without `self`"""


class Command(Method[BaseController]):
class Command(Method["BaseController"]):
"""A `Controller` `Method` that performs a single action when called.

This class contains a function that is bound to a specific `Controller` instance and
Expand All @@ -28,7 +34,18 @@ def _validate(self, fn: CommandCallback) -> None:
raise TypeError(f"Command method cannot have arguments: {fn}")

async def __call__(self):
return await self._fn()
return await self.fn()

@property
def fn(self) -> CommandCallback:
async def command():
try:
return await self._fn()
except Exception:
logger.exception("Command failed", fn=self._fn)
raise

return command


class UnboundCommand(Method[Controller_T]):
Expand Down Expand Up @@ -56,15 +73,12 @@ def _validate(self, fn: UnboundCommandCallback[Controller_T]) -> None:
def bind(self, controller: Controller_T) -> Command:
return Command(MethodType(self.fn, controller), group=self.group)

def __call__(self):
raise NotImplementedError(
"Method must be bound to a controller instance to be callable"
)


def command(
*, group: str | None = None
) -> Callable[[UnboundCommandCallback[Controller_T]], UnboundCommand[Controller_T]]:
) -> Callable[
[UnboundCommandCallback[Controller_T]], UnboundCommandCallback[Controller_T]
]:
"""Decorator to register a `Controller` method as a `Command`

The `Command` will be passed to the transport layer to expose in the API
Expand All @@ -75,7 +89,9 @@ def command(

def wrapper(
fn: UnboundCommandCallback[Controller_T],
) -> UnboundCommand[Controller_T]:
return UnboundCommand(fn, group=group)
) -> UnboundCommandCallback[Controller_T]:
setattr(fn, "__unbound_command__", UnboundCommand(fn, group=group)) # noqa: B010

return fn

return wrapper
8 changes: 5 additions & 3 deletions src/fastcs/methods/method.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from asyncio import iscoroutinefunction
from collections.abc import Callable, Coroutine
from inspect import Signature, getdoc, signature
from typing import Generic, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar

from fastcs.controllers.base_controller import BaseController
from fastcs.tracer import Tracer

if TYPE_CHECKING:
from fastcs.controllers import BaseController # noqa: F401

MethodCallback = Callable[..., Coroutine[None, None, None]]
"""Generic protocol for all `Controller` Method callbacks"""
Controller_T = TypeVar("Controller_T", bound=BaseController)
Controller_T = TypeVar("Controller_T", bound="BaseController") # noqa: F821
"""Generic `Controller` class that an unbound method must be called with as `self`"""


Expand Down
24 changes: 13 additions & 11 deletions src/fastcs/methods/scan.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections.abc import Callable, Coroutine
from types import MethodType
from typing import TYPE_CHECKING

from fastcs.controllers import BaseController
from fastcs.logging import bind_logger
from fastcs.methods.method import Controller_T, Method

if TYPE_CHECKING:
from fastcs.controllers import BaseController # noqa: F401

logger = bind_logger(logger_name=__name__)

UnboundScanCallback = Callable[[Controller_T], Coroutine[None, None, None]]
Expand All @@ -13,7 +16,7 @@
"""A Scan callback that is bound and can be called without `self`"""


class Scan(Method[BaseController]):
class Scan(Method["BaseController"]):
"""A `Controller` `Method` that will be called periodically in the background.

This class contains a function that is bound to a specific `Controller` instance and
Expand All @@ -40,7 +43,7 @@ async def __call__(self):
return await self._fn()

@property
def fn(self):
def fn(self) -> ScanCallback:
async def scan():
try:
return await self._fn()
Expand Down Expand Up @@ -80,15 +83,10 @@ def _validate(self, fn: UnboundScanCallback[Controller_T]) -> None:
def bind(self, controller: Controller_T) -> Scan:
return Scan(MethodType(self.fn, controller), self._period)

def __call__(self):
raise NotImplementedError(
"Method must be bound to a controller instance to be callable"
)


def scan(
period: float,
) -> Callable[[UnboundScanCallback[Controller_T]], UnboundScan[Controller_T]]:
) -> Callable[[UnboundScanCallback[Controller_T]], UnboundScanCallback[Controller_T]]:
"""Decorator to register a `Controller` method as a `Scan`

The `Scan` method will be called periodically in the background.
Expand All @@ -97,7 +95,11 @@ def scan(
if period <= 0:
raise ValueError("Scan method must have a positive scan period")

def wrapper(fn: UnboundScanCallback[Controller_T]) -> UnboundScan[Controller_T]:
return UnboundScan(fn, period)
def wrapper(
fn: UnboundScanCallback[Controller_T],
) -> UnboundScanCallback[Controller_T]:
setattr(fn, "__unbound_scan__", UnboundScan(fn, period=period)) # noqa: B010

return fn

return wrapper
8 changes: 5 additions & 3 deletions src/fastcs/transports/controller_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _add_attribute_update_tasks(


def _get_periodic_scan_coros(
scan_dict: dict[float, list[Scan]],
scan_dict: dict[float, list[ScanCallback]],
) -> list[ScanCallback]:
periodic_scan_coros: list[ScanCallback] = []
for period, methods in scan_dict.items():
Expand All @@ -97,11 +97,13 @@ def _get_periodic_scan_coros(
return periodic_scan_coros


def _create_periodic_scan_coro(period: float, scans: list[Scan]) -> ScanCallback:
def _create_periodic_scan_coro(
period: float, scans: list[ScanCallback]
) -> ScanCallback:
async def _sleep():
await asyncio.sleep(period)

methods = [_sleep] + scans # Create periodic behavior
methods = [_sleep] + list(scans) # Create periodic behavior

async def scan_coro() -> None:
while True:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_control_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def test_scan_tasks(controller):

for _ in range(3):
count = controller.count
await asyncio.sleep(controller.counter.period + 0.01)
await asyncio.sleep(0.1)
assert controller.count > count


Expand Down
15 changes: 3 additions & 12 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,13 @@ async def do_nothing(self):
async def do_nothing_with_arg(self, arg):
pass

unbound_command = UnboundCommand(TestController.do_nothing, group="Test")

with pytest.raises(NotImplementedError):
await unbound_command()

with pytest.raises(TypeError):
UnboundCommand(TestController.do_nothing_with_arg) # type: ignore

with pytest.raises(TypeError):
Command(TestController().do_nothing_with_arg) # type: ignore

unbound_command = UnboundCommand(TestController.do_nothing, group="Test")
command = unbound_command.bind(TestController())
# Test that group is passed when binding commands
assert command.group == "Test"
Expand All @@ -66,19 +62,14 @@ async def update_nothing(self):
async def update_nothing_with_arg(self, arg):
pass

unbound_scan = UnboundScan(TestController.update_nothing, 1.0)

assert unbound_scan.period == 1.0

with pytest.raises(NotImplementedError):
await unbound_scan()

with pytest.raises(TypeError):
UnboundScan(TestController.update_nothing_with_arg, 1.0) # type: ignore

with pytest.raises(TypeError):
Scan(TestController().update_nothing_with_arg, 1.0) # type: ignore

unbound_scan = UnboundScan(TestController.update_nothing, 1.0)
assert unbound_scan.period == 1.0
scan = unbound_scan.bind(TestController())

assert scan.period == 1.0
Expand Down