From 419df4599c3fedfd5c2c791706ef821d1abc59ae Mon Sep 17 00:00:00 2001 From: Jean-Roch Vlimant Date: Fri, 27 Sep 2019 04:26:05 -0700 Subject: [PATCH 1/2] catch train failure --- nnlo/mpi/process.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nnlo/mpi/process.py b/nnlo/mpi/process.py index 2f1e3cd..a5ceb3e 100644 --- a/nnlo/mpi/process.py +++ b/nnlo/mpi/process.py @@ -553,7 +553,13 @@ def train(self): self.model.set_weights(self.weights) Timeline.begin("train_on_batch") - train_metrics = self.model.train_on_batch( x=batch[0], y=batch[1] ) + try: + train_metrics = self.model.train_on_batch( x=batch[0], y=batch[1] ) + except Exception as e: + print("Exception in train_on_batch") + print(str(e)) + print("trying to escape gracefully") + Timeline.end("train_on_batch") if epoch_metrics.shape != train_metrics.shape: epoch_metrics = np.zeros( train_metrics.shape) From 8f5986f7a426f545c4efba9d95fb71aff6aebb97 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Fri, 11 Oct 2019 07:52:42 -0700 Subject: [PATCH 2/2] bork the training procedure if worker fails --- nnlo/mpi/process.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/nnlo/mpi/process.py b/nnlo/mpi/process.py index a5ceb3e..393366f 100644 --- a/nnlo/mpi/process.py +++ b/nnlo/mpi/process.py @@ -238,6 +238,7 @@ def apply_update(self): 'update': 12, 'begin_gem': 13, 'update_gem': 14, + 'bork': 99, } # This dict is for reverse tag lookups. inv_tag_lookup = { value:key for key,value in tag_lookup.items() } @@ -367,6 +368,12 @@ def send_exit_to_parent(self): if self.parent_rank is not None: self.send( None, 'exit' ) + def send_bork_to_parent(self): + if self.is_shadow( sync = True): return + """Send exit tag to parent process, if parent process exists""" + if self.parent_rank is not None: + self.send( None, 'bork' ) + def send_history_to_parent(self): if self.is_shadow():return """Send keras history or dict of keras histories""" @@ -556,9 +563,10 @@ def train(self): try: train_metrics = self.model.train_on_batch( x=batch[0], y=batch[1] ) except Exception as e: - print("Exception in train_on_batch") - print(str(e)) - print("trying to escape gracefully") + logging.warn("Exception in train_on_batch") + logging.warn(str(e)) + self.send_bork_to_parent() + self.stop_training = True Timeline.end("train_on_batch") if epoch_metrics.shape != train_metrics.shape: @@ -751,6 +759,13 @@ def do_worker_finish_sequence(self, worker_id): #self.histories[key] = self.recv_history_from_child(worker_id) self.running_workers.remove(worker_id) self.num_sync_workers -= 1 + + def do_worker_fail_sequence(self, worker_id): + """Actions to take when a worker fails during training""" + #self.histories.update( self.recv_history_from_child(worker_id) ) + self.running_workers.remove(worker_id) + self.num_sync_workers -= 1 + self.stop_training = True def process_message(self, status): """Extracts message source and tag from the MPI status object and processes the message. @@ -771,6 +786,8 @@ def process_message(self, status): self.do_gem_sequence(source) elif tag == 'exit': self.do_worker_finish_sequence(source) + elif tag == 'exit': + self.do_worker_fail_sequence(source) else: raise ValueError("Tag %s not recognized" % tag) return tag