diff --git a/README.md b/README.md index 631c038..7d8a3bd 100644 --- a/README.md +++ b/README.md @@ -10,16 +10,22 @@ Tasks are run in a process pool of configurable size. You define tasks by subclassing `Task`: class DoStuff(Task): - + def run(self): print("Look at me, I'm runniiiiiing ...") - + Tasks accept two parameters during creation * `config`: Something picklable to customize the tasks behavior at runtime * `dependencies`: A list of `Task` instances that need to be done before we start this task If task execution fails, a `DaggerException` is raised, with information about which tasks completed -and which failed. +and which failed. There is also a possibility to make Dagger resume execution of a failed task graphs +from the point of failure by assigning tasks unique ids: + +` +run_tasks([awesome_task, very_awesome_task], resume_id = "awesome_tasks") +` + -See also [examples folder](dagger/examples). \ No newline at end of file +See also [examples folder](dagger/examples). diff --git a/dagger/__init__.py b/dagger/__init__.py index f2c6117..e317914 100644 --- a/dagger/__init__.py +++ b/dagger/__init__.py @@ -1,2 +1,2 @@ from run import DaggerException, run_tasks -from task import Task \ No newline at end of file +from task import Task diff --git a/dagger/run.py b/dagger/run.py index 6ab69d5..9c80b1d 100644 --- a/dagger/run.py +++ b/dagger/run.py @@ -3,6 +3,40 @@ import time +import pickle + +from os.path import isfile +from os import remove + +def save_state(state, filename): + """ + :param state: dictionary containg current dag state + :param filename: filename to save into + """ + logging.info("Saving DAG state into {}...".format(filename)) + with open(filename, 'wb') as writefile: + pickle.dump(state, writefile) + logging.info("Done! Run 'run tasks' with the same id flag to pick up") + +def load_state(filename): + """ + :param filename: filename to read from + :return: dictionary containing DAG state + """ + logging.info("Loading DAG state from {}...".format(filename)) + + with open(filename, 'rb') as readfile: + recovered_state = pickle.load(readfile) + return recovered_state + +def get_filename(id_string): + """ + :param id_string: id to turn into filename + :return: properly formated filename + """ + id_string = id_string.replace(" ", "_") + return "{id_string}.dump" + def _run_in_process(task): """ @@ -52,7 +86,7 @@ def __str__(self): ) -def run_tasks(initial_tasks, pool_size=None, tick=1): +def run_tasks(initial_tasks, pool_size=None, tick=1, resume_id = ''): """ Run tasks, guaranteeing that their dependencies will be run before them. Work is distributed in a process pool to profit from parallelization. @@ -60,21 +94,34 @@ def run_tasks(initial_tasks, pool_size=None, tick=1): If one of the tasks fails, all currently running tasks will be run to completion. Afterwards, a DaggerException is raised, containing sets of completed, pending and failed tasks. + If the resume id is set the next time run_tasks with the same id is called, Dagger will try to pick up the + previous state and skip running all the tasks that were completed last time. + :param initial_tasks: Iterable of Task instances. :param pool_size: Size of process pool. Default is the number of CPUs :param tick: Frequency of dagger ticks in seconds + :param resume_id: Id of the DAG to trigger resuming from an old state """ - pending_tasks = set(initial_tasks) - for task in initial_tasks: - task.check_circular_dependencies([]) - pending_tasks |= set(task.get_all_dependencies()) - done_tasks = set() + if resume_id and isfile(get_filename(resume_id)): + # if we have an id set and a dump file, we try to resume from previous state + logging.info("recovering from a previously saved state...") + recovered_state = load_state(get_filename(resume_id)) + initial_tasks = recovered_state['pending_tasks'] | recovered_state['failed_tasks'] + done_tasks = recovered_state['done_tasks'] + pending_tasks = set(initial_tasks) + else: + # if not, we start from scratch + pending_tasks = set(initial_tasks) + done_tasks = set() + for task in initial_tasks: + task.check_circular_dependencies([]) + pending_tasks |= set(task.get_all_dependencies()) - return run_partial_tasks(pending_tasks, done_tasks, pool_size, tick) + return run_partial_tasks(pending_tasks, done_tasks, pool_size, tick, resume_id) -def run_partial_tasks(pending_tasks, done_tasks, pool_size=None, tick=1): +def run_partial_tasks(pending_tasks, done_tasks, pool_size=None, tick=1, resume_id = ''): """ Run a graph of tasks where some are already finished. Useful for attempting a rerun of a failed dagger execution. """ @@ -134,8 +181,17 @@ def task_done(res): if error_state["success"]: logging.info("All tasks are done!") + if resume_id and isfile(get_filename(resume_id)): + # if we successfully completed everything, remove the dump if its present + logging.info("Removing previously created state") + remove(get_filename(resume_id)) return True logging.critical("Tasks execution failed") error_state["done_tasks"] |= done_tasks + + if resume_id: + # pickle the state to resume from it later if the id is provided + save_state(error_state, get_filename(resume_id)) + raise DaggerException(error_state["pending_tasks"], error_state["done_tasks"], error_state["failed_tasks"]) diff --git a/tests/test_resume.py b/tests/test_resume.py new file mode 100644 index 0000000..7ba824a --- /dev/null +++ b/tests/test_resume.py @@ -0,0 +1,53 @@ +import pytest + +from dagger import run_tasks, Task, DaggerException +from multiprocessing import Array + + +# arrays in shared memory to test number of task executions +array_faulty = Array("i", [0]) +array_extract = Array("i", [0]) + + +class FaultyTask(Task): + # a task that is implemented with error + def run(self): + array_faulty[0] += 1 + None.fail() + +# a correct implementation of a Faulty tasks' run +def new_run(self): + array_faulty[0] += 20 + +class ExtractTask(Task): + def run(self): + # a task doing some important and long stuff + array_extract[0] += 1 + +def test_resume(): + """ + Test the option to persist DAG state across runs in case of failure of + a task + """ + extract_1 = ExtractTask({}) + faultyTask = FaultyTask({}, [extract_1]) + extract_2 = ExtractTask({}, [faultyTask]) + + # this should fail + with pytest.raises(DaggerException): + run_tasks([extract_2], resume_id="test") + + # now we change implementation of FaultyTask run method + + FaultyTask.run = new_run + + # and rerun from where we left of + extract_1 = ExtractTask({}) + faultyTask = FaultyTask({}, [extract_1]) + extract_2 = ExtractTask({}, [faultyTask]) + + run_tasks([extract_2], resume_id="test") + # assert that we are running a faulty task once, and a correct task once + assert array_faulty[0] == 21 + # assert that we dont repeat tasks that are done + assert array_extract[0] == 2