Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/dodal/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,22 @@ def fixture(self, func: Callable[[], T]) -> Callable[[], T]:
self._fixtures[func.__name__] = func
return func

def include(self, other: "DeviceManager"):
common = self._factories.keys() & other._factories # noqa SLF001
common |= self._v1_factories.keys() & other._v1_factories # noqa SLF001
common |= self._factories.keys() & other._v1_factories # noqa SLF001
common |= self._v1_factories.keys() & other._factories # noqa SLF001
if common:
raise ValueError(
f"Duplicate factories in included device manager: {common}"
)

self._factories.update(other._factories) # noqa SLF001
self._v1_factories.update(other._v1_factories) # noqa SLF001

# duplicate fixtures are not checked as fixtures can be overridden
self._fixtures.update(other._fixtures) # noqa SLF001

def v1_init(
self,
factory: type[V1],
Expand Down
37 changes: 37 additions & 0 deletions tests/test_device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,43 @@ def foo(s, one, two):
s1().set_up_with.assert_called_once_with("one", 2)


def test_inherited_device_manager(dm: DeviceManager):
s1 = Mock(return_value=Mock(spec=OphydV2Device))
s2 = Mock(return_value=Mock(spec=OphydV2Device))

@dm.factory
def foo():
return s1()

dm2 = DeviceManager()

@dm2.factory
def bar(foo):
return s2(foo)

dm2.include(dm)

built_bar = bar.build()
s1.assert_called_once()
s2.assert_called_once_with(s1())
assert built_bar is s2(s1())


def test_inherited_device_manager_duplicate_name():
device = Mock(return_value=Mock(spec=OphydV2Device))

dm = DeviceManager()
dm2 = DeviceManager()

@dm.factory
@dm2.factory
def foo():
return device()

with pytest.raises(ValueError, match="{'foo'}"):
dm.include(dm2)


def test_lazy_fixtures_non_lazy():
lf = LazyFixtures(provided={"foo": "bar"}, factories={})
assert lf["foo"] == "bar"
Expand Down