diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 26ed8da78..11ed477b1 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -1,6 +1,6 @@ import logging from collections.abc import Callable -from dataclasses import dataclass, field +from dataclasses import InitVar, dataclass, field from importlib import import_module from inspect import Parameter, signature from types import ModuleType, NoneType, UnionType @@ -8,6 +8,7 @@ from bluesky.protocols import HasName from bluesky.run_engine import RunEngine +from dodal.common.beamlines.beamline_utils import get_path_provider, set_path_provider from dodal.utils import make_all_devices from ophyd_async.core import NotConnected from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler, create_model @@ -16,12 +17,15 @@ from pydantic_core import CoreSchema, core_schema from blueapi import utils -from blueapi.config import EnvironmentConfig, SourceKind +from blueapi.client.numtracker import NumtrackerClient +from blueapi.config import ApplicationConfig, EnvironmentConfig, SourceKind from blueapi.utils import ( BlueapiPlanModelConfig, is_function_sourced_from_module, load_module_all, ) +from blueapi.utils.invalid_config_error import InvalidConfigError +from blueapi.utils.path_provider import StartDocumentPathProvider from .bluesky_types import ( BLUESKY_PROTOCOLS, @@ -86,15 +90,57 @@ class BlueskyContext: The context holds the RunEngine and any plans/devices that you may want to use. """ + configuration: InitVar[ApplicationConfig | None] = None + run_engine: RunEngine = field( default_factory=lambda: RunEngine(context_managers=[]) ) + numtracker: NumtrackerClient | None = field(default=None, init=False, repr=False) plans: dict[str, Plan] = field(default_factory=dict) devices: dict[str, Device] = field(default_factory=dict) plan_functions: dict[str, PlanGenerator] = field(default_factory=dict) _reference_cache: dict[type, type] = field(default_factory=dict) + def __post_init__(self, configuration: ApplicationConfig | None): + if not configuration: + return + + if configuration.numtracker is not None: + if configuration.env.metadata is not None: + self.numtracker = NumtrackerClient(url=configuration.numtracker.url) + else: + raise InvalidConfigError( + "Numtracker url has been configured, but there is no instrument or" + " instrument_session in the environment metadata" + ) + + if self.numtracker is not None: + numtracker = self.numtracker + + path_provider = StartDocumentPathProvider() + set_path_provider(path_provider) + self.run_engine.subscribe(path_provider.update_run, "start") + + def _update_scan_num(md: dict[str, Any]) -> int: + scan = numtracker.create_scan( + md["instrument_session"], md["instrument"] + ) + md["data_session_directory"] = str(scan.scan.directory.path) + md["scan_file"] = scan.scan.scan_file + return scan.scan.scan_number + + self.run_engine.scan_id_source = _update_scan_num + + self.with_config(configuration.env) + if self.numtracker and not isinstance( + get_path_provider(), StartDocumentPathProvider + ): + raise InvalidConfigError( + "Numtracker has been configured but a path provider was imported with " + "the devices. Remove this path provider to use numtracker." + ) + def find_device(self, addr: str | list[str]) -> Device | None: """ Find a device in this context, allows for recursive search. diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 6ac3ccc72..de4ff8a09 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -4,13 +4,8 @@ from bluesky_stomp.messaging import StompClient from bluesky_stomp.models import Broker, DestinationBase, MessageTopic -from dodal.common.beamlines.beamline_utils import ( - get_path_provider, - set_path_provider, -) from blueapi.cli.scratch import get_python_environment -from blueapi.client.numtracker import NumtrackerClient from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig from blueapi.core.context import BlueskyContext from blueapi.core.event import EventStream @@ -23,8 +18,6 @@ TaskRequest, WorkerTask, ) -from blueapi.utils.invalid_config_error import InvalidConfigError -from blueapi.utils.path_provider import StartDocumentPathProvider from blueapi.worker.event import TaskStatusEnum, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TaskWorker, TrackableTask @@ -48,14 +41,10 @@ def set_config(new_config: ApplicationConfig): @cache def context() -> BlueskyContext: - ctx = BlueskyContext() + ctx = BlueskyContext(config()) return ctx -def configure_context() -> None: - context().with_config(config().env) - - @cache def worker() -> TaskWorker: worker = TaskWorker( @@ -96,35 +85,6 @@ def stomp_client() -> StompClient | None: return None -@cache -def numtracker_client() -> NumtrackerClient | None: - conf = config() - if conf.numtracker is not None: - if conf.env.metadata is not None: - return NumtrackerClient(url=conf.numtracker.url) - else: - raise InvalidConfigError( - "Numtracker url has been configured, but there is no instrument or" - " instrument_session in the environment metadata" - ) - else: - return None - - -def _update_scan_num(md: dict[str, Any]) -> int: - numtracker = numtracker_client() - if numtracker is not None: - scan = numtracker.create_scan(md["instrument_session"], md["instrument"]) - md["data_session_directory"] = str(scan.scan.directory.path) - md["scan_file"] = scan.scan.scan_file - return scan.scan.scan_number - else: - raise InvalidConfigError( - "Blueapi was configured to talk to numtracker but numtracker is not" - "configured, this should not happen, please contact the DAQ team" - ) - - def setup(config: ApplicationConfig) -> None: """Creates and starts a worker with supplied config""" set_config(config) @@ -132,32 +92,9 @@ def setup(config: ApplicationConfig) -> None: # Eagerly initialize worker and messaging connection worker() - - # if numtracker is configured, use a StartDocumentPathProvider - if numtracker_client() is not None: - context().run_engine.scan_id_source = _update_scan_num - _hook_run_engine_and_path_provider() - - configure_context() - - if numtracker_client() is not None and not isinstance( - get_path_provider(), StartDocumentPathProvider - ): - raise InvalidConfigError( - "Numtracker has been configured but a path provider was imported" - " with the devices. Remove this path provider to use numtracker." - ) - stomp_client() -def _hook_run_engine_and_path_provider() -> None: - path_provider = StartDocumentPathProvider() - set_path_provider(path_provider) - run_engine = context().run_engine - run_engine.subscribe(path_provider.update_run, "start") - - def teardown() -> None: worker().stop() if (stomp_client_ref := stomp_client()) is not None: @@ -165,7 +102,6 @@ def teardown() -> None: context.cache_clear() worker.cache_clear() stomp_client.cache_clear() - numtracker_client.cache_clear() def _publish_event_streams( @@ -224,19 +160,13 @@ def begin_task( task: WorkerTask, pass_through_headers: Mapping[str, str] | None = None ) -> WorkerTask: """Trigger a task. Will fail if the worker is busy""" - _try_configure_numtracker(pass_through_headers or {}) - + if nt := context().numtracker: + nt.set_headers(pass_through_headers or {}) if task.task_id is not None: worker().begin_task(task.task_id) return task -def _try_configure_numtracker(pass_through_headers: Mapping[str, str]) -> None: - numtracker = numtracker_client() - if numtracker is not None: - numtracker.set_headers(pass_through_headers) - - def get_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]: """Retrieve a list of tasks based on their status.""" return worker().get_tasks_by_status(status) diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index f15859d98..bc3a93279 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -256,15 +256,15 @@ def mock_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]: assert interface.get_tasks_by_status(TaskStatusEnum.COMPLETE) == [] -@patch("blueapi.service.interface._try_configure_numtracker") +@patch("blueapi.service.interface.BlueskyContext.numtracker") @patch("blueapi.service.interface.TaskWorker.begin_task") -def test_begin_task_with_headers(worker_mock: MagicMock, mock_configure: MagicMock): +def test_begin_task_with_headers(worker_mock: MagicMock, mock_numtracker: MagicMock): uuid_value = "350043fd-597e-41a7-9a92-5d5478232cf7" task = WorkerTask(task_id=uuid_value) headers = {"a": "b"} returned_task = interface.begin_task(task, headers) - mock_configure.assert_called_once_with(headers) + mock_numtracker.set_headers.assert_called_once_with(headers) assert task == returned_task worker_mock.assert_called_once_with(uuid_value) @@ -406,10 +406,10 @@ def test_configure_numtracker(): ) interface.set_config(conf) headers = {"a": "b"} - interface._try_configure_numtracker(headers) - nt = interface.numtracker_client() + nt = interface.context().numtracker assert isinstance(nt, NumtrackerClient) + nt.set_headers(headers) assert nt._headers == {"a": "b"} assert nt._url.unicode_string() == "https://numtracker-example.com/graphql" @@ -443,37 +443,36 @@ def test_headers_are_cleared(mock_post): headers = {"foo": "bar"} interface.begin_task(task=WorkerTask(task_id=None), pass_through_headers=headers) - interface._update_scan_num({"instrument_session": "cm12345-1", "instrument": "p46"}) + ctx = interface.context() + assert ctx.run_engine.scan_id_source is not None + ctx.run_engine.scan_id_source( + {"instrument_session": "cm12345-1", "instrument": "p46"} + ) mock_post.assert_called_once() assert mock_post.call_args.kwargs["headers"] == headers interface.begin_task(task=WorkerTask(task_id=None)) - interface._update_scan_num({"instrument_session": "cm12345-1", "instrument": "p46"}) + ctx.run_engine.scan_id_source( + {"instrument_session": "cm12345-1", "instrument": "p46"} + ) assert mock_post.call_count == 2 assert mock_post.call_args.kwargs["headers"] == {} -def test_configure_numtracker_with_no_numtracker_config_fails(): +def test_numtracker_requires_instrument_metadata(): conf = ApplicationConfig( - env=EnvironmentConfig(metadata=MetadataConfig(instrument="p46")), + numtracker=NumtrackerConfig( + url=HttpUrl("https://numtracker-example.com/graphql"), + ) ) interface.set_config(conf) - headers = {"a": "b"} - interface._try_configure_numtracker(headers) - nt = interface.numtracker_client() - - assert nt is None - - -def test_configure_numtracker_with_no_metadata_fails(): - conf = ApplicationConfig(numtracker=NumtrackerConfig()) - interface.set_config(conf) - headers = {"a": "b"} - - assert conf.env.metadata is None - + print("Post config") with pytest.raises(InvalidConfigError): - interface._try_configure_numtracker(headers) + interface.context() + + # Clearing the config here prevents the same exception as above being + # raised in the ensure_worker_stopped fixture + interface.set_config(ApplicationConfig()) def test_setup_without_numtracker_with_existing_provider_does_not_overwrite_provider(): @@ -506,7 +505,6 @@ def test_setup_with_numtracker_makes_start_document_provider(): path_provider = get_path_provider() assert isinstance(path_provider, StartDocumentPathProvider) - assert interface.context().run_engine.scan_id_source == interface._update_scan_num clear_path_provider() @@ -545,12 +543,15 @@ def test_numtracker_create_scan_called_with_arguments_from_metadata(mock_create_ ) interface.set_config(conf) ctx = interface.context() - interface.configure_context() headers = {"a": "b"} - interface._try_configure_numtracker(headers) + + assert ctx.numtracker is not None + assert ctx.run_engine.scan_id_source is not None + + ctx.numtracker.set_headers(headers) ctx.run_engine.md["instrument_session"] = "ab123" - interface._update_scan_num(ctx.run_engine.md) + ctx.run_engine.scan_id_source(ctx.run_engine.md) mock_create_scan.assert_called_once_with("ab123", "p46") @@ -567,8 +568,10 @@ def test_update_scan_num_side_effect_sets_data_session_directory_in_re_md( interface.setup(conf) ctx = interface.context() + assert ctx.run_engine.scan_id_source is not None + ctx.run_engine.md["instrument_session"] = "ab123" - interface._update_scan_num(ctx.run_engine.md) + ctx.run_engine.scan_id_source(ctx.run_engine.md) assert ( ctx.run_engine.md["data_session_directory"] == "/exports/mybeamline/data/2025" @@ -587,7 +590,9 @@ def test_update_scan_num_side_effect_sets_scan_file_in_re_md( interface.setup(conf) ctx = interface.context() + assert ctx.run_engine.scan_id_source is not None + ctx.run_engine.md["instrument_session"] = "ab123" - interface._update_scan_num(ctx.run_engine.md) + ctx.run_engine.scan_id_source(ctx.run_engine.md) assert ctx.run_engine.md["scan_file"] == "p46-11"