diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index b31061d7..52f40306 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -414,7 +414,20 @@ def closure(): return total_loss # This takes the step for this minibatch - loss += optimizer.step(closure).detach().cpu().numpy() + # If an exception occurs at this point, the root cause may be related + # to the presence of large gradients that produce `nan` losses. + # We write a custom error message here to help communicate this point. + try: + loss += optimizer.step(closure).detach().cpu().numpy() + except Exception as e: + print('\n\n') + msg = 'An error has occurred during the parameter update and\n'\ + 'loss calculation step. This problem may be related to excessively\n'\ + 'large parameter gradients. Please try lowering the learning\n'\ + 'rates or disabling grad_required for one or more parameters\n'\ + + raise Exception(msg) from e + loss /= normalization