Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion nnlo/mpi/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def apply_update(self):
'update': 12,
'begin_gem': 13,
'update_gem': 14,
'bork': 99,
}
# This dict is for reverse tag lookups.
inv_tag_lookup = { value:key for key,value in tag_lookup.items() }
Expand Down Expand Up @@ -367,6 +368,12 @@ def send_exit_to_parent(self):
if self.parent_rank is not None:
self.send( None, 'exit' )

def send_bork_to_parent(self):
if self.is_shadow( sync = True): return
"""Send exit tag to parent process, if parent process exists"""
if self.parent_rank is not None:
self.send( None, 'bork' )

def send_history_to_parent(self):
if self.is_shadow():return
"""Send keras history or dict of keras histories"""
Expand Down Expand Up @@ -553,7 +560,14 @@ def train(self):
self.model.set_weights(self.weights)

Timeline.begin("train_on_batch")
train_metrics = self.model.train_on_batch( x=batch[0], y=batch[1] )
try:
train_metrics = self.model.train_on_batch( x=batch[0], y=batch[1] )
except Exception as e:
logging.warn("Exception in train_on_batch")
logging.warn(str(e))
self.send_bork_to_parent()
self.stop_training = True

Timeline.end("train_on_batch")
if epoch_metrics.shape != train_metrics.shape:
epoch_metrics = np.zeros( train_metrics.shape)
Expand Down Expand Up @@ -745,6 +759,13 @@ def do_worker_finish_sequence(self, worker_id):
#self.histories[key] = self.recv_history_from_child(worker_id)
self.running_workers.remove(worker_id)
self.num_sync_workers -= 1

def do_worker_fail_sequence(self, worker_id):
"""Actions to take when a worker fails during training"""
#self.histories.update( self.recv_history_from_child(worker_id) )
self.running_workers.remove(worker_id)
self.num_sync_workers -= 1
self.stop_training = True

def process_message(self, status):
"""Extracts message source and tag from the MPI status object and processes the message.
Expand All @@ -765,6 +786,8 @@ def process_message(self, status):
self.do_gem_sequence(source)
elif tag == 'exit':
self.do_worker_finish_sequence(source)
elif tag == 'exit':
self.do_worker_fail_sequence(source)
else:
raise ValueError("Tag %s not recognized" % tag)
return tag
Expand Down