Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions core/taskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ def register_task(cls, task_class: Type[Task], task_name: str | None = None):
task_name = task_class.__name__
logging.info(f"Registering task: {task_name}")
task = task_class.find(name=task_name)
if not task:
if not task and task_class != ExportTask:
logging.info(f"Task {task_name} not found in database, creating.")
task_dict = task_class._defaults.copy()
print(task_dict)
print(task_class)
task_dict["name"] = task_name
task = task_class(**task_dict).save()
cls._store[task_name] = task
Expand Down Expand Up @@ -88,7 +90,6 @@ def run_task(cls, task_name: str, task_params: TaskParams):
task.save()

try:
logging.info(f"Running task {task_name}")
if task_params.params:
logging.debug(
f"Running task {task_name} with params {task_params.params}"
Expand Down
2 changes: 1 addition & 1 deletion core/taskscheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def setup_periodic_tasks(sender, **kwargs):
return


@app.task
@app.task(name="run_task")
def run_task(task_name: str, params: str):
"""Runs a task.

Expand Down
14 changes: 13 additions & 1 deletion core/web/apiv2/system.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from celery import Celery
from fastapi import APIRouter, Depends
from pydantic import BaseModel, ConfigDict

from core.config.config import yeti_config
from core.taskscheduler import app
from core.web.apiv2.auth import get_current_active_user

# API endpoints
Expand Down Expand Up @@ -49,6 +49,12 @@ def get_config() -> SystemConfigResponse:

@router.get("/workers", dependencies=[Depends(get_current_active_user)])
def get_worker_status() -> WorkerStatusResponse:
app = Celery(
"tasks",
broker=f"redis://{yeti_config.get('redis', 'host')}/",
worker_pool_restarts=True,
)

inspect = app.control.inspect(timeout=5, destination=None)

registered = {}
Expand All @@ -73,6 +79,12 @@ def get_worker_status() -> WorkerStatusResponse:
def restart_worker(worker_name: str) -> WorkerRestartResponse:
"""Restarts a single or all Celery workers."""
destination = [worker_name] if worker_name != "all" else None
app = Celery(
"tasks",
broker=f"redis://{yeti_config.get('redis', 'host')}/",
worker_pool_restarts=True,
)

response = app.control.broadcast(
"pool_restart",
arguments={"reload": True},
Expand Down
12 changes: 10 additions & 2 deletions core/web/apiv2/tasks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import io

from celery import Celery, signature
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict

from core import taskscheduler
from core.config.config import yeti_config
from core.schemas.task import ExportTask, Task, TaskParams, TaskType, TaskTypes
from core.schemas.template import Template

Expand Down Expand Up @@ -48,7 +49,14 @@ def run(task_name, params: TaskParams | None = None) -> dict[str, str]:
"""Runs a task asynchronously."""
if params is None:
params = TaskParams()
taskscheduler.run_task.delay(task_name, params.model_dump_json())

Celery(
"tasks",
broker=f"redis://{yeti_config.get('redis', 'host')}/",
worker_pool_restarts=True,
)
sig = signature("run_task", (task_name, params.model_dump_json()))
sig.apply_async()
return {"status": "ok"}


Expand Down
10 changes: 7 additions & 3 deletions tests/apiv2/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import unittest
from unittest import mock

import celery
from fastapi.testclient import TestClient

from core import database_arango, taskmanager
Expand Down Expand Up @@ -72,15 +73,18 @@ def test_run_task(self, mock_delay):
self.assertEqual(data["status"], "ok")
mock_delay.assert_called_once_with("FakeTask", '{"params":{}}')

@mock.patch("core.taskscheduler.run_task.delay")
def test_run_task_with_params(self, mock_delay):
# @mock.patch("core.taskscheduler.run_task.delay")
@mock.patch("celery.signature")
def test_run_task_with_params(self, mock_signature):
response = client.post(
"/api/v2/tasks/FakeTask/run", json={"params": {"value": "test"}}
)
data = response.json()
self.assertEqual(response.status_code, 200, data)
self.assertEqual(data["status"], "ok")
mock_delay.assert_called_once_with("FakeTask", '{"params":{"value":"test"}}')
mock_signature.assert_called_once_with(
"run_task", '{"params":{"value":"test"}}'
)


class ExportTaskTest(unittest.TestCase):
Expand Down
Loading