Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 45 additions & 30 deletions src/omotes_sdk/internal/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,22 @@
import logging
import socket
import sys
from typing import Callable, Dict, List, Any, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from uuid import UUID

import streamcapture
from billiard.einfo import ExceptionInfo
from celery import Task as CeleryTask, Celery
from celery import Celery
from celery import Task as CeleryTask
from celery.apps.worker import Worker as CeleryWorker
from kombu import Queue as KombuQueue
from esdl import EnergySystem
from kombu import Queue as KombuQueue
from omotes_sdk_protocol.internal.task_pb2 import TaskProgressUpdate, TaskResult

from omotes_sdk.internal.orchestrator_worker_events.esdl_messages import EsdlMessage
from omotes_sdk.internal.common.broker_interface import BrokerInterface
from omotes_sdk.internal.common.esdl_util import pyesdl_from_string
from omotes_sdk.internal.orchestrator_worker_events.esdl_messages import EsdlMessage
from omotes_sdk.internal.worker.configs import WorkerConfig
from omotes_sdk.internal.common.broker_interface import BrokerInterface
from omotes_sdk_protocol.internal.task_pb2 import (
TaskResult,
TaskProgressUpdate,
)
from omotes_sdk.types import ProtobufDict

logger = logging.getLogger("omotes_sdk_internal")
Expand Down Expand Up @@ -59,7 +57,7 @@ def send_start(self) -> None:
TaskProgressUpdate(
job_id=str(self.job_id),
celery_task_id=self.task.request.id,
celery_task_type=WORKER_TASK_TYPE,
celery_task_type=self.task.name,
status=TaskProgressUpdate.START,
message="Started job at worker.",
).SerializeToString(),
Expand All @@ -84,7 +82,7 @@ def update_progress(self, fraction: float, message: str) -> None:
TaskProgressUpdate(
job_id=str(self.job_id),
celery_task_id=self.task.request.id,
celery_task_type=WORKER_TASK_TYPE,
celery_task_type=self.task.name,
progress=float(fraction),
message=message,
).SerializeToString(),
Expand Down Expand Up @@ -165,7 +163,7 @@ def after_return(
result_message = TaskResult(
job_id=str(job_id),
celery_task_id=self.request.id,
celery_task_type=WORKER_TASK_TYPE,
celery_task_type=self.name,
result_type=TaskResult.ResultType.ERROR,
output_esdl="",
logs=logs,
Expand All @@ -181,7 +179,7 @@ def after_return(
result_message = TaskResult(
job_id=str(job_id),
celery_task_id=self.request.id,
celery_task_type=WORKER_TASK_TYPE,
celery_task_type=self.name,
result_type=TaskResult.ResultType.SUCCEEDED,
output_esdl=self.output_esdl,
logs=logs,
Expand Down Expand Up @@ -258,17 +256,21 @@ def wrapped_worker_task(
:param params_dict: job, non-ESDL, parameters.
"""
logger.info("Worker started new task %s with reference %s", job_id, job_reference)
task_util = TaskUtil(job_id, task, task.broker_if)
task_util = TaskUtil(
job_id,
task,
task.broker_if,
)
task_util.send_start()
output_esdl, esdl_messages = WORKER_TASK_FUNCTION(
input_esdl, params_dict, task_util.update_progress
input_esdl, params_dict, task_util.update_progress, task.name
)

if output_esdl:
input_esh = pyesdl_from_string(input_esdl)
input_energy_system: EnergySystem = input_esh.energy_system
if job_reference is None:
new_name = f"{input_energy_system.name}_{WORKER_TASK_TYPE}"
new_name = f"{input_energy_system.name}_{task.name}"
elif job_reference == "":
new_name = f"{input_energy_system.name}"
else:
Expand Down Expand Up @@ -316,9 +318,21 @@ def start(self) -> None:
)

# Config of celery app
self.celery_app.conf.task_queues = [KombuQueue(
WORKER_TASK_TYPE, routing_key=WORKER_TASK_TYPE, queue_arguments={"x-max-priority": 10}
)] # Tell the worker to listen to a specific queue for 1 workflow type.
queues = []
for worker_task_type in WORKER_TASK_TYPES:
logger.info("Starting Worker to work on task %s", worker_task_type)
queues.append(
KombuQueue(
worker_task_type,
routing_key=worker_task_type,
queue_arguments={"x-max-priority": 10},
)
)
self.celery_app.task(
wrapped_worker_task, base=WorkerTask, name=worker_task_type, bind=True
)

self.celery_app.conf.task_queues = queues
self.celery_app.conf.task_acks_late = True
self.celery_app.conf.task_reject_on_worker_lost = True
self.celery_app.conf.task_acks_on_failure_or_timeout = False
Expand All @@ -331,9 +345,6 @@ def start(self) -> None:
self.celery_app.conf.worker_hijack_root_logger = False
self.celery_app.conf.worker_redirect_stdouts = False

self.celery_app.task(wrapped_worker_task, base=WorkerTask, name=WORKER_TASK_TYPE, bind=True)

logger.info("Starting Worker to work on task %s", WORKER_TASK_TYPE)
logger.info(
"Connected to broker rabbitmq (%s:%s/%s) as %s",
rabbitmq_config.host,
Expand All @@ -343,7 +354,7 @@ def start(self) -> None:
)

self.celery_worker = self.celery_app.Worker(
hostname=f"worker-{WORKER_TASK_TYPE}@{socket.gethostname()}",
hostname=f"worker-{'_'.join(WORKER_TASK_TYPES)}@{socket.gethostname()}",
loglevel=logging.getLevelName(self.config.log_level),
autoscale=(1, 1),
)
Expand All @@ -353,7 +364,7 @@ def start(self) -> None:

UpdateProgressHandler = Callable[[float, str], None]
WorkerTaskF = Callable[
[str, ProtobufDict, UpdateProgressHandler],
[str, ProtobufDict, UpdateProgressHandler, str],
Tuple[
Optional[str],
List[EsdlMessage],
Expand All @@ -362,21 +373,25 @@ def start(self) -> None:

WORKER: Worker = None # type: ignore [assignment] # noqa
WORKER_TASK_FUNCTION: WorkerTaskF = None # type: ignore [assignment] # noqa
WORKER_TASK_TYPE: str = None # type: ignore [assignment] # noqa
WORKER_TASK_TYPES: list[str] = None # type: ignore [assignment] # noqa


def initialize_worker(
task_type: str,
task_types: list[str],
task_function: WorkerTaskF,
) -> None:
"""Initialize and run the `Worker`.

:param task_type: Technical name of the task. Needs to be equal to the name of the celery task
to which the orchestrator forwards the task.
:param task_types: Technical name of the tasks. Needs to be equal to the name of the celery task
to which the orchestrator forwards the task. May connect to one or more tasks.
:param task_function: Function which performs the Celery task.
"""
global WORKER_TASK_FUNCTION, WORKER_TASK_TYPE, WORKER
WORKER_TASK_TYPE = task_type
global WORKER_TASK_FUNCTION, WORKER_TASK_TYPES, WORKER
WORKER_TASK_TYPES = task_types
if len(WORKER_TASK_TYPES) < 1:
raise RuntimeError(
f"Should connect to one or more worker task types. Only found {len(WORKER_TASK_TYPES)}"
)
WORKER_TASK_FUNCTION = task_function
WORKER = Worker()
WORKER.start()
Loading