Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os

import torch


def is_global_primary(args):
return args.rank == 0


def is_local_primary(args):
return args.local_rank == 0


def is_primary(args, local=False):
return is_local_primary(args) if local else is_global_primary(args)


def is_using_distributed():
if 'WORLD_SIZE' in os.environ:
return int(os.environ['WORLD_SIZE']) > 1
if 'SLURM_NTASKS' in os.environ:
return int(os.environ['SLURM_NTASKS']) > 1
return False


def world_info_from_env():
local_rank = 0
for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
if v in os.environ:
local_rank = int(os.environ[v])
break
global_rank = 0
for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
if v in os.environ:
global_rank = int(os.environ[v])
break
world_size = 1
for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
if v in os.environ:
world_size = int(os.environ[v])
break

return local_rank, global_rank, world_size


def init_distributed_device(args):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
args.distributed = False
args.world_size = 1
args.rank = 0 # global rank
args.local_rank = 0
if is_using_distributed():
if 'SLURM_PROCID' in os.environ:
# DDP via SLURM
args.local_rank, args.rank, args.world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed
os.environ['LOCAL_RANK'] = str(args.local_rank)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
else:
# DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True

if torch.cuda.is_available():
if args.distributed and not args.no_set_device_rank:
device = 'cuda:%d' % args.local_rank
else:
device = 'cuda:0'
torch.cuda.set_device(device)
else:
device = 'cpu'
args.device = device
device = torch.device(device)
return device
193 changes: 104 additions & 89 deletions train_maskgit.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,83 @@
import os
import time

import numpy as np
import open_clip
import torch
import torch.multiprocessing as mp
import wandb
from open_clip import tokenizer
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from vivq import VIVQ, BASE_SHAPE

from distributed import init_distributed_device, is_primary, is_local_primary
from maskgit import MaskGit
from paella import DenoiseUNet
from utils import get_dataloader, sample_paella, sample_maskgit
from utils import get_dataloader, sample_paella
from vivq import VIVQ, BASE_SHAPE
# from transformers import T5Tokenizer, T5Model
import open_clip
from open_clip import tokenizer

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--run-name', type=str, default=None)
parser.add_argument('--model', type=str, default='paella')
parser.add_argument('--dataset', type=str, default='second_stage')

parser.add_argument('--total-steps', type=int, default=300_000)
parser.add_argument('-b', '--batch-size', type=int, default=4)
parser.add_argument('-j', '--num-workers', type=int, default=6)
parser.add_argument('--log-period', type=int, default=2000)
parser.add_argument('--extra-ckpt', type=int, default=10_000)
parser.add_argument('--accum-grad', type=int, default=2)

parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
parser.add_argument("--no-set-device-rank", default=False, action="store_true",
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).")

# args.vq_path = "./models/server/vivq_8192_5_skipframes/model_100000.pt"
parser.add_argument('--vq-path', type=str, default='/fsx/phenaki/src/models/model_120000.pt')

parser.add_argument('--dim', type=int, default=1224)
parser.add_argument('--num-tokens', type=int, default=8192)
parser.add_argument('--max-seq-len', type=int, default=6 * 16 * 16)
parser.add_argument('--depth', type=int, default=22)
parser.add_argument('--dim-context', type=int, default=1024)
parser.add_argument('--heads', type=int, default=22)

parser.add_argument('--clip-len', type=int, default=10)
parser.add_argument('--skip-frames', type=int, default=10)


def main():
args = parser.parse_args()

# FIXME turn into arg(s)
args.urls = {
# "videos": "file:C:/Users/d6582/Documents/ml/phenaki/data/webvid/tar_files/0.tar",
# "videos": "/fsx/mas/phenaki/data/webvid/tar_files/{0..249}.tar",
"videos": "/fsx/phenaki/data/videos/tar_files/{0..1243}.tar",
# "images": "file:C:/Users/d6582/Documents/ml/paella/evaluations/laion-30k/000069.tar"
# "images": "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
"images": "/fsx/phenaki/coyo-700m/coyo-data-2/{00000..20892}.tar"
}

# FIXME handle autogen run name if empty
assert args.run_name is not None

def train(proc_id, args):
if os.path.exists(f"results/{args.run_name}/log.pt"):
resume = True
else:
resume = False
parallel = len(args.devices) > 1
device = torch.device(proc_id)

if parallel:
torch.cuda.set_device(proc_id)
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
dist.init_process_group(backend="nccl", init_method="file:///fsx/phenaki/src/dist_file",
world_size=args.n_nodes * len(args.devices),
rank=proc_id + len(args.devices) * args.node_id)
torch.set_num_threads(6)

device = init_distributed_device(args)

