From b81af664b11c18ffc045fa8b2f55bf1123ccd753 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 29 Nov 2022 17:14:52 -0800 Subject: [PATCH 1/3] Start some cleanup --- distributed.py | 87 ++++++++++++++++++++++ train_maskgit.py | 182 +++++++++++++++++++++++++---------------------- 2 files changed, 183 insertions(+), 86 deletions(-) create mode 100644 distributed.py diff --git a/distributed.py b/distributed.py new file mode 100644 index 0000000..0937403 --- /dev/null +++ b/distributed.py @@ -0,0 +1,87 @@ +import os + +import torch + + +def is_global_primary(args): + return args.rank == 0 + + +def is_local_primary(args): + return args.local_rank == 0 + + +def is_primary(args, local=False): + return is_local_primary(args) if local else is_global_primary(args) + + +def is_using_distributed(): + if 'WORLD_SIZE' in os.environ: + return int(os.environ['WORLD_SIZE']) > 1 + if 'SLURM_NTASKS' in os.environ: + return int(os.environ['SLURM_NTASKS']) > 1 + return False + + +def world_info_from_env(): + local_rank = 0 + for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): + if v in os.environ: + local_rank = int(os.environ[v]) + break + global_rank = 0 + for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): + if v in os.environ: + global_rank = int(os.environ[v]) + break + world_size = 1 + for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): + if v in os.environ: + world_size = int(os.environ[v]) + break + + return local_rank, global_rank, world_size + + +def init_distributed_device(args): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + args.distributed = False + args.world_size = 1 + args.rank = 0 # global rank + args.local_rank = 0 + if is_using_distributed(): + if 'SLURM_PROCID' in os.environ: + # DDP via SLURM + args.local_rank, args.rank, args.world_size = world_info_from_env() + # SLURM var -> torch.distributed vars in case needed + os.environ['LOCAL_RANK'] = str(args.local_rank) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + else: + # DDP via torchrun, torch.distributed.launch + args.local_rank, _, _ = world_info_from_env() + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url) + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + args.distributed = True + + if torch.cuda.is_available(): + if args.distributed and not args.no_set_device_rank: + device = 'cuda:%d' % args.local_rank + else: + device = 'cuda:0' + torch.cuda.set_device(device) + else: + device = 'cpu' + args.device = device + device = torch.device(device) + return device diff --git a/train_maskgit.py b/train_maskgit.py index 7e9072a..7e2c264 100644 --- a/train_maskgit.py +++ b/train_maskgit.py @@ -1,38 +1,78 @@ import os import time + import numpy as np +import open_clip import torch +import torch.multiprocessing as mp import wandb +from open_clip import tokenizer from torch import optim +from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, TensorDataset from tqdm import tqdm -import torch.multiprocessing as mp -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel -from vivq import VIVQ, BASE_SHAPE + +from distributed import init_distributed_device, is_primary, is_local_primary from maskgit import MaskGit from paella import DenoiseUNet -from utils import get_dataloader, sample_paella, sample_maskgit +from utils import get_dataloader, sample_paella +from vivq import VIVQ, BASE_SHAPE # from transformers import T5Tokenizer, T5Model -import open_clip -from open_clip import tokenizer +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--run-name', type=str, default=None) +parser.add_argument('--model', type=str, default='paella') +parser.add_argument('--dataset', type=str, default='second_stage') + +parser.add_argument('--total-steps', type=int, default=300_000) +parser.add_argument('-b', '--batch-size', type=int, default=4) +parser.add_argument('-j', '--num-workers', type=int, default=6) +parser.add_argument('--log-period', type=int, default=2000) +parser.add_argument('--extra-ckpt', type=int, default=10_000) +parser.add_argument('--accum-grad', type=int, default=2) + +# args.vq_path = "./models/server/vivq_8192_5_skipframes/model_100000.pt" +parser.add_argument('--vq-path', type=str, default='/fsx/phenaki/src/models/model_120000.pt') + +parser.add_argument('--dim', type=int, default=1224) +parser.add_argument('--num-tokens', type=int, default=8192) +parser.add_argument('--max-seq-len', type=int, default=6 * 16 * 16) +parser.add_argument('--depth', type=int, default=22) +parser.add_argument('--dim-context', type=int, default=1024) +parser.add_argument('--heads', type=int, default=22) + +parser.add_argument('--clip-len', type=int, default=10) +parser.add_argument('--skip-frames', type=int, default=10) + + +def main(): + args = parser.parse_args() + + # FIXME turn into arg(s) + args.urls = { + # "videos": "file:C:/Users/d6582/Documents/ml/phenaki/data/webvid/tar_files/0.tar", + # "videos": "/fsx/mas/phenaki/data/webvid/tar_files/{0..249}.tar", + "videos": "/fsx/phenaki/data/videos/tar_files/{0..1243}.tar", + # "images": "file:C:/Users/d6582/Documents/ml/paella/evaluations/laion-30k/000069.tar" + # "images": "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -" + "images": "/fsx/phenaki/coyo-700m/coyo-data-2/{00000..20892}.tar" + } + + # FIXME handle autogen run name if empty + assert args.run_name is not None -def train(proc_id, args): if os.path.exists(f"results/{args.run_name}/log.pt"): resume = True else: resume = False - parallel = len(args.devices) > 1 - device = torch.device(proc_id) - if parallel: - torch.cuda.set_device(proc_id) + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True - dist.init_process_group(backend="nccl", init_method="file:///fsx/phenaki/src/dist_file", - world_size=args.n_nodes * len(args.devices), - rank=proc_id + len(args.devices) * args.node_id) - torch.set_num_threads(6) + + device = init_distributed_device(args) vqmodel = VIVQ(codebook_size=args.num_tokens).to(device) vqmodel.load_state_dict(torch.load(args.vq_path, map_location=device)) @@ -47,13 +87,23 @@ def train(proc_id, args): clip_model = clip_model.to(device).eval().requires_grad_(False) if args.model == "maskgit": - model = MaskGit(dim=args.dim, num_tokens=args.num_tokens, max_seq_len=args.max_seq_len, depth=args.depth, dim_context=args.dim_context, heads=args.heads).to(device) + model = MaskGit( + dim=args.dim, + num_tokens=args.num_tokens, + max_seq_len=args.max_seq_len, + depth=args.depth, + dim_context=args.dim_context, + heads=args.heads).to(device) elif args.model == "paella": - model = DenoiseUNet(num_labels=args.num_tokens, down_levels=[4, 6, 8], up_levels=[8, 6, 4], c_clip=args.dim_context).to(device) + model = DenoiseUNet( + num_labels=args.num_tokens, + down_levels=[4, 6, 8], + up_levels=[8, 6, 4], + c_clip=args.dim_context).to(device) else: raise NotImplementedError() - if not proc_id and args.node_id == 0: + if is_primary(args): print(f"Starting run '{args.run_name}'....") print(f"Batch Size check: {args.n_nodes * args.batch_size * args.accum_grad * len(args.devices)}") print(f"Number of Parameters: {sum([p.numel() for p in model.parameters()])}") @@ -66,10 +116,12 @@ def train(proc_id, args): grad_accum_steps = args.accum_grad # scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, total_steps=args.total_steps, pct_start=0.1, div_factor=25, anneal_strategy='cos') - scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, total_steps=args.total_steps, pct_start=0.1, div_factor=25, final_div_factor=1 / 25, anneal_strategy='linear') + scheduler = optim.lr_scheduler.OneCycleLR( + optimizer, max_lr=lr, total_steps=args.total_steps, + pct_start=0.1, div_factor=25, final_div_factor=1 / 25, anneal_strategy='linear') if resume: - if not proc_id and args.node_id == 0: + if is_primary(args): print("Loading last checkpoint....") logs = torch.load(f"results/{args.run_name}/log.pt") run_id = logs["wandb_run_id"] @@ -78,16 +130,17 @@ def train(proc_id, args): accuracies = logs["accuracies"] total_loss, total_acc = losses[-1] * start_step, accuracies[-1] * start_step model.load_state_dict(torch.load(f"models/{args.run_name}/model.pt", map_location=device)) - if not proc_id and args.node_id == 0: + if is_primary: print("Loaded model....") opt_state = torch.load(f"models/{args.run_name}/optim.pt", map_location=device) last_lr = opt_state["param_groups"][0]["lr"] with torch.no_grad(): while last_lr > optimizer.param_groups[0]["lr"]: scheduler.step() - if not proc_id and args.node_id == 0: + if is_primary: print(f"Initialized scheduler") - print(f"Sanity check => Last-LR: {last_lr} == Current-LR: {optimizer.param_groups[0]['lr']} -> {last_lr == optimizer.param_groups[0]['lr']}") + print(f"Sanity check => Last-LR: {last_lr} == Current-LR: {optimizer.param_groups[0]['lr']} " + f"-> {last_lr == optimizer.param_groups[0]['lr']}") optimizer.load_state_dict(opt_state) del opt_state else: @@ -96,19 +149,22 @@ def train(proc_id, args): accuracies = [] start_step, total_loss, total_acc = 0, 0, 0 - if not proc_id and args.node_id == 0: + if is_primary: wandb.init(project="DenoiseGIT", name=args.run_name, entity="wand-tech", config=vars(args), id=run_id, resume="allow") os.makedirs(f"results/{args.run_name}", exist_ok=True) os.makedirs(f"models/{args.run_name}", exist_ok=True) wandb.watch(model) - if parallel: + if args.distributed: model = DistributedDataParallel(model, device_ids=[device], output_device=device) model.train() # images, videos = next(iter(dataset)) - pbar = tqdm(enumerate(dataset, start=start_step), total=args.total_steps, initial=start_step) if args.node_id == 0 and proc_id == 0 else enumerate(dataset, start=start_step) + if is_primary(args): + pbar = tqdm(enumerate(dataset, start=start_step), total=args.total_steps, initial=start_step) + else: + pbar = enumerate(dataset, start=start_step) # pbar = tqdm(range(1000000)) for step, (images, videos, captions) in pbar: # for step in pbar: @@ -158,7 +214,7 @@ def train(proc_id, args): scheduler.step() optimizer.zero_grad() - if not proc_id and args.node_id == 0: + if is_primary(args): log = { 'loss': total_loss / (step + 1), 'curr_loss': loss.item(), @@ -171,7 +227,7 @@ def train(proc_id, args): pbar.set_postfix(log) wandb.log(log) - if args.node_id == 0 and proc_id == 0 and step % args.log_period == 0: + if is_primary(args) and step % args.log_period == 0: losses.append(total_loss / (step + 1)) accuracies.append(total_acc / (step + 1)) @@ -194,7 +250,9 @@ def train(proc_id, args): text_tokens = text_tokens.to(device) cool_captions_embeddings = clip_model.encode_text(text_tokens).float() - cool_captions = DataLoader(TensorDataset(cool_captions_embeddings.repeat_interleave(n, dim=0)), batch_size=1) + cool_captions = DataLoader( + TensorDataset(cool_captions_embeddings.repeat_interleave(n, dim=0)), + batch_size=1) cool_captions_sampled = [] st = time.time() for caption_embedding in cool_captions: @@ -235,7 +293,10 @@ def train(proc_id, args): log_table = wandb.Table(data=log_data, columns=["Caption", "Video", "Orig", "Recon"]) wandb.log({"Log": log_table}) - log_data_cool = [[cool_captions_text[i]] + [wandb.Video(cool_captions_sampled[i].cpu().mul(255).add_(0.5).clamp_(0, 255).numpy())] for i in range(len(cool_captions_text))] + log_data_cool = [ + [cool_captions_text[i]] + + [wandb.Video(cool_captions_sampled[i].cpu().mul(255).add_(0.5).clamp_(0, 255).numpy())] + for i in range(len(cool_captions_text))] log_table_cool = wandb.Table(data=log_data_cool, columns=["Caption", "Video"]) wandb.log({"Log Cool": log_table_cool}) @@ -249,64 +310,13 @@ def train(proc_id, args): torch.save(model.module.state_dict(), f"models/{args.run_name}/model.pt") torch.save(optimizer.state_dict(), f"models/{args.run_name}/optim.pt") torch.save(grad_scaler.state_dict(), f"models/{args.run_name}/scaler.pt") - torch.save({'step': step, 'losses': losses, 'accuracies': accuracies, 'wandb_run_id': run_id}, f"results/{args.run_name}/log.pt") - - -def launch(args): - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(d) for d in args.devices]) - if len(args.devices) == 1: - train(0, args) - else: - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "33751" - p = mp.spawn(train, nprocs=len(args.devices), args=(args,)) + torch.save({ + 'step': step, 'losses': losses, 'accuracies': accuracies, + 'wandb_run_id': run_id}, f"results/{args.run_name}/log.pt") if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - args = parser.parse_args() - args.run_name = "Paella_Test_4" - args.model = "paella" - args.dataset = "second_stage" - args.urls = { - # "videos": "file:C:/Users/d6582/Documents/ml/phenaki/data/webvid/tar_files/0.tar", - # "videos": "/fsx/mas/phenaki/data/webvid/tar_files/{0..249}.tar", - "videos": "/fsx/phenaki/data/videos/tar_files/{0..1243}.tar", - # "images": "file:C:/Users/d6582/Documents/ml/paella/evaluations/laion-30k/000069.tar" - # "images": "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -" - "images": "/fsx/phenaki/coyo-700m/coyo-data-2/{00000..20892}.tar" - } - args.total_steps = 300_000 - args.batch_size = 4 - args.num_workers = 10 - args.log_period = 2000 - args.extra_ckpt = 10_000 - args.accum_grad = 2 - - args.vq_path = "/fsx/phenaki/src/models/model_120000.pt" # ./models/server/vivq_8192_5_skipframes/model_100000.pt - # args.vq_path = "./models/server/vivq_8192_5_skipframes/model_100000.pt" - args.dim = 1224 # 1224 - args.num_tokens = 8192 - args.max_seq_len = 6 * 16 * 16 - args.depth = 22 # 22 - args.dim_context = 1024 # for clip, 512 for T5 - args.heads = 22 # 22 - - args.clip_len = 10 - args.skip_frames = 5 - - args.n_nodes = 3 - # args.n_nodes = 1 - args.node_id = int(os.environ["SLURM_PROCID"]) - # args.node_id = 0 - args.devices = [0, 1, 2, 3, 4, 5, 6, 7] - # args.devices = [0] - - print("Launching with args: ", args) - launch( - args - ) + main() # device = "cuda" # model = MaskGit(dim=args.dim, num_tokens=args.num_tokens, max_seq_len=args.max_seq_len, depth=args.depth, From 3c975bfe87189d0fe8c0fe9d61eea087da2b0c2b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 29 Nov 2022 17:27:30 -0800 Subject: [PATCH 2/3] Add dist args --- train_maskgit.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/train_maskgit.py b/train_maskgit.py index 7e2c264..4d4a895 100644 --- a/train_maskgit.py +++ b/train_maskgit.py @@ -33,6 +33,11 @@ parser.add_argument('--extra-ckpt', type=int, default=10_000) parser.add_argument('--accum-grad', type=int, default=2) +parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") +parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") +parser.add_argument("--no-set-device-rank", default=False, action="store_true", + help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).") + # args.vq_path = "./models/server/vivq_8192_5_skipframes/model_100000.pt" parser.add_argument('--vq-path', type=str, default='/fsx/phenaki/src/models/model_120000.pt') From 80b1e2b1122e05dd9b5de5cc95312ccd55332dd9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 29 Nov 2022 17:42:16 -0800 Subject: [PATCH 3/3] Fix world size log --- train_maskgit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_maskgit.py b/train_maskgit.py index 4d4a895..0a89485 100644 --- a/train_maskgit.py +++ b/train_maskgit.py @@ -110,7 +110,7 @@ def main(): if is_primary(args): print(f"Starting run '{args.run_name}'....") - print(f"Batch Size check: {args.n_nodes * args.batch_size * args.accum_grad * len(args.devices)}") + print(f"Batch Size check: {args.world_size * args.batch_size * args.accum_grad}") print(f"Number of Parameters: {sum([p.numel() for p in model.parameters()])}") lr = 3e-4 @@ -155,8 +155,8 @@ def main(): start_step, total_loss, total_acc = 0, 0, 0 if is_primary: - wandb.init(project="DenoiseGIT", name=args.run_name, entity="wand-tech", config=vars(args), id=run_id, - resume="allow") + #wandb.init(project="DenoiseGIT", name=args.run_name, entity="wand-tech", config=vars(args), id=run_id, + # resume="allow") os.makedirs(f"results/{args.run_name}", exist_ok=True) os.makedirs(f"models/{args.run_name}", exist_ok=True) wandb.watch(model)