diff --git a/src/omotes_sdk/internal/worker/worker.py b/src/omotes_sdk/internal/worker/worker.py index 0e63cde..188d70d 100644 --- a/src/omotes_sdk/internal/worker/worker.py +++ b/src/omotes_sdk/internal/worker/worker.py @@ -2,24 +2,22 @@ import logging import socket import sys -from typing import Callable, Dict, List, Any, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from uuid import UUID import streamcapture from billiard.einfo import ExceptionInfo -from celery import Task as CeleryTask, Celery +from celery import Celery +from celery import Task as CeleryTask from celery.apps.worker import Worker as CeleryWorker -from kombu import Queue as KombuQueue from esdl import EnergySystem +from kombu import Queue as KombuQueue +from omotes_sdk_protocol.internal.task_pb2 import TaskProgressUpdate, TaskResult -from omotes_sdk.internal.orchestrator_worker_events.esdl_messages import EsdlMessage +from omotes_sdk.internal.common.broker_interface import BrokerInterface from omotes_sdk.internal.common.esdl_util import pyesdl_from_string +from omotes_sdk.internal.orchestrator_worker_events.esdl_messages import EsdlMessage from omotes_sdk.internal.worker.configs import WorkerConfig -from omotes_sdk.internal.common.broker_interface import BrokerInterface -from omotes_sdk_protocol.internal.task_pb2 import ( - TaskResult, - TaskProgressUpdate, -) from omotes_sdk.types import ProtobufDict logger = logging.getLogger("omotes_sdk_internal") @@ -59,7 +57,7 @@ def send_start(self) -> None: TaskProgressUpdate( job_id=str(self.job_id), celery_task_id=self.task.request.id, - celery_task_type=WORKER_TASK_TYPE, + celery_task_type=self.task.name, status=TaskProgressUpdate.START, message="Started job at worker.", ).SerializeToString(), @@ -84,7 +82,7 @@ def update_progress(self, fraction: float, message: str) -> None: TaskProgressUpdate( job_id=str(self.job_id), celery_task_id=self.task.request.id, - celery_task_type=WORKER_TASK_TYPE, + celery_task_type=self.task.name, progress=float(fraction), message=message, ).SerializeToString(), @@ -165,7 +163,7 @@ def after_return( result_message = TaskResult( job_id=str(job_id), celery_task_id=self.request.id, - celery_task_type=WORKER_TASK_TYPE, + celery_task_type=self.name, result_type=TaskResult.ResultType.ERROR, output_esdl="", logs=logs, @@ -181,7 +179,7 @@ def after_return( result_message = TaskResult( job_id=str(job_id), celery_task_id=self.request.id, - celery_task_type=WORKER_TASK_TYPE, + celery_task_type=self.name, result_type=TaskResult.ResultType.SUCCEEDED, output_esdl=self.output_esdl, logs=logs, @@ -258,17 +256,21 @@ def wrapped_worker_task( :param params_dict: job, non-ESDL, parameters. """ logger.info("Worker started new task %s with reference %s", job_id, job_reference) - task_util = TaskUtil(job_id, task, task.broker_if) + task_util = TaskUtil( + job_id, + task, + task.broker_if, + ) task_util.send_start() output_esdl, esdl_messages = WORKER_TASK_FUNCTION( - input_esdl, params_dict, task_util.update_progress + input_esdl, params_dict, task_util.update_progress, task.name ) if output_esdl: input_esh = pyesdl_from_string(input_esdl) input_energy_system: EnergySystem = input_esh.energy_system if job_reference is None: - new_name = f"{input_energy_system.name}_{WORKER_TASK_TYPE}" + new_name = f"{input_energy_system.name}_{task.name}" elif job_reference == "": new_name = f"{input_energy_system.name}" else: @@ -316,9 +318,21 @@ def start(self) -> None: ) # Config of celery app - self.celery_app.conf.task_queues = [KombuQueue( - WORKER_TASK_TYPE, routing_key=WORKER_TASK_TYPE, queue_arguments={"x-max-priority": 10} - )] # Tell the worker to listen to a specific queue for 1 workflow type. + queues = [] + for worker_task_type in WORKER_TASK_TYPES: + logger.info("Starting Worker to work on task %s", worker_task_type) + queues.append( + KombuQueue( + worker_task_type, + routing_key=worker_task_type, + queue_arguments={"x-max-priority": 10}, + ) + ) + self.celery_app.task( + wrapped_worker_task, base=WorkerTask, name=worker_task_type, bind=True + ) + + self.celery_app.conf.task_queues = queues self.celery_app.conf.task_acks_late = True self.celery_app.conf.task_reject_on_worker_lost = True self.celery_app.conf.task_acks_on_failure_or_timeout = False @@ -331,9 +345,6 @@ def start(self) -> None: self.celery_app.conf.worker_hijack_root_logger = False self.celery_app.conf.worker_redirect_stdouts = False - self.celery_app.task(wrapped_worker_task, base=WorkerTask, name=WORKER_TASK_TYPE, bind=True) - - logger.info("Starting Worker to work on task %s", WORKER_TASK_TYPE) logger.info( "Connected to broker rabbitmq (%s:%s/%s) as %s", rabbitmq_config.host, @@ -343,7 +354,7 @@ def start(self) -> None: ) self.celery_worker = self.celery_app.Worker( - hostname=f"worker-{WORKER_TASK_TYPE}@{socket.gethostname()}", + hostname=f"worker-{'_'.join(WORKER_TASK_TYPES)}@{socket.gethostname()}", loglevel=logging.getLevelName(self.config.log_level), autoscale=(1, 1), ) @@ -353,7 +364,7 @@ def start(self) -> None: UpdateProgressHandler = Callable[[float, str], None] WorkerTaskF = Callable[ - [str, ProtobufDict, UpdateProgressHandler], + [str, ProtobufDict, UpdateProgressHandler, str], Tuple[ Optional[str], List[EsdlMessage], @@ -362,21 +373,25 @@ def start(self) -> None: WORKER: Worker = None # type: ignore [assignment] # noqa WORKER_TASK_FUNCTION: WorkerTaskF = None # type: ignore [assignment] # noqa -WORKER_TASK_TYPE: str = None # type: ignore [assignment] # noqa +WORKER_TASK_TYPES: list[str] = None # type: ignore [assignment] # noqa def initialize_worker( - task_type: str, + task_types: list[str], task_function: WorkerTaskF, ) -> None: """Initialize and run the `Worker`. - :param task_type: Technical name of the task. Needs to be equal to the name of the celery task - to which the orchestrator forwards the task. + :param task_types: Technical name of the tasks. Needs to be equal to the name of the celery task + to which the orchestrator forwards the task. May connect to one or more tasks. :param task_function: Function which performs the Celery task. """ - global WORKER_TASK_FUNCTION, WORKER_TASK_TYPE, WORKER - WORKER_TASK_TYPE = task_type + global WORKER_TASK_FUNCTION, WORKER_TASK_TYPES, WORKER + WORKER_TASK_TYPES = task_types + if len(WORKER_TASK_TYPES) < 1: + raise RuntimeError( + f"Should connect to one or more worker task types. Only found {len(WORKER_TASK_TYPES)}" + ) WORKER_TASK_FUNCTION = task_function WORKER = Worker() WORKER.start()