diff --git a/tuner_utils/yellowfin.py b/tuner_utils/yellowfin.py index 0961a08..f45f35d 100644 --- a/tuner_utils/yellowfin.py +++ b/tuner_utils/yellowfin.py @@ -203,9 +203,17 @@ def step(self): if group['weight_decay'] != 0: grad = grad.add(group['weight_decay'], p.data) - if self._clip_thresh != None: - torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh) - + #if self._clip_thresh != None: + # torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh) + if self._clip_thresh is not None: + if isinstance(self._var_list[0], dict): + params = [] + for p in self._var_list: + params.extend(p['params']) + torch.nn.utils.clip_grad_norm(params, self._clip_thresh) + else: + torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh) + # apply update self._optimizer.step()