diff --git a/trainer.py b/trainer.py index 91131a9..bed2621 100644 --- a/trainer.py +++ b/trainer.py @@ -148,7 +148,7 @@ def _set_device(args): gpus = [] for device in device_type: - if device_type == -1: + if device == -1: device = torch.device("cpu") else: device = torch.device("cuda:{}".format(device))