diff --git a/core/taskmanager.py b/core/taskmanager.py index 8a1c087f4..8050993e1 100644 --- a/core/taskmanager.py +++ b/core/taskmanager.py @@ -22,9 +22,11 @@ def register_task(cls, task_class: Type[Task], task_name: str | None = None): task_name = task_class.__name__ logging.info(f"Registering task: {task_name}") task = task_class.find(name=task_name) - if not task: + if not task and task_class != ExportTask: logging.info(f"Task {task_name} not found in database, creating.") task_dict = task_class._defaults.copy() + print(task_dict) + print(task_class) task_dict["name"] = task_name task = task_class(**task_dict).save() cls._store[task_name] = task @@ -88,7 +90,6 @@ def run_task(cls, task_name: str, task_params: TaskParams): task.save() try: - logging.info(f"Running task {task_name}") if task_params.params: logging.debug( f"Running task {task_name} with params {task_params.params}" diff --git a/core/taskscheduler.py b/core/taskscheduler.py index 479d836f3..2224b06a0 100644 --- a/core/taskscheduler.py +++ b/core/taskscheduler.py @@ -58,7 +58,7 @@ def setup_periodic_tasks(sender, **kwargs): return -@app.task +@app.task(name="run_task") def run_task(task_name: str, params: str): """Runs a task. diff --git a/core/web/apiv2/system.py b/core/web/apiv2/system.py index cf9942d98..12e3fe49e 100644 --- a/core/web/apiv2/system.py +++ b/core/web/apiv2/system.py @@ -1,8 +1,8 @@ +from celery import Celery from fastapi import APIRouter, Depends from pydantic import BaseModel, ConfigDict from core.config.config import yeti_config -from core.taskscheduler import app from core.web.apiv2.auth import get_current_active_user # API endpoints @@ -49,6 +49,12 @@ def get_config() -> SystemConfigResponse: @router.get("/workers", dependencies=[Depends(get_current_active_user)]) def get_worker_status() -> WorkerStatusResponse: + app = Celery( + "tasks", + broker=f"redis://{yeti_config.get('redis', 'host')}/", + worker_pool_restarts=True, + ) + inspect = app.control.inspect(timeout=5, destination=None) registered = {} @@ -73,6 +79,12 @@ def get_worker_status() -> WorkerStatusResponse: def restart_worker(worker_name: str) -> WorkerRestartResponse: """Restarts a single or all Celery workers.""" destination = [worker_name] if worker_name != "all" else None + app = Celery( + "tasks", + broker=f"redis://{yeti_config.get('redis', 'host')}/", + worker_pool_restarts=True, + ) + response = app.control.broadcast( "pool_restart", arguments={"reload": True}, diff --git a/core/web/apiv2/tasks.py b/core/web/apiv2/tasks.py index fd3dc31f7..ad0f665fb 100644 --- a/core/web/apiv2/tasks.py +++ b/core/web/apiv2/tasks.py @@ -1,10 +1,11 @@ import io +from celery import Celery, signature from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict -from core import taskscheduler +from core.config.config import yeti_config from core.schemas.task import ExportTask, Task, TaskParams, TaskType, TaskTypes from core.schemas.template import Template @@ -48,7 +49,14 @@ def run(task_name, params: TaskParams | None = None) -> dict[str, str]: """Runs a task asynchronously.""" if params is None: params = TaskParams() - taskscheduler.run_task.delay(task_name, params.model_dump_json()) + + Celery( + "tasks", + broker=f"redis://{yeti_config.get('redis', 'host')}/", + worker_pool_restarts=True, + ) + sig = signature("run_task", (task_name, params.model_dump_json())) + sig.apply_async() return {"status": "ok"} diff --git a/tests/apiv2/tasks.py b/tests/apiv2/tasks.py index 224caee8c..85237df8d 100644 --- a/tests/apiv2/tasks.py +++ b/tests/apiv2/tasks.py @@ -4,6 +4,7 @@ import unittest from unittest import mock +import celery from fastapi.testclient import TestClient from core import database_arango, taskmanager @@ -72,15 +73,18 @@ def test_run_task(self, mock_delay): self.assertEqual(data["status"], "ok") mock_delay.assert_called_once_with("FakeTask", '{"params":{}}') - @mock.patch("core.taskscheduler.run_task.delay") - def test_run_task_with_params(self, mock_delay): + # @mock.patch("core.taskscheduler.run_task.delay") + @mock.patch("celery.signature") + def test_run_task_with_params(self, mock_signature): response = client.post( "/api/v2/tasks/FakeTask/run", json={"params": {"value": "test"}} ) data = response.json() self.assertEqual(response.status_code, 200, data) self.assertEqual(data["status"], "ok") - mock_delay.assert_called_once_with("FakeTask", '{"params":{"value":"test"}}') + mock_signature.assert_called_once_with( + "run_task", '{"params":{"value":"test"}}' + ) class ExportTaskTest(unittest.TestCase):