Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**/.ipynb_checkpoints
Empty file modified cvivit.py
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion distributed.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def init_distributed_device(args):
if args.distributed and not args.no_set_device_rank:
device = 'cuda:%d' % args.local_rank
else:
device = 'cuda:0'
device = 'cuda:1'
torch.cuda.set_device(device)
else:
device = 'cpu'
Expand Down
168 changes: 168 additions & 0 deletions ipynb/log_writer.ipynb

Large diffs are not rendered by default.

382 changes: 382 additions & 0 deletions ipynb/partial_maskgit_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,382 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "470f8139-b421-4f48-8200-f23b562ab73c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/anaconda3/envs/torch_1.x/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n",
" warn(f\"Failed to load image Python extension: {e}\")\n",
"/root/anaconda3/envs/torch_1.x/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import random\n",
"import math\n",
"import gc\n",
"import os\n",
"\n",
"from torch.utils.data import Dataset, DataLoader, IterableDataset\n",
"from webdataset.handlers import warn_and_continue\n",
"from easydict import EasyDict as edict\n",
"from matplotlib import pyplot as plt\n",
"from open_clip import tokenizer\n",
"from einops import rearrange\n",
"import webdataset as wds\n",
"import numpy as np\n",
"import torchvision\n",
"import open_clip\n",
"import torch\n",
"import cv2\n",
"\n",
"from distributed import init_distributed_device, is_primary\n",
"from vivq import VIVQ, BASE_SHAPE\n",
"from utils import sample_paella\n",
"from paella import DenoiseUNet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2e5743c5-35ae-4d20-9d37-d1deb7634895",
"metadata": {},
"outputs": [],
"source": [
"gc.collect()\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3f43a5f9-9136-4aba-8dc7-893953d1bf26",
"metadata": {},
"outputs": [],
"source": [
"SEP = os.path.sep\n",
"ROOT_PATH = SEP.join(os.getcwd().split(SEP)[:-4])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f3a355cb-b491-481f-9068-836a0ae0d227",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Launching with args: {'run_name': 'Paella_Test', 'model': 'maskgit', 'dataset': 'first_stage', 'urls': {'videos': '/home/projects/DataSets/webvid/dataset/00000.tar', 'images': '/home/projects/DataSets/Coyo/coyo-700m-webdataset/{00000..03506}.tar'}, 'total_steps': 300000, 'batch_size': 4, 'num_workers': 1, 'log_period': 2000, 'extra_ckpt': 10000, 'accum_grad': 2, 'vq_path': 'models/vivq_8192_drop_video/model_250000.pt', 'du_path': 'models/Paella_Test_/model_90000.pt', 'dim': 1224, 'num_tokens': 8192, 'max_seq_len': 1536, 'depth': 22, 'dim_context': 1024, 'heads': 22, 'clip_len': 10, 'skip_frames': 5, 'n_nodes': 1, 'dist_url': 'env://', 'dist_backend': 'nccl', 'no_set_device_rank': False}\n"
]
}
],
"source": [
"args = edict({})\n",
"args.run_name = \"Paella_Test\"\n",
"args.model = \"maskgit\"\n",
"args.dataset = \"first_stage\"\n",
"args.urls = {\n",
"\n",
" \"videos\": \"/home/projects/DataSets/webvid/dataset/00000.tar\",\n",
" \"images\": \"/home/projects/DataSets/Coyo/coyo-700m-webdataset/{00000..03506}.tar\"\n",
"}\n",
"\n",
"args.total_steps = 300_000\n",
"args.batch_size = 4\n",
"args.num_workers = 1\n",
"args.log_period = 2000\n",
"args.extra_ckpt = 10_000\n",
"args.accum_grad = 2\n",
"\n",
"args.vq_path = 'models/vivq_8192_drop_video/model_250000.pt' \n",
"args.du_path = 'models/Paella_Test_/model_90000.pt'\n",
"args.dim = 1224 # 1224\n",
"args.num_tokens = 8192\n",
"args.max_seq_len = 6 * 16 * 16\n",
"args.depth = 22 # 22\n",
"args.dim_context = 1024 # for clip, 512 for T5\n",
"args.heads = 22 # 22\n",
"\n",
"args.clip_len = 10\n",
"args.skip_frames = 5\n",
"\n",
"args.n_nodes = 1\n",
"args.dist_url = \"env://\"\n",
"args.dist_backend = \"nccl\"\n",
"args.no_set_device_rank = False\n",
"\n",
"print(\"Launching with args: \", args)\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dc057466-129f-4faa-aba0-2b70ccb1ef0d",
"metadata": {},
"outputs": [],
"source": [
"class MixImageVideoDataset(IterableDataset):\n",
" def __init__(self, args):\n",
" super().__init__()\n",
" self.batch_size = args.batch_size # TODO: split this into image bs and video bs\n",
" self.video_dataset, self.image_dataset = self.init_dataloaders(args)\n",
"\n",
" def init_dataloaders(self, args):\n",
" video_dataset = wds.WebDataset(args.urls[\"videos\"], resampled=True, handler=warn_and_continue).decode(wds.torch_video,\n",
" handler=warn_and_continue).map(ProcessVideos(clip_len=args.clip_len, skip_frames=args.skip_frames),\n",
" handler=warn_and_continue).to_tuple(\"image\", \"video\", \"txt\", handler=warn_and_continue).shuffle(690, handler=warn_and_continue)\n",
" image_dataset = wds.WebDataset(args.urls[\"images\"], resampled=True, handler=warn_and_continue).decode(\"rgb\").map(\n",
" ProcessImages(), handler=warn_and_continue).to_tuple(\"jpg\", \"txt\", handler=warn_and_continue).shuffle(6969, initial=10000)\n",
" return video_dataset, image_dataset\n",
"\n",
" def __iter__(self):\n",
" sources = [iter(self.image_dataset), iter(self.video_dataset)]\n",
" # sources = [iter(self.video_dataset), iter(self.image_dataset)]\n",
" # sources = [iter(self.video_dataset)]\n",
" while True:\n",
" for source in sources:\n",
" for _ in range(self.batch_size):\n",
" try:\n",
" yield next(source)\n",
" except StopIteration:\n",
" return\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9287e3a2-6fcc-4885-833a-dfca0ed419ac",
"metadata": {},
"outputs": [],
"source": [
"def collate_first_stage(batch):\n",
" images = torch.stack([i[0] for i in batch], dim=0)\n",
" videos = torch.stack([i[1] for i in batch], dim=0)\n",
" return [images, videos]\n",
"\n",
"\n",
"def collate_second_stage(batch):\n",
"\n",
" if len(batch[0]) == 2:\n",
" images = torch.stack([i[0] for i in batch], dim=0)\n",
" videos = None\n",
" captions = [i[1] for i in batch]\n",
" else:\n",
" images = torch.stack([i[0] for i in batch], dim=0)\n",
" videos = torch.stack([i[1] for i in batch], dim=0)\n",
" captions = [i[2] for i in batch]\n",
" \n",
" return [images, videos, captions]\n",
"\n",
"\n",
"class ProcessVideos:\n",
" def __init__(self, clip_len=10, skip_frames=4):\n",
" self.video_transform = torchvision.transforms.Compose([\n",
" torchvision.transforms.Resize(128),\n",
" torchvision.transforms.RandomCrop(128)\n",
" ])\n",
" self.clip_len = clip_len\n",
" self.skip_frames = skip_frames\n",
" print(f\"Using clip length of {clip_len} and {skip_frames} skip frames.\")\n",
"\n",
" def __call__(self, data):\n",
" video = data[\"mp4\"][0]\n",
" max_seek = video.shape[0] - (self.clip_len * self.skip_frames)\n",
" if max_seek < 0:\n",
" raise Exception(f\"Video too short ({video.shape[0]} frames), skipping.\")\n",
" start = math.floor(random.uniform(0., max_seek))\n",
" video = video[start:start+(self.clip_len*self.skip_frames)+1:self.skip_frames]\n",
" video = video.permute(0, 3, 1, 2) / 255.\n",
" if self.video_transform:\n",
" video = self.video_transform(video)\n",
" image, video = video[0], video[1:]\n",
" data[\"image\"] = image\n",
" data[\"video\"] = video\n",
" if video.shape[0] != 10:\n",
" raise Exception(\"Not 10 frames. But I should find the real cause lol for this happening.\")\n",
" return data\n",
"\n",
"\n",
"class ProcessImages:\n",
" def __init__(self,):\n",
" self.transforms = torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Resize(128),\n",
" torchvision.transforms.RandomCrop(128),\n",
" ])\n",
"\n",
" def __call__(self, data):\n",
" data[\"jpg\"] = self.transforms(data[\"jpg\"])\n",
" return data\n",
"\n",
"\n",
"def get_dataloader(args):\n",
" if args.dataset == \"first_stage\":\n",
" dataset = wds.WebDataset(args.dataset_path, resampled=True, handler=warn_and_continue).decode(wds.torch_video,\n",
" handler=warn_and_continue).map(ProcessVideos(clip_len=args.clip_len, skip_frames=args.skip_frames),\n",
" handler=warn_and_continue).to_tuple(\"image\", \"video\", handler=warn_and_continue).shuffle(690, handler=warn_and_continue)\n",
" \n",
" dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_first_stage) # TODO: num_workers=args.num_workers\n",
"\n",
" elif args.dataset == \"second_stage\":\n",
" dataset = MixImageVideoDataset(args)\n",
" dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_second_stage, num_workers=args.num_workers) # TODO: num_workers=args.num_workers\n",
"\n",
" else:\n",
" dataset = VideoDataset(video_transform=transforms)\n",
" dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) # TODO: add num_workers=args.num_workers\n",
" return dataloader"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "009d46c6-e1d8-4e4c-8c6c-7b9c9dad5f86",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"모델 로딩 완.\n"
]
}
],
"source": [
"device = init_distributed_device(args)\n",
"\n",
"vqmodel = VIVQ(codebook_size=args.num_tokens, model = 'maskgit').to(device)\n",
"vqmodel.load_state_dict(torch.load(args.vq_path, map_location=device))\n",
"vqmodel.vqmodule.q_step_counter += int(1e9)\n",
"vqmodel.eval().requires_grad_(False)\n",
"\n",
"model = DenoiseUNet(num_labels = args.num_tokens, down_levels = [4, 6, 8],\n",
" up_levels = [8, 6, 4], c_clip = args.dim_context).to(device)\n",
"\n",
"# model.load_state_dict(torch.load(args.du_path, map_location = device))\n",
"\n",
"clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained = 'laion2b_s32b_b79k',\n",
" cache_dir = '/fsx/max/.cache')\n",
"\n",
"del clip_model.visual\n",
"model.eval()\n",
"clip_model = clip_model.to(device).eval().requires_grad_(False)\n",
"\n",
"print('모델 로딩 완.')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3fab0607-8715-4fe1-b8f3-deee13a789cd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using clip length of 10 and 5 skip frames.\n"
]
}
],
"source": [
"dataset = wds.WebDataset(args.urls[\"videos\"], resampled=True, handler=warn_and_continue).decode(wds.torch_video,\n",
" handler=warn_and_continue).map(ProcessVideos(clip_len=args.clip_len, skip_frames=args.skip_frames),\n",
" handler=warn_and_continue).to_tuple(\"image\", \"video\", \"txt\", handler=warn_and_continue).shuffle(690, handler=warn_and_continue)\n",
"dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_first_stage)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "db18e543-b071-4f5e-ada9-56294ee5a46c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([4, 3, 128, 128]), torch.Size([4, 10, 3, 128, 128]))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sample = next(iter(dataloader))\n",
"image, video = sample\n",
"image.size(), video.size()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "8dcd4ba1-142c-46d7-8dbd-8f384ec50f8f",
"metadata": {},
"outputs": [],
"source": [
"cool_captions_text = open('cool_captions.txt').read().splitlines()\n",
"text_tokens = tokenizer.tokenize(cool_captions_text).to(device)\n",
"text_tokens_embeddings = clip_model.encode_text(text_tokens).float() "
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f1212368-7642-4fe2-9a8b-3162fba55813",
"metadata": {},
"outputs": [],
"source": [
"cool_captions_sampled = []\n",
"for caption_embedding in text_tokens_embeddings.chunk(10):\n",
"\n",
" caption_embedding = caption_embedding[0].float().to(device)\n",
" caption_embedding = caption_embedding.unsqueeze(0)\n",
" \n",
" sampled_text = sample_paella(model, caption_embedding)\n",
" sampled_text = vqmodel.decode_indices(sampled_text)\n",
"\n",
" for s in sampled_text:\n",
" cool_captions_sampled.append(s.cpu())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f7dfe17-6941-4e9e-b44b-505d158b6112",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "torch_1.x",
"language": "python",
"name": "torch_1.x"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading