From 7bf98f3ea02ee1b39667946c553ddb549f0208a3 Mon Sep 17 00:00:00 2001 From: EdwardTyantov Date: Wed, 5 Jul 2017 19:04:08 +0300 Subject: [PATCH] add fix for dict params; grad clip --- tuner_utils/yellowfin.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tuner_utils/yellowfin.py b/tuner_utils/yellowfin.py index 0961a08..f45f35d 100644 --- a/tuner_utils/yellowfin.py +++ b/tuner_utils/yellowfin.py @@ -203,9 +203,17 @@ def step(self): if group['weight_decay'] != 0: grad = grad.add(group['weight_decay'], p.data) - if self._clip_thresh != None: - torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh) - + #if self._clip_thresh != None: + # torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh) + if self._clip_thresh is not None: + if isinstance(self._var_list[0], dict): + params = [] + for p in self._var_list: + params.extend(p['params']) + torch.nn.utils.clip_grad_norm(params, self._clip_thresh) + else: + torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh) + # apply update self._optimizer.step()