-
Notifications
You must be signed in to change notification settings - Fork 7
Add Sparsified sgd #75
base: develop
Are you sure you want to change the base?
Changes from all commits
89bc359
30301c5
c252224
2efaca4
fce28a9
c399a29
d3fb819
839039c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,13 @@ | ||
| import torch | ||
| import torch.distributed as dist | ||
| import numpy as np | ||
| import torch.nn.functional as F | ||
|
|
||
| from utils import checkpoint | ||
| from utils import log | ||
| from utils.metrics import AverageMeter | ||
| from utils.helper import Timeit, maybe_range, update_best_runtime_metric | ||
| from utils.communication import aggregate_gradients, global_average | ||
| from utils.communication import aggregate_gradients, global_average, aggregate_sparsified_gradients | ||
| from utils.utils import convert_dtype | ||
|
|
||
| from datasets import create_dataset | ||
|
|
@@ -22,6 +24,7 @@ def train_epoch(model, optimizer, criterion, scheduler, options, timeit): | |
| scheduler.step() | ||
|
|
||
| data = convert_dtype(options.dtype, data) | ||
|
|
||
| if options.force_target_dtype: | ||
| target = convert_dtype(options.dtype, target) | ||
|
|
||
|
|
@@ -32,9 +35,27 @@ def train_epoch(model, optimizer, criterion, scheduler, options, timeit): | |
| output = model(data) | ||
| loss = criterion(output, target) | ||
| loss.backward() | ||
| aggregate_gradients(model, options.world_size) | ||
|
|
||
| if options.opt_name == 'sparsified_sgd': | ||
| aggregate_sparsified_gradients(model, options.world_size, | ||
| options.sparse_grad_size, | ||
| options.random_sparse, | ||
| optimizer, | ||
| scheduler.get_lr()) | ||
| else: | ||
| aggregate_gradients(model, options.world_size) | ||
|
|
||
| optimizer.step() | ||
|
|
||
| if options.model_name == 'logistic_regression' and options.train_validate: | ||
| t = options.runtime['current_epoch'] * options.train_num_samples_per_device + batch_idx * options.batch_size | ||
| optimizer.update_estimated_weights(model, t, options.sparse_grad_size) | ||
|
|
||
| if t % options.compute_loss_every == 0: | ||
| print("Train validation....") | ||
| timeit.pause() | ||
| train_validate(optimizer, model, options) | ||
| timeit.resume() | ||
| with torch.no_grad(): | ||
| loss = loss.item() | ||
| loss = global_average(loss, 1).item() | ||
|
|
@@ -46,6 +67,45 @@ def train_epoch(model, optimizer, criterion, scheduler, options, timeit): | |
| timeit.resume() | ||
|
|
||
|
|
||
| def train_validate(optimizer, model, options): | ||
| """ Validation on train data by using weighted average of parameters """ | ||
| estimated_weights = optimizer.get_estimated_weights(model) | ||
| num_samples = 0 | ||
| l1 = options.l1_coef | ||
| l2 = options.l2_coef | ||
|
|
||
| loss = 0 | ||
|
|
||
| for batch_idx, (data, target) in zip(maybe_range(options.max_batch_per_epoch), | ||
| options.val_loader): | ||
| data = convert_dtype(options.dtype, data) | ||
| if options.force_target_dtype: | ||
| target = convert_dtype(options.dtype, target) | ||
|
|
||
| if options.use_cuda: | ||
| data, target = data.cuda(), target.cuda() | ||
| target = target * 2 - 1 | ||
|
|
||
| for weight in estimated_weights: | ||
| w = weight.squeeze() | ||
| batch_loss = np.log(1 + np.exp(-target * (data @ w))) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a softmargin loss, right? Could use https://pytorch.org/docs/stable/nn.html#softmarginloss and https://pytorch.org/docs/stable/torch.html#torch.matmul here. Especially since numpy ops are on the CPU, not GPU, so the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I'll use the softmargin loss instead, thanks. |
||
| loss += batch_loss.sum().item() | ||
|
|
||
| num_samples += data.size()[0] | ||
|
|
||
| train_loss = global_average(loss, num_samples).item() | ||
|
|
||
| l2_loss = sum(weight.norm(2) ** 2 for weight in estimated_weights).item() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above, could use https://pytorch.org/docs/stable/nn.html#l1loss and https://pytorch.org/docs/stable/nn.html#torch.nn.MSELoss
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we just need to calculate L1 and L2 norm of a tensor here, I think using loss functions make it complicated. |
||
| train_loss += l2 / 2 * l2_loss | ||
| l1_loss = sum(weight.norm(1) for weight in estimated_weights).item() | ||
| train_loss += l1 * l1_loss | ||
|
|
||
| print("Global Train Loss: " + str(train_loss)) | ||
|
|
||
| with open(options.ckpt_run_dir + "/" + str(dist.get_rank()) + "_train_validation.txt", "a+") as file: | ||
| file.write(str(train_loss) + "\n") | ||
|
|
||
|
|
||
| def validate(model, optimizer, criterion, metrics, options): | ||
| model.eval() | ||
|
|
||
|
|
@@ -123,7 +183,7 @@ def __call__(self, model, optimizer, criterion, metrics, scheduler, options): | |
| options.batch_size), 0) | ||
|
|
||
| # train the model and evaluate the model per args.eval_freq | ||
| max_epochs = min(options.train_epochs, options.max_train_steps)\ | ||
| max_epochs = min(options.train_epochs, options.max_train_steps) \ | ||
| if options.max_train_steps else options.train_epochs | ||
| start_epoch = options.runtime['current_epoch'] if options.resume else 0 | ||
| options.runtime['records'] = options.runtime.get('records', []) | ||
|
|
@@ -134,6 +194,7 @@ def __call__(self, model, optimizer, criterion, metrics, scheduler, options): | |
|
|
||
| timeit = Timeit(0 if len(options.runtime['cumu_time_val']) == 0 | ||
| else options.runtime['cumu_time_val'][-1]) | ||
|
|
||
| for epoch in range(start_epoch, max_epochs): | ||
| options.runtime['current_epoch'] = epoch | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Single letter variable names are usually not that good for readability (except for something like
ias a counter in a for loop), verbose variable names make the code more readable.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I'll fix it.