vqmodel = VIVQ(codebook_size=args.num_tokens).to(device)
vqmodel.load_state_dict(torch.load(args.vq_path, map_location=device))
Expand All @@ -47,15 +92,25 @@ def train(proc_id, args):
clip_model = clip_model.to(device).eval().requires_grad_(False)

if args.model == "maskgit":
model = MaskGit(dim=args.dim, num_tokens=args.num_tokens, max_seq_len=args.max_seq_len, depth=args.depth, dim_context=args.dim_context, heads=args.heads).to(device)
model = MaskGit(
dim=args.dim,
num_tokens=args.num_tokens,
max_seq_len=args.max_seq_len,
depth=args.depth,
dim_context=args.dim_context,
heads=args.heads).to(device)
elif args.model == "paella":
model = DenoiseUNet(num_labels=args.num_tokens, down_levels=[4, 6, 8], up_levels=[8, 6, 4], c_clip=args.dim_context).to(device)
model = DenoiseUNet(
num_labels=args.num_tokens,
down_levels=[4, 6, 8],
up_levels=[8, 6, 4],
c_clip=args.dim_context).to(device)
else:
raise NotImplementedError()

if not proc_id and args.node_id == 0:
if is_primary(args):
print(f"Starting run '{args.run_name}'....")
print(f"Batch Size check: {args.n_nodes * args.batch_size * args.accum_grad * len(args.devices)}")
print(f"Batch Size check: {args.world_size * args.batch_size * args.accum_grad}")
print(f"Number of Parameters: {sum([p.numel() for p in model.parameters()])}")

lr = 3e-4
Expand All @@ -66,10 +121,12 @@ def train(proc_id, args):

grad_accum_steps = args.accum_grad
# scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, total_steps=args.total_steps, pct_start=0.1, div_factor=25, anneal_strategy='cos')
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, total_steps=args.total_steps, pct_start=0.1, div_factor=25, final_div_factor=1 / 25, anneal_strategy='linear')
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=lr, total_steps=args.total_steps,
pct_start=0.1, div_factor=25, final_div_factor=1 / 25, anneal_strategy='linear')

if resume:
if not proc_id and args.node_id == 0:
if is_primary(args):
print("Loading last checkpoint....")
logs = torch.load(f"results/{args.run_name}/log.pt")
run_id = logs["wandb_run_id"]
Expand All @@ -78,16 +135,17 @@ def train(proc_id, args):
accuracies = logs["accuracies"]
total_loss, total_acc = losses[-1] * start_step, accuracies[-1] * start_step
model.load_state_dict(torch.load(f"models/{args.run_name}/model.pt", map_location=device))
if not proc_id and args.node_id == 0:
if is_primary:
print("Loaded model....")
opt_state = torch.load(f"models/{args.run_name}/optim.pt", map_location=device)
last_lr = opt_state["param_groups"][0]["lr"]
with torch.no_grad():
while last_lr > optimizer.param_groups[0]["lr"]:
scheduler.step()
if not proc_id and args.node_id == 0:
if is_primary:
print(f"Initialized scheduler")
print(f"Sanity check => Last-LR: {last_lr} == Current-LR: {optimizer.param_groups[0]['lr']} -> {last_lr == optimizer.param_groups[0]['lr']}")
print(f"Sanity check => Last-LR: {last_lr} == Current-LR: {optimizer.param_groups[0]['lr']} "
f"-> {last_lr == optimizer.param_groups[0]['lr']}")
optimizer.load_state_dict(opt_state)
del opt_state
else:
Expand All @@ -96,19 +154,22 @@ def train(proc_id, args):
accuracies = []
start_step, total_loss, total_acc = 0, 0, 0

if not proc_id and args.node_id == 0:
wandb.init(project="DenoiseGIT", name=args.run_name, entity="wand-tech", config=vars(args), id=run_id,
resume="allow")
if is_primary:
#wandb.init(project="DenoiseGIT", name=args.run_name, entity="wand-tech", config=vars(args), id=run_id,
# resume="allow")
os.makedirs(f"results/{args.run_name}", exist_ok=True)
os.makedirs(f"models/{args.run_name}", exist_ok=True)
wandb.watch(model)

if parallel:
if args.distributed:
model = DistributedDataParallel(model, device_ids=[device], output_device=device)

model.train()
# images, videos = next(iter(dataset))
pbar = tqdm(enumerate(dataset, start=start_step), total=args.total_steps, initial=start_step) if args.node_id == 0 and proc_id == 0 else enumerate(dataset, start=start_step)
if is_primary(args):
pbar = tqdm(enumerate(dataset, start=start_step), total=args.total_steps, initial=start_step)
else:
pbar = enumerate(dataset, start=start_step)
# pbar = tqdm(range(1000000))
for step, (images, videos, captions) in pbar:
# for step in pbar:
Expand Down Expand Up @@ -158,7 +219,7 @@ def train(proc_id, args):
scheduler.step()
optimizer.zero_grad()

if not proc_id and args.node_id == 0:
if is_primary(args):
log = {
'loss': total_loss / (step + 1),
'curr_loss': loss.item(),
Expand All @@ -171,7 +232,7 @@ def train(proc_id, args):
pbar.set_postfix(log)
wandb.log(log)

if args.node_id == 0 and proc_id == 0 and step % args.log_period == 0:
if is_primary(args) and step % args.log_period == 0:
losses.append(total_loss / (step + 1))
accuracies.append(total_acc / (step + 1))

Expand All @@ -194,7 +255,9 @@ def train(proc_id, args):
text_tokens = text_tokens.to(device)
cool_captions_embeddings = clip_model.encode_text(text_tokens).float()

cool_captions = DataLoader(TensorDataset(cool_captions_embeddings.repeat_interleave(n, dim=0)), batch_size=1)
cool_captions = DataLoader(
TensorDataset(cool_captions_embeddings.repeat_interleave(n, dim=0)),
batch_size=1)
cool_captions_sampled = []
st = time.time()
for caption_embedding in cool_captions:
Expand Down Expand Up @@ -235,7 +298,10 @@ def train(proc_id, args):
log_table = wandb.Table(data=log_data, columns=["Caption", "Video", "Orig", "Recon"])
wandb.log({"Log": log_table})

log_data_cool = [[cool_captions_text[i]] + [wandb.Video(cool_captions_sampled[i].cpu().mul(255).add_(0.5).clamp_(0, 255).numpy())] for i in range(len(cool_captions_text))]
log_data_cool = [
[cool_captions_text[i]] +
[wandb.Video(cool_captions_sampled[i].cpu().mul(255).add_(0.5).clamp_(0, 255).numpy())]
for i in range(len(cool_captions_text))]
log_table_cool = wandb.Table(data=log_data_cool, columns=["Caption", "Video"])
wandb.log({"Log Cool": log_table_cool})

Expand All @@ -249,64 +315,13 @@ def train(proc_id, args):
torch.save(model.module.state_dict(), f"models/{args.run_name}/model.pt")
torch.save(optimizer.state_dict(), f"models/{args.run_name}/optim.pt")
torch.save(grad_scaler.state_dict(), f"models/{args.run_name}/scaler.pt")
torch.save({'step': step, 'losses': losses, 'accuracies': accuracies, 'wandb_run_id': run_id}, f"results/{args.run_name}/log.pt")


def launch(args):
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(d) for d in args.devices])
if len(args.devices) == 1:
train(0, args)
else:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "33751"
p = mp.spawn(train, nprocs=len(args.devices), args=(args,))
torch.save({
'step': step, 'losses': losses, 'accuracies': accuracies,
'wandb_run_id': run_id}, f"results/{args.run_name}/log.pt")


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.run_name = "Paella_Test_4"
args.model = "paella"
args.dataset = "second_stage"
args.urls = {
# "videos": "file:C:/Users/d6582/Documents/ml/phenaki/data/webvid/tar_files/0.tar",
# "videos": "/fsx/mas/phenaki/data/webvid/tar_files/{0..249}.tar",
"videos": "/fsx/phenaki/data/videos/tar_files/{0..1243}.tar",
# "images": "file:C:/Users/d6582/Documents/ml/paella/evaluations/laion-30k/000069.tar"
# "images": "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
"images": "/fsx/phenaki/coyo-700m/coyo-data-2/{00000..20892}.tar"
}
args.total_steps = 300_000
args.batch_size = 4
args.num_workers = 10
args.log_period = 2000
args.extra_ckpt = 10_000
args.accum_grad = 2

args.vq_path = "/fsx/phenaki/src/models/model_120000.pt" # ./models/server/vivq_8192_5_skipframes/model_100000.pt
# args.vq_path = "./models/server/vivq_8192_5_skipframes/model_100000.pt"
args.dim = 1224 # 1224
args.num_tokens = 8192
args.max_seq_len = 6 * 16 * 16
args.depth = 22 # 22
args.dim_context = 1024 # for clip, 512 for T5
args.heads = 22 # 22

args.clip_len = 10
args.skip_frames = 5

args.n_nodes = 3
# args.n_nodes = 1
args.node_id = int(os.environ["SLURM_PROCID"])
# args.node_id = 0
args.devices = [0, 1, 2, 3, 4, 5, 6, 7]
# args.devices = [0]

print("Launching with args: ", args)
launch(
args
)
main()

# device = "cuda"
# model = MaskGit(dim=args.dim, num_tokens=args.num_tokens, max_seq_len=args.max_seq_len, depth=args.depth,
Expand Down