diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..336d1a22 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "editor.defaultFormatter": "ms-python.black-formatter" +} \ No newline at end of file diff --git a/README_EAMM.md b/README_EAMM.md new file mode 100644 index 00000000..1b72cc66 --- /dev/null +++ b/README_EAMM.md @@ -0,0 +1,65 @@ +# EAMM + +[Source](https://github.com/jixinya/EAMM) + +## 分工 + +高鹏 1120210558:个人完成。 + +## 准备工作 + +下载模型权重 [google-drive](https://drive.google.com/file/d/1IL9LjH3JegyMqJABqMxrX3StAq_v8Gtp/view?usp=sharing), + +将里面的三个 `.pth.tar` 文件放到 `checkpoints/EAMM/` 路径下。 + +该路径下应该包括以下8个文件:`1-6000.pth.tar`, `5-3000.pth.tar`, `124_52000.pth.tar`, `M003_template.npy`, `shape_predictor_68_face_landmarks.dat`, `mb1_120x120.pth`, `bfm_noneck_v3.pkl`, `param_mean_std_62d_120x120.pkl` 。 + +然后进入 `talkingface\utils\pose_3ddfa\FaceBoxes\utils` 目录,运行以下命令(需要安装 Cython 库): + +```bash +python3 build.py build_ext --inplace +``` + + +## 使用方法 + +**预处理** + +预处理在框架中自动完成,包括三个部分: + +- 将视频按帧导出为图片序列,大小为 (256,256,3) +- 提取音频的mfcc特征导出为 `.npy` 文件 +- 提取人脸的姿态数据导出为 `.npy` 文件 + +处理好的文件分别位于数据集中 `preprocessed_data` 目录下的 `crop`, `MEAD_MFCC` 和 `pose`。 + + +**训练** + +在作者开源的代码中,训练所需的数据集和模型部分,存在一些错误,包括: + +- 数据集构建的代码中包括作者本机的文件地址,文档或注释中也并未说明这些文件是怎么来的,需要哪些处理,自己尝试复现可能需要花费大量时间和算力,故而放弃 +![dataset_fault](./saved/eamm/dataset_fault.png) +- 训练所需的模型代码中存在大量变量未定义的报错,应该是没有将完整的代码传上去 + +综上所述,最后未能复现训练部分。但是,训练相关的代码均已补充完整,后面可以在此基础上尝试填补上述缺漏。 + +**推理** + +```bash +python .\run_talkingface.py -m EAMM -d mead --config_files '.\talkingface\properties\overall.yaml .\talkingface\properties\model\EAMM.yaml .\talkingface\properties\model\EAMM\evaluate.yaml' --evaluate_model_file .\talkingface\properties\model\EAMM\evaluate.yaml +``` + +推理过程的配置主要在 `evaluate.yaml` 中,主要的配置包括: + +- **source_image**:生成结果所使用的人脸图片 +- **driving_video**:生成结果所使用的表情视频 +- **in_file**:生成结果所使用的驱动语音 +- **pose_file**:生成结果所使用的人脸的姿态数据 +- **emotion**:生成的表情,包括 'angry', 'contempt','disgusted','fear','happy','neutral','sad','surprised' + +据论文所述,上述文件均不必来自同一源视频。 + +生成的结果位于 `saved\eamm\output\` 下,`mp4` 是不带声音的,`mov` 是带声音的,`emotion.mov` 是生成结果,`all.mov` 是和驱动视频的效果对比。 + +![result](./saved/eamm/output/all.gif) \ No newline at end of file diff --git a/checkpoints/EAMM/M003_template.npy b/checkpoints/EAMM/M003_template.npy new file mode 100644 index 00000000..23d507d2 Binary files /dev/null and b/checkpoints/EAMM/M003_template.npy differ diff --git a/checkpoints/EAMM/bfm_noneck_v3.pkl b/checkpoints/EAMM/bfm_noneck_v3.pkl new file mode 100644 index 00000000..7a44e1e2 Binary files /dev/null and b/checkpoints/EAMM/bfm_noneck_v3.pkl differ diff --git a/checkpoints/EAMM/param_mean_std_62d_120x120.pkl b/checkpoints/EAMM/param_mean_std_62d_120x120.pkl new file mode 100644 index 00000000..110c0f3e Binary files /dev/null and b/checkpoints/EAMM/param_mean_std_62d_120x120.pkl differ diff --git a/checkpoints/EAMM/shape_predictor_68_face_landmarks.dat b/checkpoints/EAMM/shape_predictor_68_face_landmarks.dat new file mode 100644 index 00000000..e0ec20d6 Binary files /dev/null and b/checkpoints/EAMM/shape_predictor_68_face_landmarks.dat differ diff --git a/saved/eamm/dataset_fault.png b/saved/eamm/dataset_fault.png new file mode 100644 index 00000000..dcbcd2f4 Binary files /dev/null and b/saved/eamm/dataset_fault.png differ diff --git a/saved/eamm/output/all.gif b/saved/eamm/output/all.gif new file mode 100644 index 00000000..0b2284c9 Binary files /dev/null and b/saved/eamm/output/all.gif differ diff --git a/saved/eamm/output/all.mov b/saved/eamm/output/all.mov new file mode 100644 index 00000000..3ed1389a Binary files /dev/null and b/saved/eamm/output/all.mov differ diff --git a/saved/eamm/output/all.mp4 b/saved/eamm/output/all.mp4 new file mode 100644 index 00000000..fd556475 Binary files /dev/null and b/saved/eamm/output/all.mp4 differ diff --git a/saved/eamm/output/emotion.mov b/saved/eamm/output/emotion.mov new file mode 100644 index 00000000..d6760d35 Binary files /dev/null and b/saved/eamm/output/emotion.mov differ diff --git a/saved/eamm/output/emotion.mp4 b/saved/eamm/output/emotion.mp4 new file mode 100644 index 00000000..4a1b0fc9 Binary files /dev/null and b/saved/eamm/output/emotion.mp4 differ diff --git a/saved/eamm/output/neutral.mp4 b/saved/eamm/output/neutral.mp4 new file mode 100644 index 00000000..06315549 Binary files /dev/null and b/saved/eamm/output/neutral.mp4 differ diff --git a/talkingface/data/dataset/eamm_dataset.py b/talkingface/data/dataset/eamm_dataset.py new file mode 100644 index 00000000..8bc09e62 --- /dev/null +++ b/talkingface/data/dataset/eamm_dataset.py @@ -0,0 +1,382 @@ +import os +import random +from glob import glob +from os.path import basename, dirname, isfile, join + +import cv2 +import librosa +import numpy as np +import python_speech_features +import torch +import torch.backends.cudnn as cudnn +from imageio import mimread +from skimage import img_as_float32, io, transform +from skimage.color import gray2rgb +from sklearn.model_selection import train_test_split +from torch import nn, optim +from torch.utils import data as data_utils +from tqdm import tqdm + +from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio +from talkingface.data.dataset.dataset import Dataset +from talkingface.utils.augmentation import AllAugmentationTransform +from talkingface.utils.filter1 import OneEuroFilter + + +class EAMMDataset(Dataset): + from pathlib import Path + + def __init__(self, config, datasplit): + super().__init__(config, datasplit) + self.type = config['dataset_name'] + if self.config['train']: + self._build_dataset() + else: + self.videos = np.random.rand(10,256,256,3) #! 伪造数据,以便跑通流程 + + def _build_dataset(self): + if self.type == 'Vox': + self._init_vox() + elif self.type == 'LRW': + self._init_lrw() + elif self.type == 'MEAD': + self._init_mead() + + def _init_vox(self): + self.root_dir = self.config['dataset_root_dir'] + self.audio_dir = os.path.join(self.root_dir,'MFCC') + self.image_dir = os.path.join(self.root_dir,'align_img') + self.pose_dir = os.path.join(self.root_dir,'align_pose') + + assert len(os.listdir(self.audio_dir)) == len(os.listdir(self.image_dir)), 'audio and image length not equal' + + #! 作者没有告知这里是什么,如何得到 + self.videos=np.load('/mnt/lustre/share_data/jixinya/VoxCeleb1_Cut/right.npy') + self.frame_shape = tuple(self.config['dataset_frame_shape']) + self.pairs_list = None + self.id_sampling = self.config['dataset_id_sampling'] + + if os.path.exists(os.path.join(self.pose_dir, 'train_fo')): + assert os.path.exists(os.path.join(self.pose_dir, 'test_fo')) + print("Use predefined train-test split.") + if self.id_sampling: + train_videos = {os.path.basename(video).split('#')[0] for video in + os.listdir(os.path.join(self.image_dir, 'train'))} + train_videos = list(train_videos) + else: + train_videos = np.load('/mnt/lustre/share_data/jixinya/VoxCeleb1_Cut/right.npy')# get_list(self.pose_dir, 'train_fo') + + self.image_dir = os.path.join(self.image_dir, 'train_fo' if self.datasplit else 'test_fo') + self.audio_dir = os.path.join(self.audio_dir, 'train' if self.datasplit else 'test') + self.pose_dir = os.path.join(self.pose_dir, 'train_fo' if self.datasplit else 'test_fo') + else: + print("Use random train-test split.") + train_videos, test_videos = train_test_split(self.videos, random_state=self.config['seed'], test_size=0.2) + + if self.datasplit: + self.videos = train_videos + else: + self.videos = test_videos + + self.is_train = self.datasplit + + if self.is_train: + self.transform = AllAugmentationTransform( + flip_param={ + 'horizontal_flip': self.config['dataset_augmentation_flip_horizontal_flip'], + 'time_flip': self.config['dataset_augmentation_flip_time_flip'], + }, + jitter_param={ + 'brightness': self.config['dataset_augmentation_jitter_brightness'], + 'contrast': self.config['dataset_augmentation_jitter_contrast'], + 'saturation': self.config['dataset_augmentation_jitter_saturation'], + 'hue': self.config['dataset_augmentation_jitter_hue'], + } + ) + else: + self.transform = None + + def _init_lrw(self): + self.root_dir = self.config['dataset_root_dir'] + self.audio_dir = os.path.join(self.root_dir,'MFCC') + self.image_dir = os.path.join(self.root_dir,'Image') + self.pose_dir = os.path.join(self.root_dir,'pose') + assert len(os.listdir(self.audio_dir)) == len(os.listdir(self.image_dir)), 'audio and image length not equal' + + self.frame_shape = tuple(self.config['dataset_frame_shape']) + + self.id_sampling = self.config['dataset_id_sampling'] + + if os.path.exists(os.path.join(self.pose_dir, 'train_fo')): + assert os.path.exists(os.path.join(self.pose_dir, 'test_fo')) + print("Use predefined train-test split.") + if self.id_sampling: + train_videos = {os.path.basename(video).split('#')[0] for video in + os.listdir(os.path.join(self.image_dir, 'train'))} + train_videos = list(train_videos) + else: + train_videos = np.load('../LRW/list/train_fo.npy')# get_list(self.pose_dir, 'train_fo') + test_videos=np.load('../LRW/list/test_fo.npy') + + self.image_dir = os.path.join(self.image_dir, 'train_fo' if self.datasplit else 'test_fo') + self.audio_dir = os.path.join(self.audio_dir, 'train' if self.datasplit else 'test') + self.pose_dir = os.path.join(self.pose_dir, 'train_fo' if self.datasplit else 'test_fo') + else: + print("Use random train-test split.") + train_videos, test_videos = train_test_split(self.videos, random_state=self.config['seed'], test_size=0.2) + + if self.datasplit: + self.videos = train_videos + else: + self.videos = test_videos + + self.is_train = self.datasplit + + if self.is_train: + self.transform = AllAugmentationTransform( + flip_param={ + 'horizontal_flip': self.config['dataset_augmentation_flip_horizontal_flip'], + 'time_flip': self.config['dataset_augmentation_flip_time_flip'], + }, + jitter_param={ + 'brightness': self.config['dataset_augmentation_jitter_brightness'], + 'contrast': self.config['dataset_augmentation_jitter_contrast'], + 'saturation': self.config['dataset_augmentation_jitter_saturation'], + 'hue': self.config['dataset_augmentation_jitter_hue'], + } + ) + else: + self.transform = None + + def _init_mead(self): + self.root_dir = self.config['dataset_root_dir'] + self.audio_dir = os.path.join(self.root_dir,'MEAD_MFCC') + self.image_dir = os.path.join(self.root_dir,'MEAD_fomm_crop') + self.pose_dir = os.path.join(self.root_dir,'MEAD_fomm_pose_crop') + self.videos = np.load('/mnt/lustre/share_data/jixinya/MEAD/MEAD_fomm_audio_less_crop.npy') + self.dict = np.load('/mnt/lustre/share_data/jixinya/MEAD/MEAD_fomm_neu_dic_crop.npy',allow_pickle=True).item() + self.frame_shape = tuple(self.config['dataset_frame_shape']) + self.id_sampling = self.config['dataset_id_sampling'] + if os.path.exists(os.path.join(self.root_dir, 'train')): + assert os.path.exists(os.path.join(self.root_dir, 'test')) + print("Use predefined train-test split.") + if self.id_sampling: + train_videos = {os.path.basename(video).split('#')[0] for video in + os.listdir(os.path.join(self.root_dir, 'train'))} + train_videos = list(train_videos) + else: + train_videos = os.listdir(os.path.join(self.root_dir, 'train')) + test_videos = os.listdir(os.path.join(self.root_dir, 'test')) + self.root_dir = os.path.join(self.root_dir, 'train' if self.datasplit else 'test') + else: + print("Use random train-test split.") + train_videos, test_videos = train_test_split(self.videos, random_state=self.config['seed'], test_size=0.2) + if self.datasplit: + self.videos = train_videos + else: + self.videos = test_videos + self.is_train = self.datasplit + if self.is_train: + self.transform = AllAugmentationTransform( + crop_mouth_param={ + 'center_x': self.config['dataset_augmentation_crop_mouth_center_x'], + 'center_y': self.config['dataset_augmentation_crop_mouth_center_y'], + 'mask_width': self.config['dataset_augmentation_crop_mouth_mask_width'], + 'mask_height': self.config['dataset_augmentation_crop_mouth_mask_height'], + }, + rotation_param={ + 'degrees': self.config['dataset_augmentation_rotation_degrees'], + }, + perspective_param={ + 'pers_num': self.config['dataset_augmentation_perspective_pers_num'], + 'enlarge_num': self.config['dataset_augmentation_perspective_enlarge_num'], + }, + flip_param={ + 'horizontal_flip': self.config['dataset_augmentation_flip_horizontal_flip'], + 'time_flip': self.config['dataset_augmentation_flip_time_flip'], + }, + jitter_param={ + 'brightness': self.config['dataset_augmentation_jitter_brightness'], + 'contrast': self.config['dataset_augmentation_jitter_contrast'], + 'saturation': self.config['dataset_augmentation_jitter_saturation'], + 'hue': self.config['dataset_augmentation_jitter_hue'], + } + ) + else: + self.transform = None + + def __len__(self): + return len(self.videos) + + def __getitem__(self, idx): + if self.type == 'Vox': + return self._getitem_vox(idx) + elif self.type == 'LRW': + return self._getitem_lrw(idx) + elif self.type == 'MEAD': + return self._getitem_mead(idx) + + def _getitem_vox(self, idx): + if self.is_train and self.id_sampling: + name = self.videos[idx].split('.')[0] + path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4'))) + else: + name = self.videos[idx].split('.')[0] + + audio_path = os.path.join(self.audio_dir, name+'.npy') + pose_path = os.path.join(self.pose_dir,name+'.npy') + path = os.path.join(self.image_dir, name) + + video_name = os.path.basename(path) + if os.path.isdir(path): + frames = os.listdir(path) + num_frames = len(frames) + frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) + video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx] + mfcc = np.load(audio_path) + pose = np.load(pose_path) + + try: + len(mfcc) > 16 + except: + print('wrongmfcc len:',audio_path) + if 16 < len(mfcc) < 24 : + r = 0 + else: + r = random.choice([x for x in range(3, len(mfcc)-20)]) + + mfccs = [] + poses = [] + video_array = [] + for ind in range(1, 17): + t_mfcc = mfcc[r+ind][:, 1:] + mfccs.append(t_mfcc) + t_pose = pose[r+ind,:-1] + poses.append(t_pose) + image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png'))) + video_array.append(image) + mfccs = np.array(mfccs) + poses = np.array(poses) + video_array = np.array(video_array) + example_image = img_as_float32(io.imread(os.path.join(path, str(r)+'.png'))) + else: + print('Wrong, data path not an existing file.') + + if self.transform is not None: + video_array = self.transform(video_array) + + out = {} + driving = np.array(video_array, dtype='float32') + spatial_size = np.array(driving.shape[1:3][::-1])[np.newaxis] + driving_pose = np.array(poses, dtype='float32') + example_image = np.array(example_image, dtype='float32') + out['example_image'] = example_image.transpose((2, 0, 1)) + out['driving_pose'] = driving_pose + out['driving'] = driving.transpose((0, 3, 1, 2)) + out['driving_audio'] = np.array(mfccs, dtype='float32') + return out + + def _getitem_lrw(self, idx): + if self.is_train and self.id_sampling: + name = self.videos[idx].split('.')[0] + path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4'))) + else: + name = self.videos[idx].split('.')[0] + audio_path = os.path.join(self.audio_dir, name) + pose_path = os.path.join(self.pose_dir,name) + path = os.path.join(self.image_dir, name) + video_name = os.path.basename(path) + if os.path.isdir(path): + # mfcc loading + r = random.choice([x for x in range(3, 8)]) + example_image = img_as_float32(io.imread(os.path.join(path, str(r)+'.png'))) + mfccs = [] + for ind in range(1, 17): + # t_mfcc = mfcc[(r + ind - 3) * 4: (r + ind + 4) * 4, 1:] + t_mfcc = np.load(os.path.join(audio_path,str(r + ind)+'.npy'),allow_pickle=True)[:, 1:] + mfccs.append(t_mfcc) + mfccs = np.array(mfccs) + poses = [] + video_array = [] + for ind in range(1, 17): + t_pose = np.load(os.path.join(self.pose_dir,name+'.npy'))[r+ind,:-1] + poses.append(t_pose) + image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png'))) + video_array.append(image) + poses = np.array(poses) + video_array = np.array(video_array) + else: + print('Wrong, data path not an existing file.') + if self.transform is not None: + video_array = self.transform(video_array) + out = {} + driving = np.array(video_array, dtype='float32') + spatial_size = np.array(driving.shape[1:3][::-1])[np.newaxis] + driving_pose = np.array(poses, dtype='float32') + example_image = np.array(example_image, dtype='float32') + out['example_image'] = example_image.transpose((2, 0, 1)) + out['driving_pose'] = driving_pose + out['driving'] = driving.transpose((0, 3, 1, 2)) + out['driving_audio'] = np.array(mfccs, dtype='float32') + return out + + def _getitem_mead(self, idx): + if self.is_train and self.id_sampling: + name = self.videos[idx] + path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4'))) + else: + name = self.videos[idx] + path = os.path.join(self.image_dir, name) + video_name = os.path.basename(path) + id_name = path.split('/')[-2] + neu_list = self.dict[id_name] + neu_path = os.path.join(self.image_dir, np.random.choice(neu_list)) + audio_path = os.path.join(self.audio_dir, name+'.npy') + pose_path = os.path.join(self.pose_dir,name+'.npy') + if self.is_train and os.path.isdir(path): + mfcc = np.load(audio_path) + pose_raw = np.load(pose_path) + one_euro_filter = OneEuroFilter(mincutoff=0.01, beta=0.7, dcutoff=1.0, freq=100) + pose = np.zeros((len(pose_raw),7)) + for j in range(len(pose_raw)): + pose[j]=one_euro_filter.process(pose_raw[j]) + neu_frames = os.listdir(neu_path) + num_neu_frames = len(neu_frames) + frame_idx = np.random.choice(num_neu_frames) + example_image = img_as_float32(io.imread(os.path.join(neu_path, neu_frames[frame_idx]))) + try: + len(mfcc) > 16 + except: + print('wrongmfcc len:',audio_path) + if 16 < len(mfcc) < 24 : + r = 0 + else: + r = random.choice([x for x in range(3, len(mfcc)-20)]) + mfccs = [] + poses = [] + video_array = [] + for ind in range(1, 17): + t_mfcc = mfcc[r+ind][:, 1:] + mfccs.append(t_mfcc) + t_pose = pose[r+ind,:-1] + poses.append(t_pose) + image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png'))) + video_array.append(image) + mfccs = np.array(mfccs) + poses = np.array(poses) + video_array = np.array(video_array) + else: + print('Wrong, data path not an existing file.') + if self.transform is not None: + video_array = self.transform(video_array) + out = {} + if self.is_train: + driving = np.array(video_array, dtype='float32') + driving_pose = np.array(poses, dtype='float32') + example_image = np.array(example_image, dtype='float32') + out['example_image'] = example_image.transpose((2, 0, 1)) + out['driving_pose'] = driving_pose + out['driving'] = driving.transpose((0, 3, 1, 2)) + out['driving_audio'] = np.array(mfccs, dtype='float32') + return out + diff --git a/talkingface/model/audio_driven_talkingface/eamm.py b/talkingface/model/audio_driven_talkingface/eamm.py new file mode 100644 index 00000000..711968ea --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.nn.init import xavier_normal_, constant_ +from tqdm import tqdm +from os import listdir, path +import numpy as np +import os, subprocess +from glob import glob +import cv2 +from talkingface.model.layers import Conv2d, Conv2dTranspose, nonorm_Conv2d +from talkingface.model.abstract_talkingface import AbstractTalkingFace +from talkingface.data.dataprocess.wav2lip_process import Wav2LipPreprocessForInference, Wav2LipAudio +from talkingface.utils import ensure_dir +from talkingface.model.audio_driven_talkingface.eamm_modules.generator import OcclusionAwareGenerator +from talkingface.model.audio_driven_talkingface.eamm_modules.discriminator import MultiScaleDiscriminator +from talkingface.model.audio_driven_talkingface.eamm_modules.keypoint_detector import KPDetector, Audio_Feature, KPDetector_a +from talkingface.model.audio_driven_talkingface.eamm_modules.util import AT_net, Emotion_k + +class EAMM(AbstractTalkingFace): + def __init__(self, config): + super().__init__() + self.config = config + print("init EAMM module") + self._build_model() + + def forward(self): + pass + + def predict(self): + pass + + def calculate_loss(self, interaction, valid=False): + pass + + def generate_batch(self): + pass + + def _build_model(self): + self.generator = OcclusionAwareGenerator( + num_kp=self.config['model_common_num_kp'], + num_channels=self.config['model_common_num_channels'], + estimate_jacobian=self.config['model_common_estimate_jacobian'], + block_expansion=self.config['model_generator_block_expansion'], + max_features=self.config['model_generator_max_features'], + num_down_blocks=self.config['model_generator_num_down_blocks'], + num_bottleneck_blocks=self.config['model_generator_num_bottleneck_blocks'], + estimate_occlusion_map=self.config['model_generator_estimate_occlusion_map'], + dense_motion_params={ + 'block_expansion': self.config['model_generator_dense_motion_block_expansion'], + 'max_features': self.config['model_generator_dense_motion_max_features'], + 'num_blocks': self.config['model_generator_dense_motion_num_blocks'], + 'scale_factor': self.config['model_generator_dense_motion_scale_factor'], + }, + ) + + if self.config['use_gpu'] and torch.cuda.is_available(): + self.generator.to(self.config['device_ids'][0]) + else: + self.generator.cpu() + + if self.config['verbose']: + print(self.generator) + + self.discriminator = MultiScaleDiscriminator( + num_kp=self.config['model_common_num_kp'], + num_channels=self.config['model_common_num_channels'], + estimate_jacobian=self.config['model_common_estimate_jacobian'], + scales=self.config['model_discriminator_scales'], + block_expansion=self.config['model_discriminator_block_expansion'], + max_features=self.config['model_discriminator_max_features'], + num_blocks=self.config['model_discriminator_num_blocks'], + sn=self.config['model_discriminator_sn'], + ) + if self.config['use_gpu'] and torch.cuda.is_available(): + self.discriminator.to(self.config['device_ids'][0]) + else: + self.discriminator.cpu() + + if self.config['verbose']: + print(self.discriminator) + + self.kp_detector = KPDetector( + num_kp=self.config['model_common_num_kp'], + num_channels=self.config['model_common_num_channels'], + estimate_jacobian=self.config['model_common_estimate_jacobian'], + temperature=self.config['model_kp_detector_temperature'], + block_expansion=self.config['model_kp_detector_block_expansion'], + max_features=self.config['model_kp_detector_max_features'], + scale_factor=self.config['model_kp_detector_scale_factor'], + num_blocks=self.config['model_kp_detector_num_blocks'], + ) + + self.kp_detector_a = KPDetector_a( + num_kp=self.config['model_audio_num_kp'], + num_channels=self.config['model_audio_num_channels'], + num_channels_a=self.config['model_audio_num_channels_a'], + estimate_jacobian=self.config['model_common_estimate_jacobian'], + temperature=self.config['model_kp_detector_temperature'], + block_expansion=self.config['model_kp_detector_block_expansion'], + max_features=self.config['model_kp_detector_max_features'], + scale_factor=self.config['model_kp_detector_scale_factor'], + num_blocks=self.config['model_kp_detector_num_blocks'], + ) + + if self.config['use_gpu'] and torch.cuda.is_available(): + self.kp_detector.to(self.config['device_ids'][0]) + self.kp_detector_a.to(self.config['device_ids'][0]) + else: + self.kp_detector.cpu() + self.kp_detector_a.cpu() + + self.audio_feature = AT_net() + self.emo_feature = Emotion_k( + block_expansion=32, + num_channels=3, + max_features=1024, + num_blocks=5, + scale_factor=0.25, + num_classes=8, + ) + + if self.config['use_gpu'] and torch.cuda.is_available(): + self.audio_feature.to(self.config['device_ids'][0]) + self.emo_feature.to(self.config['device_ids'][0]) + else: + self.audio_feature.cpu() + self.emo_feature.cpu() + + if self.config['verbose']: + print(self.kp_detector) + print(self.kp_detector_a) + print(self.audio_feature) + print(self.emo_feature) + diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/__init__.py b/talkingface/model/audio_driven_talkingface/eamm_modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/batchnorm.py b/talkingface/model/audio_driven_talkingface/eamm_modules/batchnorm.py new file mode 100644 index 00000000..5f4e763f --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/batchnorm.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/comm.py b/talkingface/model/audio_driven_talkingface/eamm_modules/comm.py new file mode 100644 index 00000000..922f8c4a --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/dense_motion.py b/talkingface/model/audio_driven_talkingface/eamm_modules/dense_motion.py new file mode 100644 index 00000000..a74b3e9e --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/dense_motion.py @@ -0,0 +1,157 @@ +from torch import nn +import torch.nn.functional as F +import torch +from .util import ( + Hourglass, + AntiAliasInterpolation2d, + make_coordinate_grid, + kp2gaussian, +) + + +class DenseMotionNetwork(nn.Module): + """ + Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving + """ + + def __init__( + self, + block_expansion, + num_blocks, + max_features, + num_kp, + num_channels, + estimate_occlusion_map=False, + scale_factor=1, + kp_variance=0.01, + ): + super(DenseMotionNetwork, self).__init__() + self.hourglass = Hourglass( + block_expansion=block_expansion, + in_features=(num_kp + 1) * (num_channels + 1), + max_features=max_features, + num_blocks=num_blocks, + ) + + self.mask = nn.Conv2d( + self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3) + ) + + if estimate_occlusion_map: + self.occlusion = nn.Conv2d( + self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3) + ) + else: + self.occlusion = None + + self.num_kp = num_kp + self.scale_factor = scale_factor + self.kp_variance = kp_variance + + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) + + def create_heatmap_representations(self, source_image, kp_driving, kp_source): + """ + Eq 6. in the paper H_k(z) + """ + spatial_size = source_image.shape[2:] + gaussian_driving = kp2gaussian( + kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance + ) + gaussian_source = kp2gaussian( + kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance + ) + heatmap = gaussian_driving - gaussian_source # [4,10,H,W] + + # adding background feature + zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type( + heatmap.type() + ) + heatmap = torch.cat([zeros, heatmap], dim=1) + heatmap = heatmap.unsqueeze(2) # [4,11,1,h,w] + return heatmap + + def create_sparse_motions(self, source_image, kp_driving, kp_source): + """ + Eq 4. in the paper T_{s<-d}(z) + """ + bs, _, h, w = source_image.shape + identity_grid = make_coordinate_grid((h, w), type=kp_source["value"].type()) + identity_grid = identity_grid.view(1, 1, h, w, 2) + coordinate_grid = identity_grid - kp_driving["value"].view( + bs, self.num_kp, 1, 1, 2 + ) # [4,10,64,64,2] + if "jacobian" in kp_driving: + jacobian = torch.matmul( + kp_source["jacobian"], torch.inverse(kp_driving["jacobian"]) + ) + jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) + jacobian = jacobian.repeat(1, 1, h, w, 1, 1) + coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) + coordinate_grid = coordinate_grid.squeeze(-1) + + driving_to_source = coordinate_grid + kp_source["value"].view( + bs, self.num_kp, 1, 1, 2 + ) + + # adding background feature + identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) + sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) + return sparse_motions + + def create_deformed_source_image(self, source_image, sparse_motions): + """ + Eq 7. in the paper \hat{T}_{s<-d}(z) + """ + bs, _, h, w = source_image.shape + source_repeat = ( + source_image.unsqueeze(1) + .unsqueeze(1) + .repeat(1, self.num_kp + 1, 1, 1, 1, 1) + ) + source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w) + sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1)) + sparse_deformed = F.grid_sample(source_repeat, sparse_motions) + sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w)) + return sparse_deformed + + def forward(self, source_image, kp_driving, kp_source): + if self.scale_factor != 1: + source_image = self.down(source_image) # [4,3,H*scale,W*scale] + + bs, _, h, w = source_image.shape + + out_dict = dict() + heatmap_representation = self.create_heatmap_representations( + source_image, kp_driving, kp_source + ) # [4,11,1,64,64] + sparse_motion = self.create_sparse_motions( + source_image, kp_driving, kp_source + ) # [4,11,64,64,2] + deformed_source = self.create_deformed_source_image( + source_image, sparse_motion + ) # [4,11,3,64,64] + out_dict["sparse_deformed"] = deformed_source + + input = torch.cat([heatmap_representation, deformed_source], dim=2) + input = input.view(bs, -1, h, w) # [4,11*4,64,64] + + prediction = self.hourglass(input) # [4,108,64,64] + + mask = self.mask(prediction) + mask = F.softmax(mask, dim=1) # [4,11,64,64] + out_dict["mask"] = mask + mask = mask.unsqueeze(2) + sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3) + deformation = (sparse_motion * mask).sum(dim=1) + deformation = deformation.permute(0, 2, 3, 1) # [4,64,64,2] + + out_dict["deformation"] = deformation + + # Sec. 3.2 in the paper + if self.occlusion: + occlusion_map = torch.sigmoid(self.occlusion(prediction)) + out_dict["occlusion_map"] = occlusion_map # [4,1,64,64] + + return out_dict diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/discriminator.py b/talkingface/model/audio_driven_talkingface/eamm_modules/discriminator.py new file mode 100644 index 00000000..07c1692a --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/discriminator.py @@ -0,0 +1,95 @@ +from torch import nn +import torch.nn.functional as F +from .util import kp2gaussian +import torch + + +class DownBlock2d(nn.Module): + """ + Simple block for processing video (encoder). + """ + + def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) + + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + + if norm: + self.norm = nn.InstanceNorm2d(out_features, affine=True) + else: + self.norm = None + self.pool = pool + + def forward(self, x): + out = x + out = self.conv(out) + if self.norm: + out = self.norm(out) + out = F.leaky_relu(out, 0.2) + if self.pool: + out = F.avg_pool2d(out, (2, 2)) + return out + + +class Discriminator(nn.Module): + """ + Discriminator similar to Pix2Pix + """ + + def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, + sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs): + super(Discriminator, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append( + DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) + + self.down_blocks = nn.ModuleList(down_blocks) + self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + self.use_kp = use_kp + self.kp_variance = kp_variance + + def forward(self, x, kp=None): + feature_maps = [] + out = x + if self.use_kp: + heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance) + out = torch.cat([out, heatmap], dim=1) + + for down_block in self.down_blocks: + feature_maps.append(down_block(out)) + out = feature_maps[-1] + prediction_map = self.conv(out) + + return feature_maps, prediction_map + + +class MultiScaleDiscriminator(nn.Module): + """ + Multi-scale (scale) discriminator + """ + + def __init__(self, scales=(), **kwargs): + super(MultiScaleDiscriminator, self).__init__() + self.scales = scales + discs = {} + for scale in scales: + discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) + self.discs = nn.ModuleDict(discs) + + def forward(self, x, kp=None): + out_dict = {} + for scale, disc in self.discs.items(): + scale = str(scale).replace('-', '.') + key = 'prediction_' + scale + feature_maps, prediction_map = disc(x[key], kp) + out_dict['feature_maps_' + scale] = feature_maps + out_dict['prediction_map_' + scale] = prediction_map + return out_dict diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/function.py b/talkingface/model/audio_driven_talkingface/eamm_modules/function.py new file mode 100644 index 00000000..d7ce0f4c --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/function.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Thu Sep 30 17:45:24 2021 + +@author: SENSETIME\jixinya1 +""" + +import torch + + +def calc_mean_std(feat, eps=1e-5): + # eps is a small value added to the variance to avoid divide-by-zero. + size = feat.size() + assert (len(size) == 4) + N, C = size[:2] + feat_var = feat.view(N, C, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(N, C, 1, 1) + feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + assert (content_feat.size()[:2] == style_feat.size()[:2]) + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + + normalized_feat = (content_feat - content_mean.expand( + size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +def _calc_feat_flatten_mean_std(feat): + # takes 3D feat (C, H, W), return mean and std of array within channels + assert (feat.size()[0] == 3) + assert (isinstance(feat, torch.FloatTensor)) + feat_flatten = feat.view(3, -1) + mean = feat_flatten.mean(dim=-1, keepdim=True) + std = feat_flatten.std(dim=-1, keepdim=True) + return feat_flatten, mean, std + + +def _mat_sqrt(x): + U, D, V = torch.svd(x) + return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) + + +def coral(source, target): + # assume both source and target are 3D array (C, H, W) + # Note: flatten -> f + + source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source) + source_f_norm = (source_f - source_f_mean.expand_as( + source_f)) / source_f_std.expand_as(source_f) + source_f_cov_eye = \ + torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) + + target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target) + target_f_norm = (target_f - target_f_mean.expand_as( + target_f)) / target_f_std.expand_as(target_f) + target_f_cov_eye = \ + torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) + + source_f_norm_transfer = torch.mm( + _mat_sqrt(target_f_cov_eye), + torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), + source_f_norm) + ) + + source_f_transfer = source_f_norm_transfer * \ + target_f_std.expand_as(source_f_norm) + \ + target_f_mean.expand_as(source_f_norm) + + return source_f_transfer.view(source.size()) \ No newline at end of file diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/generator.py b/talkingface/model/audio_driven_talkingface/eamm_modules/generator.py new file mode 100644 index 00000000..a31155d6 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/generator.py @@ -0,0 +1,97 @@ +import torch +from torch import nn +import torch.nn.functional as F +from .util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d +from .dense_motion import DenseMotionNetwork + + +class OcclusionAwareGenerator(nn.Module): + """ + Generator that given source image and and keypoints try to transform image according to movement trajectories + induced by keypoints. Generator follows Johnson architecture. + """ + + def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks, + num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): + super(OcclusionAwareGenerator, self).__init__() + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params) + else: + self.dense_motion_network = None + + self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + up_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i))) + out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1))) + up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.up_blocks = nn.ModuleList(up_blocks) + + self.bottleneck = torch.nn.Sequential() + in_features = min(max_features, block_expansion * (2 ** num_down_blocks)) + for i in range(num_bottleneck_blocks): + self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) + + self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) + self.estimate_occlusion_map = estimate_occlusion_map + self.num_channels = num_channels + + def deform_input(self, inp, deformation): + _, h_old, w_old, _ = deformation.shape + _, _, h, w = inp.shape + if h_old != h or w_old != w: + deformation = deformation.permute(0, 3, 1, 2) + deformation = F.interpolate(deformation, size=(h, w), mode='bilinear') + deformation = deformation.permute(0, 2, 3, 1) + return F.grid_sample(inp, deformation) + + def forward(self, source_image, kp_driving, kp_source): + # Encoding (downsampling) part + out = self.first(source_image) #[4,64,H,W] + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) #[4,256,H/4,W/4] + + # Transforming feature representation according to deformation and occlusion + output_dict = {} + if self.dense_motion_network is not None: + dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving, + kp_source=kp_source) + output_dict['mask'] = dense_motion['mask'] + output_dict['sparse_deformed'] = dense_motion['sparse_deformed'] + + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] + output_dict['occlusion_map'] = occlusion_map + else: + occlusion_map = None + deformation = dense_motion['deformation'] + out = self.deform_input(out, deformation) + + if occlusion_map is not None: + if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: + occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') + out = out * occlusion_map + + output_dict["deformed"] = self.deform_input(source_image, deformation) + + # Decoding part + out = self.bottleneck(out) #[4,256,64,64] + for i in range(len(self.up_blocks)): + out = self.up_blocks[i](out) + out = self.final(out) + out = torch.sigmoid(out) #[4,3,256,256] + + output_dict["prediction"] = out + + return output_dict diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/keypoint_detector.py b/talkingface/model/audio_driven_talkingface/eamm_modules/keypoint_detector.py new file mode 100644 index 00000000..0bcf2ffc --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/keypoint_detector.py @@ -0,0 +1,259 @@ +from torch import nn +import torch +import torch.nn.functional as F +from .util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d, Ct_encoder, EmotionNet, AF2F, AF2F_s, draw_heatmap + + +class KPDetector(nn.Module): + """ + Detecting a keypoints. Return keypoint position and jacobian near each keypoint. + """ + + def __init__(self, block_expansion, num_kp, num_channels, max_features, + num_blocks, temperature, estimate_jacobian=False, scale_factor=1, + single_jacobian_map=False, pad=0): + super(KPDetector, self).__init__() + + self.predictor = Hourglass(block_expansion, in_features=num_channels, + max_features=max_features, num_blocks=num_blocks) + + self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7), + padding=pad) + + if estimate_jacobian: + self.num_jacobian_maps = 1 if single_jacobian_map else num_kp + self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, + out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad) + self.jacobian.weight.data.zero_() + self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) + else: + self.jacobian = None + + self.temperature = temperature + self.scale_factor = scale_factor + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) + + + + + def gaussian2kp(self, heatmap): + """ + Extract the mean and from a heatmap + """ + shape = heatmap.shape + heatmap = heatmap.unsqueeze(-1) #[4,10,58,58,1] + grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) #[1,1,58,58,2] + value = (heatmap * grid).sum(dim=(2, 3)) #[4,10,2] + kp = {'value': value} + + return kp + + def audio_feature(self, x, heatmap): + + # prediction = self.kp(x) #[4,10,H/4-6, W/4-6] + + # final_shape = prediction.shape + # heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58] + # heatmap = F.softmax(heatmap / self.temperature, dim=2) + # heatmap = heatmap.view(*final_shape) #[4,10,58,58] + + # out = self.gaussian2kp(heatmap) + final_shape = heatmap.squeeze(2).shape + + if self.jacobian is not None: + jacobian_map = self.jacobian(x) ##[4,40,H/4-6, W/4-6] + jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], + final_shape[3]) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6] + jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) + jacobian = jacobian.sum(dim=-1) #[4,10,4] + jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2] + + return jacobian + + def forward(self, x): #torch.Size([4, 3, H, W]) + if self.scale_factor != 1: + x = self.down(x) # 0.25 [4, 3, H/4, W/4] + + feature_map = self.predictor(x) #[4,3+32,H/4, W/4] + prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6] + + final_shape = prediction.shape + + heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58] + heatmap = F.softmax(heatmap / self.temperature, dim=2) + heatmap = heatmap.view(*final_shape) #[4,10,58,58] + + out = self.gaussian2kp(heatmap) + out['heatmap'] = heatmap + + if self.jacobian is not None: + jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6] + jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], + final_shape[3]) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6] + jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) + jacobian = jacobian.sum(dim=-1) #[4,10,4] + jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2] + out['jacobian'] = jacobian + + return out + + + + +class KPDetector_a(nn.Module): + """ + Detecting a keypoints. Return keypoint position and jacobian near each keypoint. + """ + + def __init__(self, block_expansion, num_kp, num_channels,num_channels_a, max_features, + num_blocks, temperature, estimate_jacobian=False, scale_factor=1, + single_jacobian_map=False, pad=0): + super(KPDetector_a, self).__init__() + + self.predictor = Hourglass(block_expansion, in_features=num_channels_a, + max_features=max_features, num_blocks=num_blocks) + + self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7), + padding=pad) + + if estimate_jacobian: + self.num_jacobian_maps = 1 if single_jacobian_map else num_kp + self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, + out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad) + self.jacobian.weight.data.zero_() + self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) + else: + self.jacobian = None + + self.temperature = temperature + self.scale_factor = scale_factor + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) + + + + + def gaussian2kp(self, heatmap): + """ + Extract the mean and from a heatmap + """ + shape = heatmap.shape + heatmap = heatmap.unsqueeze(-1) #[4,10,58,58,1] + grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) #[1,1,58,58,2] + value = (heatmap * grid).sum(dim=(2, 3)) #[4,10,2] + kp = {'value': value} + + return kp + + def audio_feature(self, x, heatmap): + + # prediction = self.kp(x) #[4,10,H/4-6, W/4-6] + + # final_shape = prediction.shape + # heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58] + # heatmap = F.softmax(heatmap / self.temperature, dim=2) + # heatmap = heatmap.view(*final_shape) #[4,10,58,58] + + # out = self.gaussian2kp(heatmap) + final_shape = heatmap.squeeze(2).shape + + if self.jacobian is not None: + jacobian_map = self.jacobian(x) ##[4,40,H/4-6, W/4-6] + jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], + final_shape[3]) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6] + jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) + jacobian = jacobian.sum(dim=-1) #[4,10,4] + jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2] + + return jacobian + + def forward(self, feature_map): #torch.Size([4, 3, H, W]) + + prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6] + + final_shape = prediction.shape + + heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58] + heatmap = F.softmax(heatmap / self.temperature, dim=2) + heatmap = heatmap.view(*final_shape) #[4,10,58,58] + + out = self.gaussian2kp(heatmap) + out['heatmap'] = heatmap + + if self.jacobian is not None: + jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6] + jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], + final_shape[3]) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6] + jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) + jacobian = jacobian.sum(dim=-1) #[4,10,4] + jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2] + out['jacobian'] = jacobian + + return out + + +class Audio_Feature(nn.Module): + def __init__(self): + super(Audio_Feature, self).__init__() + + self.con_encoder = Ct_encoder() + self.emo_encoder = EmotionNet() + self.decoder = AF2F_s() + + + + def forward(self, x): + x = x.unsqueeze(1) + + c = self.con_encoder(x) + e = self.emo_encoder(x) + + # d = torch.cat([c, e], dim=1) + d = self.decoder(c) + + + return d +''' +def forward(self, x, cube, audio): #torch.Size([4, 3, H, W]) + if self.scale_factor != 1: + x = self.down(x) # 0.25 [4, 3, H/4, W/4] + + cube = cube.unsqueeze(1) + feature = torch.cat([x,cube,audio],dim=1) + feature_map = self.predictor(feature) #[4,3+32,H/4, W/4] + prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6] + + final_shape = prediction.shape + heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58] + heatmap = F.softmax(heatmap / self.temperature, dim=2) + heatmap = heatmap.view(*final_shape) #[4,10,58,58] + + out = self.gaussian2kp(heatmap) + out['heatmap'] = heatmap + if self.jacobian is not None: + jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6] + jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], + final_shape[3]) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6] + jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) + jacobian = jacobian.sum(dim=-1) #[4,10,4] + jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2] + out['jacobian'] = jacobian + + return out +''' diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/model.py b/talkingface/model/audio_driven_talkingface/eamm_modules/model.py new file mode 100644 index 00000000..5fe8210b --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/model.py @@ -0,0 +1,587 @@ +from torch import nn +import torch +import torch.nn.functional as F +from .util import AntiAliasInterpolation2d, make_coordinate_grid +from torchvision import models +import numpy as np +from torch.autograd import grad + + +class Vgg19(torch.nn.Module): + """ + Vgg19 network for perceptual loss. See Sec 3.3. + """ + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), + requires_grad=False) + self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), + requires_grad=False) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + X = (X - self.mean) / self.std + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + +class ImagePyramide(torch.nn.Module): + """ + Create image pyramide for computing pyramide perceptual loss. See Sec 3.3 + """ + def __init__(self, scales, num_channels): + super(ImagePyramide, self).__init__() + downs = {} + for scale in scales: + downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) + self.downs = nn.ModuleDict(downs) + + def forward(self, x): + out_dict = {} + for scale, down_module in self.downs.items(): + out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) + return out_dict + + +class Transform: + """ + Random tps transformation for equivariance constraints. See Sec 3.3 + """ + def __init__(self, bs, **kwargs): + noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3])) + self.theta = noise + torch.eye(2, 3).view(1, 2, 3) + self.bs = bs + + if ('sigma_tps' in kwargs) and ('points_tps' in kwargs): + self.tps = True + self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type()) + self.control_points = self.control_points.unsqueeze(0) + self.control_params = torch.normal(mean=0, + std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2])) + else: + self.tps = False + + def transform_frame(self, frame): + grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2] + grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) + grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) + return F.grid_sample(frame, grid, padding_mode="reflection") + + def inverse_transform_frame(self, frame): + grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2] + grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) + grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) + return F.grid_sample(frame, grid, padding_mode="reflection") + + def warp_coordinates(self, coordinates): + theta = self.theta.type(coordinates.type()) + theta = theta.unsqueeze(1) + transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:] + transformed = transformed.squeeze(-1) + + if self.tps: + control_points = self.control_points.type(coordinates.type()) + control_params = self.control_params.type(coordinates.type()) + distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) + distances = torch.abs(distances).sum(-1) + + result = distances ** 2 + result = result * torch.log(distances + 1e-6) + result = result * control_params + result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) + transformed = transformed + result + + return transformed + + def inverse_warp_coordinates(self, coordinates): + theta = self.theta.type(coordinates.type()) + theta = theta.unsqueeze(1) + a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda() + c = torch.cat((theta,a),2) + d = c.inverse()[:,:,:2,:] + d = d.type(coordinates.type()) + transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:] + transformed = transformed.squeeze(-1) + + if self.tps: + control_points = self.control_points.type(coordinates.type()) + control_params = self.control_params.type(coordinates.type()) + distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) + distances = torch.abs(distances).sum(-1) + + result = distances ** 2 + result = result * torch.log(distances + 1e-6) + result = result * control_params + result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) + transformed = transformed + result + + + return transformed + + def jacobian(self, coordinates): + coordinates.requires_grad=True + new_coordinates = self.warp_coordinates(coordinates)#[4,10,2] + grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True) + grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True) + jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2) + return jacobian + + +def detach_kp(kp): + return {key: value.detach() for key, value in kp.items()} + +class TrainPart1Model(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids): + super(TrainPart1Model, self).__init__() + self.kp_extractor = kp_extractor + self.kp_extractor_a = kp_extractor_a + + self.audio_feature = audio_feature + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = train_params['scales'] + self.disc_scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + if sum(self.loss_weights['perceptual']) != 0: + self.vgg = Vgg19() + if torch.cuda.is_available(): + self.vgg = self.vgg.cuda() + + + self.mse_loss_fn = nn.MSELoss().cuda() + def forward(self, x): + + kp_source = self.kp_extractor(x['example_image']) + + kp_driving = [] + for i in range(16): + kp_driving.append(self.kp_extractor(x['driving'][:,i])) + + kp_driving_a = [] #x['example_image'], + deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) + loss_values = {} + + if self.loss_weights['audio'] != 0: + + kp_driving_a = [] + for i in range(16): + kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))# + + + loss_value = 0 + loss_heatmap = 0 + loss_jacobian = 0 + loss_perceptual = 0 + for i in range(len(kp_driving)): + loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian']).mean())*self.loss_weights['audio'] + + # loss_jacobian = loss_jacobian*self.loss_weights['audio'] + loss_heatmap += (torch.abs(kp_driving[i]['heatmap'] - kp_driving_a[i]['heatmap']).mean())*self.loss_weights['audio']*100 + + + loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']).mean())*self.loss_weights['audio'] + + loss_values['loss_value'] = loss_value/len(kp_driving) + loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving) + loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving) + + + if self.train_params['generator'] == 'not': + # loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out) + for i in range(1): #0,len(kp_driving),4 + + generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i]) + generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a}) + elif self.train_params['generator'] == 'visual': + for i in range(0,len(kp_driving),4): #0,len(kp_driving),4 + + generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i]) + generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) + + pyramide_real = self.pyramid(x['driving'][:,i]) + pyramide_generated = self.pyramid(generated['prediction']) + + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_perceptual += value_total + + length = int((len(kp_driving)-1)/4)+1 + loss_values['perceptual'] = loss_perceptual/length + elif self.train_params['generator'] == 'audio': + for i in range(0,len(kp_driving),4): #0,len(kp_driving),4 + + generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i]) + generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a}) + + pyramide_real = self.pyramid(x['driving'][:,i]) + pyramide_generated = self.pyramid(generated['prediction']) + + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_perceptual += value_total + + length = int((len(kp_driving)-1)/4)+1 + loss_values['perceptual'] = loss_perceptual/length + else: + print('wrong train_params: ', self.train_params['generator']) + + + + return loss_values,generated + + +class TrainPart2Model(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids): + super(TrainPart2Model, self).__init__() + self.kp_extractor = kp_extractor + self.kp_extractor_a = kp_extractor_a + + self.audio_feature = audio_feature + self.emo_feature = emo_feature + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = train_params['scales'] + self.disc_scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + if sum(self.loss_weights['perceptual']) != 0: + self.vgg = Vgg19() + if torch.cuda.is_available(): + self.vgg = self.vgg.cuda() + + self.mse_loss_fn = nn.MSELoss().cuda() + self.CroEn_loss = nn.CrossEntropyLoss().cuda() + def forward(self, x): + + kp_source = self.kp_extractor(x['example_image']) + + kp_driving = [] + kp_emo = [] + for i in range(16): + kp_driving.append(self.kp_extractor(x['driving'][:,i])) + # kp_emo.append(self.emo_detector(x['driving'][:,i])) + + kp_driving_a = [] #x['example_image'], + deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) + # emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) + loss_values = {} + + if self.loss_weights['emo'] != 0: + + kp_driving_a = [] + fakes = [] + for i in range(16): + kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))# + value = self.kp_extractor_a(deco_out[:,i])['value'] + jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian'] + if self.train_params['type'] == 'linear_4' : + out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian) + kp_emo.append(out) + fakes.append(fake) + # kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian)) + elif self.train_params['type'] == 'linear_10': + # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) + + out, fake = self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian) + kp_emo.append(out) + fakes.append(fake) + elif self.train_params['type'] == 'linear_4_new': + # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) + + out, fake = self.emo_feature.linear_4(x['transformed_driving'][:,i],value,jacobian) + kp_emo.append(out) + fakes.append(fake) + elif self.train_params['type'] == 'linear_np_4': + # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) + + out, fake = self.emo_feature.linear_np_4(x['transformed_driving'][:,i],value,jacobian) + kp_emo.append(out) + fakes.append(fake) + elif self.train_params['type'] == 'linear_np_10': + # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) + + out, fake = self.emo_feature.linear_np_10(x['transformed_driving'][:,i],value,jacobian) + kp_emo.append(out) + fakes.append(fake) + + loss_value = 0 + + loss_jacobian = 0 + + loss_classify = 0 + kp_all = kp_driving_a + + for i in range(len(kp_driving)): + + if self.train_params['type'] == 'linear_4' or self.train_params['type'] == 'linear_4_new' or self.train_params['type'] == 'linear_np_4': + loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,1] - kp_driving_a[i]['jacobian'][:,1] -kp_emo[i]['jacobian'][:,0]).mean())*self.loss_weights['emo'] + loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,4] - kp_driving_a[i]['jacobian'][:,4] -kp_emo[i]['jacobian'][:,1]).mean())*self.loss_weights['emo'] + loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,6] - kp_driving_a[i]['jacobian'][:,6] -kp_emo[i]['jacobian'][:,2]).mean())*self.loss_weights['emo'] + loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,8] - kp_driving_a[i]['jacobian'][:,8] -kp_emo[i]['jacobian'][:,3]).mean())*self.loss_weights['emo'] + + loss_classify += self.CroEn_loss(fakes[i],x['emotion']) + loss_value += (torch.abs(kp_driving[i]['value'][:,1] .detach() - kp_driving_a[i]['value'][:,1] - kp_emo[i]['value'][:,0] ).mean())*self.loss_weights['emo'] + loss_value += (torch.abs(kp_driving[i]['value'][:,4] .detach() - kp_driving_a[i]['value'][:,4] - kp_emo[i]['value'][:,1] ).mean())*self.loss_weights['emo'] + loss_value += (torch.abs(kp_driving[i]['value'][:,6] .detach() - kp_driving_a[i]['value'][:,6] - kp_emo[i]['value'][:,2] ).mean())*self.loss_weights['emo'] + loss_value += (torch.abs(kp_driving[i]['value'][:,8] .detach() - kp_driving_a[i]['value'][:,8] - kp_emo[i]['value'][:,3] ).mean())*self.loss_weights['emo'] + kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1] + kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4] + kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6] + kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8] + kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1] + kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4] + kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6] + kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8] + elif self.train_params['type'] == 'linear_10' or self.train_params['type'] == 'linear_np_10': + loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian'] -kp_emo[i]['jacobian']).mean())*self.loss_weights['emo'] + + loss_classify += self.CroEn_loss(fakes[i],x['emotion']) + loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value'] - kp_emo[i]['value'] ).mean())*self.loss_weights['emo'] + + # kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value'] + + loss_values['loss_value'] = loss_value/len(kp_driving) + # loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving) + loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving) + if self.train_params['classify'] == True: + loss_values['loss_classify'] = loss_classify/len(kp_driving) + else: + loss_values['loss_classify'] = torch.tensor(0, device = loss_values['loss_value'].device) + + return loss_values, generated + + +class GeneratorFullModel(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params): + super(GeneratorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.kp_extractor_a = kp_extractor_a + # self.content_encoder = content_encoder + # self.emotion_encoder = emotion_encoder + self.audio_feature = audio_feature + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = train_params['scales'] + self.disc_scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + if sum(self.loss_weights['perceptual']) != 0: + self.vgg = Vgg19() + if torch.cuda.is_available(): + self.vgg = self.vgg.cuda() + + self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda() + self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda() + + def forward(self, x): + # source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[]) + # source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1))) + # kp_source = self.kp_extractor(x['source']) + # kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f) + # driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1))) + # driving_a_f = self.audio_feature(x['driving_audio']) + # kp_driving = self.kp_extractor(x['driving']) + # kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f) + + kp_driving = [] + for i in range(16): + kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value'])) + + kp_driving_a = [] + fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose']) + fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out) + + + fake_lmark = torch.mm( fake_lmark, self.pca.t() ) + fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark) + + + fake_lmark = fake_lmark.unsqueeze(0) + + # for i in range(16): + # kp_driving_a.append() + + # generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving) + # generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) + + loss_values = {} + + pyramide_real = self.pyramid(x['driving']) + pyramide_generated = self.pyramid(generated['prediction']) + + if self.loss_weights['audio'] != 0: + value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean() + value = value/2 + loss_values['jacobian'] = value*self.loss_weights['audio'] + value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean() + value = value/2 + loss_values['heatmap'] = value*self.loss_weights['audio'] + value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean() + value = value/2 + loss_values['value'] = value*self.loss_weights['audio'] + + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_values['perceptual'] = value_total + + if self.loss_weights['generator_gan'] != 0: + discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) + discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) + value_total = 0 + for scale in self.disc_scales: + key = 'prediction_map_%s' % scale + value = ((1 - discriminator_maps_generated[key]) ** 2).mean() + value_total += self.loss_weights['generator_gan'] * value + loss_values['gen_gan'] = value_total + + if sum(self.loss_weights['feature_matching']) != 0: + value_total = 0 + for scale in self.disc_scales: + key = 'feature_maps_%s' % scale + for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])): + if self.loss_weights['feature_matching'][i] == 0: + continue + value = torch.abs(a - b).mean() + value_total += self.loss_weights['feature_matching'][i] * value + loss_values['feature_matching'] = value_total + + if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0: + transform = Transform(x['driving'].shape[0], **self.train_params['transform_params']) + transformed_frame = transform.transform_frame(x['driving']) + transformed_landmark = transform.inverse_warp_coordinates(x['driving_landmark']) + transformed_kp = self.kp_extractor(transformed_frame) + + generated['transformed_frame'] = transformed_frame + generated['transformed_kp'] = transformed_kp + + ## Value loss part + if self.loss_weights['equivariance_value'] != 0: + value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean() + loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value + + ## jacobian loss part + if self.loss_weights['equivariance_jacobian'] != 0: + jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']), + transformed_kp['jacobian']) + + normed_driving = torch.inverse(kp_driving['jacobian']) + normed_transformed = jacobian_transformed + value = torch.matmul(normed_driving, normed_transformed) + + eye = torch.eye(2).view(1, 1, 2, 2).type(value.type()) + + value = torch.abs(eye - value).mean() + loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value + + return loss_values, generated + + +class DiscriminatorFullModel(torch.nn.Module): + """ + Merge all discriminator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, generator, discriminator, train_params): + super(DiscriminatorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + def forward(self, x, generated): + pyramide_real = self.pyramid(x['driving']) + pyramide_generated = self.pyramid(generated['prediction'].detach()) + + kp_driving = generated['kp_driving'] + discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) + discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) + + loss_values = {} + value_total = 0 + for scale in self.scales: + key = 'prediction_map_%s' % scale + value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2 + value_total += self.loss_weights['discriminator_gan'] * value.mean() + loss_values['disc_gan'] = value_total + + return loss_values diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/model_delta_map.py b/talkingface/model/audio_driven_talkingface/eamm_modules/model_delta_map.py new file mode 100644 index 00000000..f0cbff49 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/model_delta_map.py @@ -0,0 +1,500 @@ +from torch import nn +import torch +import torch.nn.functional as F +from .util import AntiAliasInterpolation2d, make_coordinate_grid +from torchvision import models +import numpy as np +from torch.autograd import grad + + +class Vgg19(torch.nn.Module): + """ + Vgg19 network for perceptual loss. See Sec 3.3. + """ + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), + requires_grad=False) + self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), + requires_grad=False) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + X = (X - self.mean) / self.std + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + +class ImagePyramide(torch.nn.Module): + """ + Create image pyramide for computing pyramide perceptual loss. See Sec 3.3 + """ + def __init__(self, scales, num_channels): + super(ImagePyramide, self).__init__() + downs = {} + for scale in scales: + downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) + self.downs = nn.ModuleDict(downs) + + def forward(self, x): + out_dict = {} + for scale, down_module in self.downs.items(): + out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) + return out_dict + + +class Transform: + """ + Random tps transformation for equivariance constraints. See Sec 3.3 + """ + def __init__(self, bs, **kwargs): + noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3])) + self.theta = noise + torch.eye(2, 3).view(1, 2, 3) + self.bs = bs + + if ('sigma_tps' in kwargs) and ('points_tps' in kwargs): + self.tps = True + self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type()) + self.control_points = self.control_points.unsqueeze(0) + self.control_params = torch.normal(mean=0, + std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2])) + else: + self.tps = False + + def transform_frame(self, frame): + grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2] + grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) + grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) + return F.grid_sample(frame, grid, padding_mode="reflection") + + def inverse_transform_frame(self, frame): + grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2] + grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) + grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) + return F.grid_sample(frame, grid, padding_mode="reflection") + + def warp_coordinates(self, coordinates): + theta = self.theta.type(coordinates.type()) + theta = theta.unsqueeze(1) + transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:] + transformed = transformed.squeeze(-1) + + if self.tps: + control_points = self.control_points.type(coordinates.type()) + control_params = self.control_params.type(coordinates.type()) + distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) + distances = torch.abs(distances).sum(-1) + + result = distances ** 2 + result = result * torch.log(distances + 1e-6) + result = result * control_params + result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) + transformed = transformed + result + + return transformed + + def inverse_warp_coordinates(self, coordinates): + theta = self.theta.type(coordinates.type()) + theta = theta.unsqueeze(1) + a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda() + c = torch.cat((theta,a),2) + d = c.inverse()[:,:,:2,:] + d = d.type(coordinates.type()) + transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:] + transformed = transformed.squeeze(-1) + + if self.tps: + control_points = self.control_points.type(coordinates.type()) + control_params = self.control_params.type(coordinates.type()) + distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) + distances = torch.abs(distances).sum(-1) + + result = distances ** 2 + result = result * torch.log(distances + 1e-6) + result = result * control_params + result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) + transformed = transformed + result + + + return transformed + + def jacobian(self, coordinates): + coordinates.requires_grad=True + new_coordinates = self.warp_coordinates(coordinates)#[4,10,2] + grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True) + grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True) + jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2) + return jacobian + + +def detach_kp(kp): + return {key: value.detach() for key, value in kp.items()} + +class TrainFullModel(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids): + super(TrainFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.kp_extractor_a = kp_extractor_a + # self.emo_detector = emo_detector + # self.content_encoder = content_encoder + # self.emotion_encoder = emotion_encoder + self.audio_feature = audio_feature + self.emo_feature = emo_feature + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = train_params['scales'] + self.disc_scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + if sum(self.loss_weights['perceptual']) != 0: + self.vgg = Vgg19() + if torch.cuda.is_available(): + self.vgg = self.vgg.cuda() + + # self.pca = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/U_106.npy'))[:, :16].to(device_ids[0]) + # self.mean = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/mean_106.npy')).to(device_ids[0]) + self.mse_loss_fn = nn.MSELoss().cuda() + self.CroEn_loss = nn.CrossEntropyLoss().cuda() + def forward(self, x): + # source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[]) + # source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1))) + kp_source = self.kp_extractor(x['example_image']) + + kp_driving = [] + kp_emo = [] + for i in range(16): + kp_driving.append(self.kp_extractor(x['driving'][:,i])) + # kp_emo.append(self.emo_detector(x['driving'][:,i])) + # print('KP_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a')) + kp_driving_a = [] #x['example_image'], + deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) + # emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) + loss_values = {} + + if self.loss_weights['emo'] != 0: + + kp_driving_a = [] + fakes = [] + for i in range(16): + kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))# + value = self.kp_extractor_a(deco_out[:,i])['value'] + jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian'] + if self.train_params['type'] == 'map_4': + out, fake = self.emo_feature.map_4(x['transformed_driving'][:,i],value,jacobian) + kp_emo.append(out) + fakes.append(fake) + # kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian)) + elif self.train_params['type'] == 'map_10': + # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) + + out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian) + kp_emo.append(out) + fakes.append(fake) + # kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian)) + # print('Kp_audio_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a')) + loss_value = 0 + # loss_heatmap = 0 + loss_jacobian = 0 + loss_perceptual = 0 + loss_classify = 0 + kp_all = kp_driving_a + for i in range(len(kp_driving)): + if self.train_params['type'] == 'map_4': + loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,1] - kp_driving_a[i]['jacobian'][:,1] -kp_emo[i]['jacobian'][:,0]).mean())*self.loss_weights['emo'] + loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,4] - kp_driving_a[i]['jacobian'][:,4] -kp_emo[i]['jacobian'][:,1]).mean())*self.loss_weights['emo'] + loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,6] - kp_driving_a[i]['jacobian'][:,6] -kp_emo[i]['jacobian'][:,2]).mean())*self.loss_weights['emo'] + loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,8] - kp_driving_a[i]['jacobian'][:,8] -kp_emo[i]['jacobian'][:,3]).mean())*self.loss_weights['emo'] + + loss_classify += self.CroEn_loss(fakes[i],x['emotion']) + loss_value += (torch.abs(kp_driving[i]['value'][:,1] .detach() - kp_driving_a[i]['value'][:,1] - kp_emo[i]['value'][:,0] ).mean())*self.loss_weights['emo'] + loss_value += (torch.abs(kp_driving[i]['value'][:,4] .detach() - kp_driving_a[i]['value'][:,4] - kp_emo[i]['value'][:,1] ).mean())*self.loss_weights['emo'] + loss_value += (torch.abs(kp_driving[i]['value'][:,6] .detach() - kp_driving_a[i]['value'][:,6] - kp_emo[i]['value'][:,2] ).mean())*self.loss_weights['emo'] + loss_value += (torch.abs(kp_driving[i]['value'][:,8] .detach() - kp_driving_a[i]['value'][:,8] - kp_emo[i]['value'][:,3] ).mean())*self.loss_weights['emo'] + kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1] + kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4] + kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6] + kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8] + kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1] + kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4] + kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6] + kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8] + elif self.train_params['type'] == 'map_10': + loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian'] -kp_emo[i]['jacobian']).mean())*self.loss_weights['emo'] + + loss_classify += self.CroEn_loss(fakes[i],x['emotion']) + loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value'] - kp_emo[i]['value'] ).mean())*self.loss_weights['emo'] + + # kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value'] + + loss_values['loss_value'] = loss_value/len(kp_driving) + # loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving) + loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving) + loss_values['loss_classify'] = loss_classify/len(kp_driving) + + if self.train_params['generator'] == 'not': + loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out) + for i in range(1): #0,len(kp_driving),4 + + generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i]) + generated.update({'kp_source': kp_source, 'kp_driving': kp_all}) + elif self.train_params['generator'] == 'visual': + for i in range(0,len(kp_driving),4): #0,len(kp_driving),4 + + generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i]) + generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) + + pyramide_real = self.pyramid(x['driving'][:,i]) + pyramide_generated = self.pyramid(generated['prediction']) + + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_perceptual += value_total + + length = int((len(kp_driving)-1)/4)+1 + loss_values['perceptual'] = loss_perceptual/length + elif self.train_params['generator'] == 'audio': + for i in range(0,len(kp_driving),4): #0,len(kp_driving),4 + + generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i]) + generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a}) + + pyramide_real = self.pyramid(x['driving'][:,i]) + pyramide_generated = self.pyramid(generated['prediction']) + + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_perceptual += value_total + + length = int((len(kp_driving)-1)/4)+1 + loss_values['perceptual'] = loss_perceptual/length + else: + print('wrong train_params: ', self.train_params['generator']) + + + + return loss_values,generated + +class GeneratorFullModel(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params): + super(GeneratorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.kp_extractor_a = kp_extractor_a + # self.content_encoder = content_encoder + # self.emotion_encoder = emotion_encoder + self.audio_feature = audio_feature + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = train_params['scales'] + self.disc_scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + if sum(self.loss_weights['perceptual']) != 0: + self.vgg = Vgg19() + if torch.cuda.is_available(): + self.vgg = self.vgg.cuda() + + self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda() + self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda() + + def forward(self, x): + # source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[]) + # source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1))) + # kp_source = self.kp_extractor(x['source']) + # kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f) + # driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1))) + # driving_a_f = self.audio_feature(x['driving_audio']) + # kp_driving = self.kp_extractor(x['driving']) + # kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f) + + kp_driving = [] + for i in range(16): + kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value'])) + + kp_driving_a = [] + fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose']) + fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out) + + + fake_lmark = torch.mm( fake_lmark, self.pca.t() ) + fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark) + + + fake_lmark = fake_lmark.unsqueeze(0) + + # for i in range(16): + # kp_driving_a.append() + + # generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving) + # generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) + + loss_values = {} + + pyramide_real = self.pyramid(x['driving']) + pyramide_generated = self.pyramid(generated['prediction']) + + if self.loss_weights['audio'] != 0: + value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean() + value = value/2 + loss_values['jacobian'] = value*self.loss_weights['audio'] + value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean() + value = value/2 + loss_values['heatmap'] = value*self.loss_weights['audio'] + value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean() + value = value/2 + loss_values['value'] = value*self.loss_weights['audio'] + + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_values['perceptual'] = value_total + + if self.loss_weights['generator_gan'] != 0: + discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) + discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) + value_total = 0 + for scale in self.disc_scales: + key = 'prediction_map_%s' % scale + value = ((1 - discriminator_maps_generated[key]) ** 2).mean() + value_total += self.loss_weights['generator_gan'] * value + loss_values['gen_gan'] = value_total + + if sum(self.loss_weights['feature_matching']) != 0: + value_total = 0 + for scale in self.disc_scales: + key = 'feature_maps_%s' % scale + for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])): + if self.loss_weights['feature_matching'][i] == 0: + continue + value = torch.abs(a - b).mean() + value_total += self.loss_weights['feature_matching'][i] * value + loss_values['feature_matching'] = value_total + + if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0: + transform = Transform(x['driving'].shape[0], **self.train_params['transform_params']) + transformed_frame = transform.transform_frame(x['driving']) + transformed_landmark = transform.inverse_warp_coordinates(x['driving_landmark']) + transformed_kp = self.kp_extractor(transformed_frame) + + generated['transformed_frame'] = transformed_frame + generated['transformed_kp'] = transformed_kp + + ## Value loss part + if self.loss_weights['equivariance_value'] != 0: + value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean() + loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value + + ## jacobian loss part + if self.loss_weights['equivariance_jacobian'] != 0: + jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']), + transformed_kp['jacobian']) + + normed_driving = torch.inverse(kp_driving['jacobian']) + normed_transformed = jacobian_transformed + value = torch.matmul(normed_driving, normed_transformed) + + eye = torch.eye(2).view(1, 1, 2, 2).type(value.type()) + + value = torch.abs(eye - value).mean() + loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value + + return loss_values, generated + + +class DiscriminatorFullModel(torch.nn.Module): + """ + Merge all discriminator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, generator, discriminator, train_params): + super(DiscriminatorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + def forward(self, x, generated): + pyramide_real = self.pyramid(x['driving']) + pyramide_generated = self.pyramid(generated['prediction'].detach()) + + kp_driving = generated['kp_driving'] + discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) + discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) + + loss_values = {} + value_total = 0 + for scale in self.scales: + key = 'prediction_map_%s' % scale + value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2 + value_total += self.loss_weights['discriminator_gan'] * value.mean() + loss_values['disc_gan'] = value_total + + return loss_values diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/model_gen.py b/talkingface/model/audio_driven_talkingface/eamm_modules/model_gen.py new file mode 100644 index 00000000..9dc3d404 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/model_gen.py @@ -0,0 +1,516 @@ +from torch import nn +import torch +import torch.nn.functional as F +from .util import AntiAliasInterpolation2d, make_coordinate_grid +from torchvision import models +import numpy as np +from torch.autograd import grad + + +class Vgg19(torch.nn.Module): + """ + Vgg19 network for perceptual loss. See Sec 3.3. + """ + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), + requires_grad=False) + self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), + requires_grad=False) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + X = (X - self.mean) / self.std + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + +class ImagePyramide(torch.nn.Module): + """ + Create image pyramide for computing pyramide perceptual loss. See Sec 3.3 + """ + def __init__(self, scales, num_channels): + super(ImagePyramide, self).__init__() + downs = {} + for scale in scales: + downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) + self.downs = nn.ModuleDict(downs) + + def forward(self, x): + out_dict = {} + for scale, down_module in self.downs.items(): + out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) + return out_dict + + +class Transform: + """ + Random tps transformation for equivariance constraints. See Sec 3.3 + """ + def __init__(self, bs, **kwargs): + noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3])) + self.theta = noise + torch.eye(2, 3).view(1, 2, 3) + self.bs = bs + + if ('sigma_tps' in kwargs) and ('points_tps' in kwargs): + self.tps = True + self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type()) + self.control_points = self.control_points.unsqueeze(0) + self.control_params = torch.normal(mean=0, + std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2])) + else: + self.tps = False + + def transform_frame(self, frame): + grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2] + grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) + grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) + return F.grid_sample(frame, grid, padding_mode="reflection") + + def inverse_transform_frame(self, frame): + grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2] + grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) + grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) + return F.grid_sample(frame, grid, padding_mode="reflection") + + def warp_coordinates(self, coordinates): + theta = self.theta.type(coordinates.type()) + theta = theta.unsqueeze(1) + transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:] + transformed = transformed.squeeze(-1) + + if self.tps: + control_points = self.control_points.type(coordinates.type()) + control_params = self.control_params.type(coordinates.type()) + distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) + distances = torch.abs(distances).sum(-1) + + result = distances ** 2 + result = result * torch.log(distances + 1e-6) + result = result * control_params + result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) + transformed = transformed + result + + return transformed + + def inverse_warp_coordinates(self, coordinates): + theta = self.theta.type(coordinates.type()) + theta = theta.unsqueeze(1) + a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda() + c = torch.cat((theta,a),2) + d = c.inverse()[:,:,:2,:] + d = d.type(coordinates.type()) + transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:] + transformed = transformed.squeeze(-1) + + if self.tps: + control_points = self.control_points.type(coordinates.type()) + control_params = self.control_params.type(coordinates.type()) + distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) + distances = torch.abs(distances).sum(-1) + + result = distances ** 2 + result = result * torch.log(distances + 1e-6) + result = result * control_params + result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) + transformed = transformed + result + + + return transformed + + def jacobian(self, coordinates): + coordinates.requires_grad=True + new_coordinates = self.warp_coordinates(coordinates)#[4,10,2] + grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True) + grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True) + jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2) + return jacobian + + +def detach_kp(kp): + return {key: value.detach() for key, value in kp.items()} + +class TrainFullModel(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids): + super(TrainFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.kp_extractor_a = kp_extractor_a + # self.emo_detector = emo_detector + # self.content_encoder = content_encoder + # self.emotion_encoder = emotion_encoder + self.audio_feature = audio_feature + self.emo_feature = emo_feature + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = train_params['scales'] + self.disc_scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + if sum(self.loss_weights['perceptual']) != 0: + self.vgg = Vgg19() + if torch.cuda.is_available(): + self.vgg = self.vgg.cuda() + + # self.pca = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/U_106.npy'))[:, :16].to(device_ids[0]) + # self.mean = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/mean_106.npy')).to(device_ids[0]) + self.mse_loss_fn = nn.MSELoss().cuda() + self.CroEn_loss = nn.CrossEntropyLoss().cuda() + def forward(self, x): + # source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[]) + # source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1))) + kp_source = self.kp_extractor(x['example_image']) + # print(x['name'],len(x['name'])) + kp_driving = [] + kp_emo = [] + for i in range(16): + kp_driving.append(self.kp_extractor(x['driving'][:,i])) + # kp_emo.append(self.emo_detector(x['driving'][:,i])) + # print('KP_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a')) + kp_driving_a = [] #x['example_image'], + deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) + # emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) + loss_values = {} + + if self.loss_weights['emo'] != 0: + + kp_driving_a = [] + fakes = [] + for i in range(16): + kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))# + value = self.kp_extractor_a(deco_out[:,i])['value'] + jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian'] + if self.train_params['type'] == 'linear_4' and x['name'][0] == 0: + out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian) + kp_emo.append(out) + fakes.append(fake) + # kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian)) + elif self.train_params['type'] == 'linear_10' and x['name'][0] == 0: + # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) + + out, fake = self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian) + kp_emo.append(out) + fakes.append(fake) + elif self.train_params['type'] == 'linear_4_new' and x['name'][0] == 0: + # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) + + out, fake = self.emo_feature.linear_4(x['transformed_driving'][:,i],value,jacobian) + kp_emo.append(out) + fakes.append(fake) + elif self.train_params['type'] == 'linear_np_4': + # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) + + out, fake = self.emo_feature.linear_np_4(x['transformed_driving'][:,i],value,jacobian) + kp_emo.append(out) + fakes.append(fake) + elif self.train_params['type'] == 'linear_np_10': + # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) + + out, fake = self.emo_feature.linear_np_10(x['transformed_driving'][:,i],value,jacobian) + kp_emo.append(out) + fakes.append(fake) + # kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian)) + # print('Kp_audio_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a')) + + loss_perceptual = 0 + + kp_all = kp_driving_a + if self.train_params['smooth'] == True: + value_all = torch.randn(len(kp_driving),out['value'].shape[0],out['value'].shape[1],out['value'].shape[2]).cuda() + jacobian_all = torch.randn(len(kp_driving),out['jacobian'].shape[0],out['jacobian'].shape[1],2,2).cuda() + print(len(kp_driving)) + for i in range(len(kp_driving)): + # if x['name'][i] == 'LRW': + # loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian']).mean())*self.loss_weights['emo'] + + # loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']).mean())*self.loss_weights['emo'] + # loss_classify += self.mse_loss_fn(deco_out,deco_out) + if self.train_params['type'] == 'linear_4' and x['name'][0] == 0: + + kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1] + kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4] + kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6] + kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8] + kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1] + kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4] + kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6] + kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8] + + # kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value'] + + + if self.train_params['smooth'] == True: + loss_smooth = 0 + loss_smooth += (torch.abs(value_all[2:,:,:,:] + value_all[:-2,:,:,:].detach() -2*value_all[1:-1,:,:,:].detach()).mean())*self.loss_weights['emo'] *100 + loss_smooth += (torch.abs(jacobian_all[2:,:,:,:] + jacobian_all[:-2,:,:,:].detach() -2*jacobian_all[1:-1,:,:,:].detach()).mean())*self.loss_weights['emo'] *100 + loss_values['loss_smooth'] = loss_smooth/len(kp_driving) + else: + loss_values['loss_smooth'] = self.mse_loss_fn(deco_out,deco_out) + if self.train_params['generator'] == 'not': + loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out) + for i in range(1): #0,len(kp_driving),4 + + generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i]) + generated.update({'kp_source': kp_source, 'kp_driving': kp_all}) + elif self.train_params['generator'] == 'visual': + for i in range(0,len(kp_driving),4): #0,len(kp_driving),4 + + generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i]) + generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) + + pyramide_real = self.pyramid(x['driving'][:,i]) + pyramide_generated = self.pyramid(generated['prediction']) + + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_perceptual += value_total + + length = int((len(kp_driving)-1)/4)+1 + loss_values['perceptual'] = loss_perceptual/length + elif self.train_params['generator'] == 'audio': + for i in range(0,len(kp_driving),4): #0,len(kp_driving),4 + + generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i]) + generated.update({'kp_source': kp_source, 'kp_driving': kp_all}) + + pyramide_real = self.pyramid(x['driving'][:,i]) + pyramide_generated = self.pyramid(generated['prediction']) + # loss_mse = nn.MSELoss(generated['prediction'],x['driving'][:,i]) + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_perceptual += value_total + + length = int((len(kp_driving)-1)/4)+1 + loss_values['perceptual'] = loss_perceptual/length + # loss_values['mse'] = loss_mse/length + + else: + print('wrong train_params: ', self.train_params['generator']) + + + + return loss_values,generated + +class GeneratorFullModel(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params): + super(GeneratorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.kp_extractor_a = kp_extractor_a + # self.content_encoder = content_encoder + # self.emotion_encoder = emotion_encoder + self.audio_feature = audio_feature + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = train_params['scales'] + self.disc_scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + if sum(self.loss_weights['perceptual']) != 0: + self.vgg = Vgg19() + if torch.cuda.is_available(): + self.vgg = self.vgg.cuda() + + self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda() + self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda() + + def forward(self, x): + # source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[]) + # source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1))) + # kp_source = self.kp_extractor(x['source']) + # kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f) + # driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1))) + # driving_a_f = self.audio_feature(x['driving_audio']) + # kp_driving = self.kp_extractor(x['driving']) + # kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f) + + kp_driving = [] + for i in range(16): + kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value'])) + + kp_driving_a = [] + fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose']) + fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out) + + + fake_lmark = torch.mm( fake_lmark, self.pca.t() ) + fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark) + + + fake_lmark = fake_lmark.unsqueeze(0) + + # for i in range(16): + # kp_driving_a.append() + + # generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving) + # generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) + + loss_values = {} + + pyramide_real = self.pyramid(x['driving']) + pyramide_generated = self.pyramid(generated['prediction']) + + if self.loss_weights['audio'] != 0: + value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean() + value = value/2 + loss_values['jacobian'] = value*self.loss_weights['audio'] + value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean() + value = value/2 + loss_values['heatmap'] = value*self.loss_weights['audio'] + value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean() + value = value/2 + loss_values['value'] = value*self.loss_weights['audio'] + + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_values['perceptual'] = value_total + + if self.loss_weights['generator_gan'] != 0: + discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) + discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) + value_total = 0 + for scale in self.disc_scales: + key = 'prediction_map_%s' % scale + value = ((1 - discriminator_maps_generated[key]) ** 2).mean() + value_total += self.loss_weights['generator_gan'] * value + loss_values['gen_gan'] = value_total + + if sum(self.loss_weights['feature_matching']) != 0: + value_total = 0 + for scale in self.disc_scales: + key = 'feature_maps_%s' % scale + for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])): + if self.loss_weights['feature_matching'][i] == 0: + continue + value = torch.abs(a - b).mean() + value_total += self.loss_weights['feature_matching'][i] * value + loss_values['feature_matching'] = value_total + + if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0: + transform = Transform(x['driving'].shape[0], **self.train_params['transform_params']) + transformed_frame = transform.transform_frame(x['driving']) + transformed_landmark = transform.inverse_warp_coordinates(x['driving_landmark']) + transformed_kp = self.kp_extractor(transformed_frame) + + generated['transformed_frame'] = transformed_frame + generated['transformed_kp'] = transformed_kp + + ## Value loss part + if self.loss_weights['equivariance_value'] != 0: + value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean() + loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value + + ## jacobian loss part + if self.loss_weights['equivariance_jacobian'] != 0: + jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']), + transformed_kp['jacobian']) + + normed_driving = torch.inverse(kp_driving['jacobian']) + normed_transformed = jacobian_transformed + value = torch.matmul(normed_driving, normed_transformed) + + eye = torch.eye(2).view(1, 1, 2, 2).type(value.type()) + + value = torch.abs(eye - value).mean() + loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value + + return loss_values, generated + + +class DiscriminatorFullModel(torch.nn.Module): + """ + Merge all discriminator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, generator, discriminator, train_params): + super(DiscriminatorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + def forward(self, x, generated): + pyramide_real = self.pyramid(x['driving']) + pyramide_generated = self.pyramid(generated['prediction'].detach()) + + kp_driving = generated['kp_driving'] + discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) + discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) + + loss_values = {} + value_total = 0 + for scale in self.scales: + key = 'prediction_map_%s' % scale + value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2 + value_total += self.loss_weights['discriminator_gan'] * value.mean() + loss_values['disc_gan'] = value_total + + return loss_values diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/ops.py b/talkingface/model/audio_driven_talkingface/eamm_modules/ops.py new file mode 100644 index 00000000..ed4f285f --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/ops.py @@ -0,0 +1,77 @@ +import torch +import torchvision +import torch.nn as nn +import torch.nn.init as init +from torch.autograd import Variable + + +def linear(channel_in, channel_out, + activation=nn.ReLU, + normalizer=nn.BatchNorm1d): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.Linear(channel_in, channel_out, bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[0].weight) + + return nn.Sequential(*layer) + + +def conv2d(channel_in, channel_out, + ksize=3, stride=1, padding=1, + activation=nn.ReLU, + normalizer=nn.BatchNorm2d): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.Conv2d(channel_in, channel_out, + ksize, stride, padding, + bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[0].weight) + + return nn.Sequential(*layer) + + +def conv_transpose2d(channel_in, channel_out, + ksize=4, stride=2, padding=1, + activation=nn.ReLU, + normalizer=nn.BatchNorm2d): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.ConvTranspose2d(channel_in, channel_out, + ksize, stride, padding, + bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[0].weight) + + return nn.Sequential(*layer) + + +def nn_conv2d(channel_in, channel_out, + ksize=3, stride=1, padding=1, + scale_factor=2, + activation=nn.ReLU, + normalizer=nn.BatchNorm2d): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.UpsamplingNearest2d(scale_factor=scale_factor)) + layer.append(nn.Conv2d(channel_in, channel_out, + ksize, stride, padding, + bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[1].weight) + + return nn.Sequential(*layer) + + +def _apply(layer, activation, normalizer, channel_out=None): + if normalizer: + layer.append(normalizer(channel_out)) + if activation: + layer.append(activation()) + return layer + diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/stylegan2.py b/talkingface/model/audio_driven_talkingface/eamm_modules/stylegan2.py new file mode 100644 index 00000000..8367036b --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/stylegan2.py @@ -0,0 +1,923 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Thu Jul 8 01:03:50 2021 + +@author: thea +""" + +""" +The network architectures is based on PyTorch implemenation of StyleGAN2Encoder. +Original PyTorch repo: https://github.com/rosinality/style-based-gan-pytorch +Origianl StyelGAN2 paper: https://github.com/NVlabs/stylegan2 +We use the network architeture for our single-image traning setting. +""" + +import math +import numpy as np +import random + +import torch +from torch import nn +from torch.nn import functional as F + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + # print("FusedLeakyReLU: ", input.abs().mean()) + out = fused_leaky_relu(input, self.bias, + self.negative_slope, + self.scale) + # print("FusedLeakyReLU: ", out.abs().mean()) + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad( + out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + :, + max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), + ] + + # out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + # out = out.permute(0, 2, 3, 1) + + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if len(k.shape) == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = math.sqrt(1) / math.sqrt(in_channel * (kernel_size ** 2)) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + # print("Before EqualConv2d: ", input.abs().mean()) + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + # print("After EqualConv2d: ", out.abs().mean(), (self.weight * self.scale).abs().mean()) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (math.sqrt(1) / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = math.sqrt(1) / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + if style_dim is not None and style_dim > 0: + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + if style is not None: + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + else: + style = torch.ones(batch, 1, in_channel, 1, 1).cuda() + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim=None, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + inject_noise=False, #True + ): + super().__init__() + + self.inject_noise = inject_noise + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style=None, noise=None): + out = self.conv(input, style) + if self.inject_noise: + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3+32, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3+32, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=1, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 256, + 8: 256, + 16: 128, + 32: 64, + 64: 32 * channel_multiplier, + 128: 16 * channel_multiplier, + 256: 8 * channel_multiplier, + 512: 4 * channel_multiplier, + 1024: 2 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if len(styles[0].shape) < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + # out = self.input(latent) + out = styles[0].unsqueeze(-1).unsqueeze(-1).repeat(1,1,4,4) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True, skip_gain=1.0): + super().__init__() + + self.skip_gain = skip_gain + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel) + + if in_channel != out_channel or downsample: + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False + ) + else: + self.skip = nn.Identity() + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out * self.skip_gain + skip) / math.sqrt(self.skip_gain ** 2 + 1.0) + + return out + + +class StyleGAN2Discriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, size=None, opt=None): + super().__init__() + self.opt = opt + self.stddev_group = 16 + if size is None: + size = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) + if "patch" in self.opt.netD and self.opt.D_patch_size is not None: + size = 2 ** int(np.log2(self.opt.D_patch_size)) + + blur_kernel = [1, 3, 3, 1] + channel_multiplier = ndf / 64 + channels = { + 4: min(384, int(4096 * channel_multiplier)), + 8: min(384, int(2048 * channel_multiplier)), + 16: min(384, int(1024 * channel_multiplier)), + 32: min(384, int(512 * channel_multiplier)), + 64: int(256 * channel_multiplier), + 128: int(128 * channel_multiplier), + 256: int(64 * channel_multiplier), + 512: int(32 * channel_multiplier), + 1024: int(16 * channel_multiplier), + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + if "smallpatch" in self.opt.netD: + final_res_log2 = 4 + elif "patch" in self.opt.netD: + final_res_log2 = 3 + else: + final_res_log2 = 2 + + for i in range(log_size, final_res_log2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + if False and "tile" in self.opt.netD: + in_channel += 1 + self.final_conv = ConvLayer(in_channel, channels[4], 3) + if "patch" in self.opt.netD: + self.final_linear = ConvLayer(channels[4], 1, 3, bias=False, activate=False) + else: + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input, get_minibatch_features=False): + if "patch" in self.opt.netD and self.opt.D_patch_size is not None: + h, w = input.size(2), input.size(3) + y = torch.randint(h - self.opt.D_patch_size, ()) + x = torch.randint(w - self.opt.D_patch_size, ()) + input = input[:, :, y:y + self.opt.D_patch_size, x:x + self.opt.D_patch_size] + out = input + for i, conv in enumerate(self.convs): + out = conv(out) + # print(i, out.abs().mean()) + # out = self.convs(input) + + batch, channel, height, width = out.shape + + if False and "tile" in self.opt.netD: + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, 1, channel // 1, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + # print(out.abs().mean()) + + if "patch" not in self.opt.netD: + out = out.view(batch, -1) + out = self.final_linear(out) + + return out + + +class TileStyleGAN2Discriminator(StyleGAN2Discriminator): + def forward(self, input): + B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3) + size = self.opt.D_patch_size + Y = H // size + X = W // size + input = input.view(B, C, Y, size, X, size) + input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size) + return super().forward(input) + + +class StyleGAN2Encoder(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None): + super().__init__() + assert opt is not None + self.opt = opt + channel_multiplier = ngf / 32 + channels = { + 4: min(512, int(round(4096 * channel_multiplier))), + 8: min(512, int(round(2048 * channel_multiplier))), + 16: min(512, int(round(1024 * channel_multiplier))), + 32: min(512, int(round(512 * channel_multiplier))), + 64: int(round(256 * channel_multiplier)), + 128: int(round(128 * channel_multiplier)), + 256: int(round(64 * channel_multiplier)), + 512: int(round(32 * channel_multiplier)), + 1024: int(round(16 * channel_multiplier)), + } + + blur_kernel = [1, 3, 3, 1] + + cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) + convs = [nn.Identity(), + ConvLayer(3, channels[cur_res], 1)] + + num_downsampling = self.opt.stylegan2_G_num_downsampling + for i in range(num_downsampling): + in_channel = channels[cur_res] + out_channel = channels[cur_res // 2] + convs.append(ResBlock(in_channel, out_channel, blur_kernel, downsample=True)) + cur_res = cur_res // 2 + + for i in range(n_blocks // 2): + n_channel = channels[cur_res] + convs.append(ResBlock(n_channel, n_channel, downsample=False)) + + self.convs = nn.Sequential(*convs) + + def forward(self, input, layers=[], get_features=False): + feat = input + feats = [] + if -1 in layers: + layers.append(len(self.convs) - 1) + for layer_id, layer in enumerate(self.convs): + feat = layer(feat) + # print(layer_id, " features ", feat.abs().mean()) + if layer_id in layers: + feats.append(feat) + + if get_features: + return feat, feats + else: + return feat + + +class StyleGAN2Decoder(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None): + super().__init__() + assert opt is not None + self.opt = opt + + blur_kernel = [1, 3, 3, 1] + + channel_multiplier = ngf / 32 + channels = { + 4: min(512, int(round(4096 * channel_multiplier))), + 8: min(512, int(round(2048 * channel_multiplier))), + 16: min(512, int(round(1024 * channel_multiplier))), + 32: min(512, int(round(512 * channel_multiplier))), + 64: int(round(256 * channel_multiplier)), + 128: int(round(128 * channel_multiplier)), + 256: int(round(64 * channel_multiplier)), + 512: int(round(32 * channel_multiplier)), + 1024: int(round(16 * channel_multiplier)), + } + + num_downsampling = self.opt.stylegan2_G_num_downsampling + cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) // (2 ** num_downsampling) + convs = [] + + for i in range(n_blocks // 2): + n_channel = channels[cur_res] + convs.append(ResBlock(n_channel, n_channel, downsample=False)) + + for i in range(num_downsampling): + in_channel = channels[cur_res] + out_channel = channels[cur_res * 2] + inject_noise = "small" not in self.opt.netG + convs.append( + StyledConv(in_channel, out_channel, 3, upsample=True, blur_kernel=blur_kernel, inject_noise=inject_noise) + ) + cur_res = cur_res * 2 + + convs.append(ConvLayer(channels[cur_res], 3, 1)) + + self.convs = nn.Sequential(*convs) + + def forward(self, input): + return self.convs(input) + + +class StyleGAN2Generator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None): + super().__init__() + self.opt = opt + self.encoder = StyleGAN2Encoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt) + self.decoder = StyleGAN2Decoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt) + + def forward(self, input, layers=[], encode_only=False): + feat, feats = self.encoder(input, layers, True) + if encode_only: + return feats + else: + fake = self.decoder(feat) + + if len(layers) > 0: + return fake, feats + else: + return fake \ No newline at end of file diff --git a/talkingface/model/audio_driven_talkingface/eamm_modules/util.py b/talkingface/model/audio_driven_talkingface/eamm_modules/util.py new file mode 100644 index 00000000..a1b2c3ff --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/eamm_modules/util.py @@ -0,0 +1,1924 @@ +from torch import nn + +import torch.nn.functional as F +import torch +import numpy as np +import cv2 +from .batchnorm import SynchronizedBatchNorm2d as BatchNorm2d + +from .stylegan2 import Generator + +import torch.nn as nn +import math +import torch.utils.model_zoo as model_zoo +from .function import adaptive_instance_normalization as adain + +import pdb + + +# Misc +img2mse = lambda x, y: torch.mean((x - y) ** 2) +mse2psnr = lambda x: -10.0 * torch.log(x) / torch.log(torch.Tensor([10.0])) +to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) + + +class InstanceNorm(nn.Module): + def __init__(self, epsilon=1e-8): + """ + @notice: avoid in-place ops. + https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 + """ + super(InstanceNorm, self).__init__() + self.epsilon = epsilon + + def forward(self, x): + x = x - torch.mean(x, (2, 3), True) + tmp = torch.mul(x, x) # or x ** 2 + tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) + return x * tmp + + +class ApplyStyle(nn.Module): + """ + @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb + """ + + def __init__(self, latent_size, channels, use_wscale): + super(ApplyStyle, self).__init__() + self.linear = FC(latent_size, channels * 2, gain=1.0, use_wscale=use_wscale) + + def forward(self, x, latent): + style = self.linear(latent) # style => [batch_size, n_channels*2] + shape = [-1, 2, x.size(1), 1, 1] + style = style.view(shape) # [batch_size, 2, n_channels, ...] + x = x * (style[:, 0] + 1.0) + style[:, 1] + return x + + +class FC(nn.Module): + def __init__( + self, + in_channels, + out_channels, + gain=2 ** (0.5), + use_wscale=False, + lrmul=1.0, + bias=True, + ): + """ + The complete conversion of Dense/FC/Linear Layer of original Tensorflow version. + """ + super(FC, self).__init__() + he_std = gain * in_channels ** (-0.5) # He init + if use_wscale: + init_std = 1.0 / lrmul + self.w_lrmul = he_std * lrmul + else: + init_std = he_std / lrmul + self.w_lrmul = lrmul + + self.weight = torch.nn.Parameter( + torch.randn(out_channels, in_channels) * init_std + ) + if bias: + self.bias = torch.nn.Parameter(torch.zeros(out_channels)) + self.b_lrmul = lrmul + else: + self.bias = None + + def forward(self, x): + if self.bias is not None: + out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul) + else: + out = F.linear(x, self.weight * self.w_lrmul) + out = F.leaky_relu(out, 0.2, inplace=True) + return out + + +# Positional encoding (section 5.1) +class Embedder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs["input_dims"] + out_dim = 0 + if self.kwargs["include_input"]: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs["max_freq_log2"] + N_freqs = self.kwargs["num_freqs"] + + if self.kwargs["log_sampling"]: + freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) + else: + freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs["periodic_fns"]: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def embed(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + + +def get_embedder(multires, i=0): + if i == -1: + return nn.Identity(), 6 + + embed_kwargs = { + "include_input": True, + "input_dims": 6, + "max_freq_log2": multires - 1, + "num_freqs": multires, + "log_sampling": True, + "periodic_fns": [torch.sin, torch.cos], + } + + embedder_obj = Embedder(**embed_kwargs) + embed = lambda x, eo=embedder_obj: eo.embed(x) + return embed, embedder_obj.out_dim + + +def draw_heatmap(landmark, width, height): + batch = landmark.shape[0] + number = landmark.shape[1] + heatmap = np.zeros((batch, number, width, height), dtype=np.float32) + # draw mouth from mouth landmarks, landmarks: mouth landmark points, format: x1, y1, x2, y2, ..., x20, + + landmark = (landmark + 1) * 29 + for i in range(batch): + for pts_idx in range(number): + if int(landmark[i, pts_idx, 0]) < 0: + landmark[i, pts_idx, 0] = 0 + if int(landmark[i, pts_idx, 1]) < 0: + landmark[i, pts_idx, 1] = 0 + if int(landmark[i, pts_idx, 0]) > 57: + landmark[i, pts_idx, 0] = 57 + if int(landmark[i, pts_idx, 1]) > 57: + landmark[i, pts_idx, 1] = 57 + heatmap[ + i, pts_idx, int(landmark[i, pts_idx, 1]), int(landmark[i, pts_idx, 0]) + ] = 1 + if heatmap[i, pts_idx].sum() == 1: + heatmap[i, pts_idx] = cv2.GaussianBlur( + heatmap[i, pts_idx], ksize=(3, 3), sigmaX=1, sigmaY=1 + ) + + heatmap = torch.tensor(heatmap).cuda() + return heatmap + + +class NA_net(nn.Module): + def __init__(self): + super(NA_net, self).__init__() + + self.decon = nn.Sequential( + nn.ConvTranspose2d( + 1, 16, kernel_size=(2, 3), stride=2, padding=(2, 1), bias=True + ), # 16,16 + nn.BatchNorm2d(16), + nn.ReLU(True), + nn.ConvTranspose2d( + 16, 32, kernel_size=4, stride=2, padding=1, bias=True + ), # 8,8 + nn.BatchNorm2d(32), + nn.ReLU(True), + nn.ConvTranspose2d( + 32, 32 + 3, kernel_size=4, stride=2, padding=1, bias=True + ), # 16,16 + ) + + def forward(self, neutral): + feature = neutral.unsqueeze(1) + current_feature = self.decon(feature) + + return current_feature + + +class AT_net(nn.Module): + def __init__(self): + super(AT_net, self).__init__() + + down_blocks = [] + for i in range(8): + down_blocks.append( + DownBlock2d( + 3 if i == 0 else 2 * (2**i), + 2 * (2 ** (i + 1)), + kernel_size=3, + padding=1, + ) + ) + self.down_blocks = nn.ModuleList(down_blocks) + + # self.lmark_encoder = nn.Sequential( + # nn.Linear(16,256), + # nn.ReLU(True), + # nn.Linear(256,512), + # nn.ReLU(True), + # ) + self.pose_encoder = nn.Sequential( + nn.Linear(6, 128), + nn.ReLU(True), + nn.Linear(128, 256), + nn.ReLU(True), + ) + self.audio_eocder = nn.Sequential( + conv2d(1, 64, 3, 1, 1), + conv2d(64, 128, 3, 1, 1), + nn.MaxPool2d(3, stride=(1, 2)), + conv2d(128, 256, 3, 1, 1), + conv2d(256, 256, 3, 1, 1), + conv2d(256, 512, 3, 1, 1), + nn.MaxPool2d(3, stride=(2, 2)), + ) + self.audio_eocder_fc = nn.Sequential( + nn.Linear(1024 * 12, 2048), + nn.ReLU(True), + nn.Linear(2048, 256), + nn.ReLU(True), + ) + self.lstm = nn.LSTM(256 * 4, 256, 3, batch_first=True) + # self.lstm_fc = nn.Sequential( + # nn.Linear(256,16), + # ) + self.decon = nn.Sequential( + nn.ConvTranspose2d( + 256, 256, kernel_size=6, stride=2, padding=1, bias=True + ), # 4,4 + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 8,8 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 16,16 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 32,32 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 32 + 3, kernel_size=4, stride=2, padding=1, bias=True + ), # 64,64 + # nn.ConvTranspose2d(128, 32*4, kernel_size=2, stride=2, padding=3, bias=True),#64,64 + ) + self.generator = Generator(64, 256, 8) + + def forward(self, example_image, audio, pose, jaco_net): + hidden = ( + torch.autograd.Variable( + torch.zeros(3, audio.size(0), 256).to(example_image.device) + ), + torch.autograd.Variable( + torch.zeros(3, audio.size(0), 256).to(example_image.device) + ), + ) + outs = example_image + for down_block in self.down_blocks: + outs = down_block(outs) + image_feature = outs + image_feature = image_feature.view(image_feature.shape[0], -1) + lstm_input = [] + for step_t in range(audio.size(1)): + current_audio = audio[:, step_t, :, :].unsqueeze(1) + current_feature = self.audio_eocder(current_audio) + current_feature = current_feature.view(current_feature.size(0), -1) + current_feature = self.audio_eocder_fc(current_feature) + pose_f = self.pose_encoder(pose[:, step_t]) + features = torch.cat([image_feature, current_feature, pose_f], 1) + lstm_input.append(features) + lstm_input = torch.stack(lstm_input, dim=1) + lstm_out, _ = self.lstm(lstm_input, hidden) + fc_out = [] + deco_out = [] + for step_t in range(audio.size(1)): + fc_in = lstm_out[:, step_t, :] + # fc_out.append(self.lstm_fc(fc_in)) + if jaco_net == "cnn": + fc_feature = torch.unsqueeze(fc_in, 2) + fc_feature = torch.unsqueeze(fc_feature, 3) + deco_out.append(self.decon(fc_feature)) + elif jaco_net == "gan": + result, _ = self.generator([fc_in]) + deco_out.append(result) + else: + raise Exception("jaco_net type wrong") + + return torch.stack(deco_out, dim=1) + + +class Classify(nn.Module): + def __init__(self): + super(Classify, self).__init__() + + self.last_fc = nn.Linear(512, 8) + + def forward(self, feature): + # mfcc= torch.unsqueeze(mfcc, 1) + + x = self.last_fc(feature) + + return x + + +class TF_net(nn.Module): + def __init__(self): + super(TF_net, self).__init__() + + down_blocks = [] + for i in range(8): + down_blocks.append( + DownBlock2d( + 3 if i == 0 else 2 * (2**i), + 2 * (2 ** (i + 1)), + kernel_size=3, + padding=1, + ) + ) + self.down_blocks = nn.ModuleList(down_blocks) + + # self.lmark_encoder = nn.Sequential( + # nn.Linear(16,256), + # nn.ReLU(True), + # nn.Linear(256,512), + # nn.ReLU(True), + # ) + self.pose_encoder = nn.Sequential( + nn.Linear(6, 128), + nn.ReLU(True), + nn.Linear(128, 256), + nn.ReLU(True), + ) + self.audio_eocder = nn.Sequential( + conv2d(1, 64, 3, 1, 1), + conv2d(64, 128, 3, 1, 1), + nn.MaxPool2d(3, stride=(1, 2)), + conv2d(128, 256, 3, 1, 1), + conv2d(256, 256, 3, 1, 1), + conv2d(256, 512, 3, 1, 1), + nn.MaxPool2d(3, stride=(2, 2)), + ) + self.audio_eocder_fc = nn.Sequential( + nn.Linear(1024 * 12, 2048), + nn.ReLU(True), + nn.Linear(2048, 256), + nn.ReLU(True), + ) + self.lstm = nn.LSTM(256 * 4, 256, 3, batch_first=True) + self.lstm_two = nn.LSTM(256 * 6, 256, 3, batch_first=True) + # self.lstm_fc = nn.Sequential( + # nn.Linear(256,16), + # ) + self.decon = nn.Sequential( + nn.ConvTranspose2d( + 256, 256, kernel_size=6, stride=2, padding=1, bias=True + ), # 4,4 + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 8,8 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 16,16 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 32,32 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 32 + 3, kernel_size=4, stride=2, padding=1, bias=True + ), # 64,64 + # nn.ConvTranspose2d(128, 32*4, kernel_size=2, stride=2, padding=3, bias=True),#64,64 + ) + self.generator = Generator(64, 256, 8) + self.instance_norm = InstanceNorm() + self.style_mod = ApplyStyle(512, 1024, use_wscale=True) + self.style_mod1 = ApplyStyle(512, 35, use_wscale=True) + + def adain_forward(self, example_image, audio, pose, jaco_net, emo_features): + hidden = ( + torch.autograd.Variable( + torch.zeros(3, audio.size(0), 256).to(example_image.device) + ), + torch.autograd.Variable( + torch.zeros(3, audio.size(0), 256).to(example_image.device) + ), + ) + outs = example_image + for down_block in self.down_blocks: + outs = down_block(outs) + image_feature = outs + image_feature = image_feature.view(image_feature.shape[0], -1) + lstm_input = [] + for step_t in range(audio.size(1)): + current_audio = audio[:, step_t, :, :].unsqueeze(1) + current_feature = self.audio_eocder(current_audio) + current_feature = current_feature.view(current_feature.size(0), -1) + current_feature = self.audio_eocder_fc(current_feature) # 256 + pose_f = self.pose_encoder(pose[:, step_t]) # 256 + features = torch.cat([image_feature, current_feature, pose_f], 1) + features = torch.unsqueeze(torch.unsqueeze(features, -1), -1) + features = self.instance_norm(features) + x = self.style_mod(features, emo_features[step_t]) + # t = adain(torch.unsqueeze(torch.unsqueeze(features,-1),-1), torch.unsqueeze(torch.unsqueeze(emo_features[step_t],1),2)) + + lstm_input.append(torch.squeeze(torch.squeeze(x, -1), -1)) + lstm_input = torch.stack(lstm_input, dim=1) + lstm_out, _ = self.lstm(lstm_input, hidden) + # fc_out = [] + deco_out = [] + for step_t in range(audio.size(1)): + fc_in = lstm_out[:, step_t, :] + # fc_out.append(self.lstm_fc(fc_in)) + if jaco_net == "cnn": + fc_feature = torch.unsqueeze(fc_in, 2) + fc_feature = torch.unsqueeze(fc_feature, 3) + deco_out.append(self.decon(fc_feature)) + elif jaco_net == "gan": + result, _ = self.generator([fc_in]) + deco_out.append(result) + else: + raise Exception("jaco_net type wrong") + + return torch.stack(deco_out, dim=1) + + def adain_feature2(self, example_image, audio, pose, jaco_net, emo_features): + hidden = ( + torch.autograd.Variable( + torch.zeros(3, audio.size(0), 256).to(example_image.device) + ), + torch.autograd.Variable( + torch.zeros(3, audio.size(0), 256).to(example_image.device) + ), + ) + outs = example_image + for down_block in self.down_blocks: + outs = down_block(outs) + image_feature = outs + image_feature = image_feature.view(image_feature.shape[0], -1) + lstm_input = [] + for step_t in range(audio.size(1)): + current_audio = audio[:, step_t, :, :].unsqueeze(1) + current_feature = self.audio_eocder(current_audio) + current_feature = current_feature.view(current_feature.size(0), -1) + current_feature = self.audio_eocder_fc(current_feature) # 256 + pose_f = self.pose_encoder(pose[:, step_t]) # 256 + features = torch.cat([image_feature, current_feature, pose_f], 1) + + lstm_input.append(features) + lstm_input = torch.stack(lstm_input, dim=1) + lstm_out, _ = self.lstm(lstm_input, hidden) + # fc_out = [] + deco_out = [] + for step_t in range(audio.size(1)): + fc_in = lstm_out[:, step_t, :] + # fc_out.append(self.lstm_fc(fc_in)) + if jaco_net == "cnn": + fc_feature = torch.unsqueeze(fc_in, 2) + fc_feature = torch.unsqueeze(fc_feature, 3) + fc_feature = self.decon(fc_feature) + fc_feature = self.instance_norm(fc_feature) + t = self.style_mod1(fc_feature, emo_features[step_t]) + # emo_feature = torch.unsqueeze(torch.unsqueeze(emo_features[step_t],-1),-1) + # emo_feature = emo_feature.repeat(1,fc_feature.shape[1],1,1) + # t = adain(fc_feature, emo_feature) + deco_out.append(t) + elif jaco_net == "gan": + result, _ = self.generator([fc_in]) + deco_out.append(result) + else: + raise Exception("jaco_net type wrong") + + return torch.stack(deco_out, dim=1) + + def forward(self, example_image, audio, pose, jaco_net, emo_features): + hidden = ( + torch.autograd.Variable( + torch.zeros(3, audio.size(0), 256).to(example_image.device) + ), + torch.autograd.Variable( + torch.zeros(3, audio.size(0), 256).to(example_image.device) + ), + ) + outs = example_image + for down_block in self.down_blocks: + outs = down_block(outs) + image_feature = outs + image_feature = image_feature.view(image_feature.shape[0], -1) + lstm_input = [] + for step_t in range(audio.size(1)): + current_audio = audio[:, step_t, :, :].unsqueeze(1) + current_feature = self.audio_eocder(current_audio) + current_feature = current_feature.view(current_feature.size(0), -1) + current_feature = self.audio_eocder_fc(current_feature) # 256 + pose_f = self.pose_encoder(pose[:, step_t]) # 256 + features = torch.cat( + [image_feature, current_feature, pose_f, emo_features[step_t]], 1 + ) + lstm_input.append(features) + lstm_input = torch.stack(lstm_input, dim=1) + lstm_out, _ = self.lstm_two(lstm_input, hidden) + fc_out = [] + deco_out = [] + for step_t in range(audio.size(1)): + fc_in = lstm_out[:, step_t, :] + # fc_out.append(self.lstm_fc(fc_in)) + if jaco_net == "cnn": + fc_feature = torch.unsqueeze(fc_in, 2) + fc_feature = torch.unsqueeze(fc_feature, 3) + deco_out.append(self.decon(fc_feature)) + elif jaco_net == "gan": + result, _ = self.generator([fc_in]) + deco_out.append(result) + else: + raise Exception("jaco_net type wrong") + + return torch.stack(deco_out, dim=1) + + +class AT_net2(nn.Module): + def __init__(self): + super(AT_net2, self).__init__() + + down_blocks = [] + for i in range(8): + down_blocks.append( + DownBlock2d( + 3 if i == 0 else 2 * (2**i), + 2 * (2 ** (i + 1)), + kernel_size=3, + padding=1, + ) + ) + self.down_blocks = nn.ModuleList(down_blocks) + + # self.lmark_encoder = nn.Sequential( + # nn.Linear(16,256), + # nn.ReLU(True), + # nn.Linear(256,512), + # nn.ReLU(True), + # ) + self.pose_encoder = nn.Sequential( + nn.Linear(6, 128), + nn.ReLU(True), + nn.Linear(128, 256), + nn.ReLU(True), + ) + self.audio_eocder = nn.Sequential( + conv2d(1, 64, 3, 1, 1), + conv2d(64, 128, 3, 1, 1), + nn.MaxPool2d(3, stride=(1, 2)), + conv2d(128, 256, 3, 1, 1), + conv2d(256, 256, 3, 1, 1), + conv2d(256, 512, 3, 1, 1), + nn.MaxPool2d(3, stride=(2, 2)), + ) + self.audio_eocder_fc = nn.Sequential( + nn.Linear(1024 * 12, 2048), + nn.ReLU(True), + nn.Linear(2048, 256), + nn.ReLU(True), + ) + self.lstm = nn.LSTM(256 * 4, 256, 3, batch_first=True) + # self.lstm_fc = nn.Sequential( + # nn.Linear(256,16), + # ) + self.decon = nn.Sequential( + nn.ConvTranspose2d( + 256, 256, kernel_size=6, stride=2, padding=1, bias=True + ), # 4,4 + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 8,8 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 16,16 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 32,32 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 32 + 3, kernel_size=4, stride=2, padding=1, bias=True + ), # 64,64 + # nn.ConvTranspose2d(128, 32*4, kernel_size=2, stride=2, padding=3, bias=True),#64,64 + ) + self.generator = Generator(64, 256, 8) + + def forward(self, example_image, audio, pose, jaco_net, weight): + hidden = ( + torch.autograd.Variable( + torch.zeros(3, audio.size(0), 256).to(example_image.device) + ), + torch.autograd.Variable( + torch.zeros(3, audio.size(0), 256).to(example_image.device) + ), + ) + outs = example_image + for down_block in self.down_blocks: + outs = down_block(outs) + image_feature = outs + image_feature = image_feature.view(image_feature.shape[0], -1) + lstm_input = [] + for step_t in range(audio.size(1)): + current_audio = audio[:, step_t, :, :].unsqueeze(1) + current_feature = self.audio_eocder(current_audio) + current_feature = current_feature.view(current_feature.size(0), -1) + current_feature = self.audio_eocder_fc(current_feature) * weight + pose_f = self.pose_encoder(pose[:, step_t]) + features = torch.cat([image_feature, current_feature, pose_f], 1) + lstm_input.append(features) + lstm_input = torch.stack(lstm_input, dim=1) + lstm_out, _ = self.lstm(lstm_input, hidden) + fc_out = [] + deco_out = [] + for step_t in range(audio.size(1)): + fc_in = lstm_out[:, step_t, :] + # fc_out.append(self.lstm_fc(fc_in)) + if jaco_net == "cnn": + fc_feature = torch.unsqueeze(fc_in, 2) + fc_feature = torch.unsqueeze(fc_feature, 3) + deco_out.append(self.decon(fc_feature)) + elif jaco_net == "gan": + result, _ = self.generator([fc_in]) + deco_out.append(result) + else: + raise Exception("jaco_net type wrong") + + return torch.stack(deco_out, dim=1) + + +class Ct_encoder(nn.Module): + def __init__(self): + super(Ct_encoder, self).__init__() + self.audio_eocder = nn.Sequential( + conv2d(1, 64, 3, 1, 1), + conv2d(64, 128, 3, 1, 1), + nn.MaxPool2d(3, stride=(1, 2)), + conv2d(128, 256, 3, 1, 1), + conv2d(256, 256, 3, 1, 1), + conv2d(256, 512, 3, 1, 1), + nn.MaxPool2d(3, stride=(2, 2)), + ) + self.audio_eocder_fc = nn.Sequential( + nn.Linear(1024 * 12, 2048), + nn.ReLU(True), + nn.Linear(2048, 256), + nn.ReLU(True), + ) + + def forward(self, audio): + feature = self.audio_eocder(audio) + feature = feature.view(feature.size(0), -1) + x = self.audio_eocder_fc(feature) + + return x + + +class EmotionNet(nn.Module): + def __init__(self): + super(EmotionNet, self).__init__() + + self.emotion_eocder = nn.Sequential( + conv2d(1, 64, 3, 1, 1), + nn.MaxPool2d((1, 3), stride=(1, 2)), # [1, 64, 12, 12] + conv2d(64, 128, 3, 1, 1), + conv2d(128, 256, 3, 1, 1), + nn.MaxPool2d((12, 1), stride=(12, 1)), # [1, 256, 1, 12] + conv2d(256, 512, 3, 1, 1), + nn.MaxPool2d((1, 2), stride=(1, 2)), # [1, 512, 1, 6] + ) + self.emotion_eocder_fc = nn.Sequential( + nn.Linear(512 * 6, 2048), + nn.ReLU(True), + nn.Linear(2048, 128), + nn.ReLU(True), + ) + + self.last_fc = nn.Linear(128, 8) + + self.re_id = nn.Sequential( + conv2d(512, 1024, 3, 1, 1), + nn.MaxPool2d((1, 2), stride=(1, 2)), # [1, 1024, 1, 3] + conv2d(1024, 1024, 3, 1, 1), + conv2d(1024, 2048, 3, 1, 1), + nn.MaxPool2d((1, 2), stride=(1, 2)), # [1, 2048, 1, 1] + ) + self.re_id_fc = nn.Sequential( + nn.Linear(2048, 512), + nn.ReLU(True), + nn.Linear(512, 128), + nn.ReLU(True), + ) + + def forward(self, mfcc): + # mfcc= torch.unsqueeze(mfcc, 1) + mfcc = torch.transpose(mfcc, 2, 3) + feature = self.emotion_eocder(mfcc) + + # id_feature = feature.detach() + + feature = feature.view(feature.size(0), -1) + x = self.emotion_eocder_fc(feature) + + # remove_feature = self.re_id(id_feature) + # remove_feature = remove_feature.view(remove_feature.size(0),-1) + # y = self.re_id_fc(remove_feature) + + return x + + +class AF2F(nn.Module): + def __init__(self): + super(AF2F, self).__init__() + self.decon = nn.Sequential( + nn.ConvTranspose2d( + 384, 256, kernel_size=6, stride=2, padding=1, bias=True + ), # 4,4 + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 8,8 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 64, kernel_size=4, stride=2, padding=1, bias=True + ), # 16,16 + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d( + 64, 64, kernel_size=4, stride=2, padding=1, bias=True + ), # 32,32 + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d( + 64, 32 + 3, kernel_size=4, stride=2, padding=1, bias=True + ), # 64,64 + ) + + def forward(self, content, emotion): + features = torch.cat( + [content, emotion], 1 + ) # connect tensors inputs and dimension + features = torch.unsqueeze(features, 2) + features = torch.unsqueeze(features, 3) + x = self.decon(features) + + return x + + +class AF2F_s(nn.Module): + def __init__(self): + super(AF2F_s, self).__init__() + self.decon = nn.Sequential( + nn.ConvTranspose2d( + 256, 256, kernel_size=6, stride=2, padding=1, bias=True + ), # 4,4 + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 8,8 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 64, kernel_size=4, stride=2, padding=1, bias=True + ), # 16,16 + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d( + 64, 64, kernel_size=4, stride=2, padding=1, bias=True + ), # 32,32 + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d( + 64, 32 + 3, kernel_size=4, stride=2, padding=1, bias=True + ), # 64,64 + nn.ReLU(), + ) + + def forward(self, content): + # features = torch.cat([content, emotion], 1) #connect tensors inputs and dimension + features = torch.unsqueeze(content, 2) + features = torch.unsqueeze(features, 3) + x = self.decon(features) + + return x + + +class A2I(nn.Module): + def __init__(self): + super(A2I, self).__init__() + self.audio_eocder = nn.Sequential( + conv2d(1, 64, 3, 1, 1), + conv2d(64, 128, 3, 1, 1), + nn.MaxPool2d((1, 5), stride=(1, 2)), + conv2d(128, 256, 3, 1, 1), + conv2d(256, 256, 3, 1, 1), + nn.MaxPool2d((5, 5), stride=(2, 2)), + ) + self.decon = nn.Sequential( + nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 8,8 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 64, kernel_size=4, stride=2, padding=1, bias=True + ), # 16,16 + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d( + 64, 32, kernel_size=4, stride=2, padding=1, bias=True + ), # 32,32 + nn.BatchNorm2d(32), + nn.ReLU(True), + nn.ConvTranspose2d( + 32, 2, kernel_size=4, stride=2, padding=1, bias=True + ), # 64,64 + nn.ReLU(), + ) + + def forward(self, mfcc): + mfcc = torch.unsqueeze(mfcc, 1) + mfcc = torch.transpose(mfcc, 2, 3) + feature = self.audio_eocder(mfcc) + + # id_feature = feature.detach() + + x = self.decon(feature) + + return x + + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp["value"] # [4,10,2] + + coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) # [h,w,2] + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape # 5 + coordinate_grid = coordinate_grid.view(*shape) # [1,1,h,w,2] + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) # [4,10,h,w,2] + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2) + mean = mean.view(*shape) # [4,10,1,1,2] + + mean_sub = coordinate_grid - mean + + out = torch.exp(-0.5 * (mean_sub**2).sum(-1) / kp_variance) + + return out + + +def make_coordinate_grid(spatial_size, type): + """ + Create a meshgrid [-1,1] x [-1,1] of given spatial_size. + """ + h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + + x = 2 * (x / (w - 1)) - 1 + y = 2 * (y / (h - 1)) - 1 + + yy = y.view(-1, 1).repeat(1, w) + xx = x.view(1, -1).repeat(h, 1) + + meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) + + return meshed + + +class ResBlock2d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.conv1 = nn.Conv2d( + in_channels=in_features, + out_channels=in_features, + kernel_size=kernel_size, + padding=padding, + ) + self.conv2 = nn.Conv2d( + in_channels=in_features, + out_channels=in_features, + kernel_size=kernel_size, + padding=padding, + ) + self.norm1 = BatchNorm2d(in_features, affine=True) + self.norm2 = BatchNorm2d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock2d, self).__init__() + + self.conv = nn.Conv2d( + in_channels=in_features, + out_channels=out_features, + kernel_size=kernel_size, + padding=padding, + groups=groups, + ) + self.norm = BatchNorm2d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d( + in_channels=in_features, + out_channels=out_features, + kernel_size=kernel_size, + padding=padding, + groups=groups, + ) + self.norm = BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d( + in_channels=in_features, + out_channels=out_features, + kernel_size=kernel_size, + padding=padding, + groups=groups, + ) + self.norm = BatchNorm2d(out_features, affine=True) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append( + DownBlock2d( + in_features + if i == 0 + else min(max_features, block_expansion * (2**i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, + padding=1, + ) + ) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min( + max_features, block_expansion * (2 ** (i + 1)) + ) + out_filters = min(max_features, block_expansion * (2**i)) + up_blocks.append( + UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1) + ) + + self.up_blocks = nn.ModuleList(up_blocks) + self.out_filters = block_expansion + in_features + + def forward(self, x): + out = x.pop() + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class AntiAliasInterpolation2d(nn.Module): + """ + Band-limited downsampling, for better preservation of the input signal. + """ + + def __init__(self, channels, scale): + super(AntiAliasInterpolation2d, self).__init__() + # sigma = (1 / scale - 1) / 2 + sigma = 1.5 + kernel_size = 2 * round(sigma * 4) + 1 + self.ka = kernel_size // 2 + self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka + + kernel_size = [kernel_size, kernel_size] + sigma = [sigma, sigma] + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [torch.arange(size, dtype=torch.float32) for size in kernel_size] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= torch.exp(-((mgrid - mean) ** 2) / (2 * std**2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer("weight", kernel) + self.groups = channels + self.scale = scale + inv_scale = 1 / scale + self.int_inv_scale = int(inv_scale) + + def forward(self, input): + if self.scale == 1.0: + return input + + out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) + out = F.conv2d(out, weight=self.weight, groups=self.groups) + out = out[:, :, :: self.int_inv_scale, :: self.int_inv_scale] + + return out + + +def sigmoid(x): + return 1 / (1 + math.exp(-x)) + + +def norm_angle(angle): + norm_angle = sigmoid(10 * (abs(angle) / 0.7853975 - 1)) + return norm_angle + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class EmDetector(nn.Module): + """ + Detecting a keypoints. Return keypoint position and jacobian near each keypoint. + """ + + def __init__( + self, + block_expansion, + num_channels, + max_features, + num_blocks, + scale_factor=1, + num_classes=8, + ): + super(EmDetector, self).__init__() + self.inplanes = 64 + self.predictor = Hourglass( + block_expansion, + in_features=num_channels, + max_features=max_features, + num_blocks=num_blocks, + ) + + self.scale_factor = scale_factor + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) + self.conv1 = nn.Conv2d( + self.predictor.out_filters, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + layers = [2, 2, 2, 2] + self.layer1 = self._make_layer(BasicBlock, 64, layers[0]) + self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2) + self.layer3 = self._make_layer(BasicBlock, 256, layers[2], stride=2) + self.layer4 = self._make_layer(BasicBlock, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes) + self.classify = Classify() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def adain_feature(self, x): # torch.Size([4, 3, H, W]) + if self.scale_factor != 1: + x = self.down(x) # 0.25 [4, 3, H/4, W/4] + + feature_map = self.predictor(x) # [4,3+32,H/4, W/4] + + # out = self.fc(out) + + return feature_map + + def forward(self, x): # torch.Size([4, 3, H, W]) + if self.scale_factor != 1: + x = self.down(x) # 0.25 [4, 3, H/4, W/4] + + feature_map = self.predictor(x) # [4,3+32,H/4, W/4] + f = self.conv1(feature_map) # [16,64,64,64] + f = self.bn1(f) # torch.Size([16, 64, 64, 64]) + f = self.relu(f) + f = self.maxpool(f) # [16, 64, 32, 32] + + f = self.layer1(f) # [16, 64, 32, 32] + f = self.layer2(f) # [16, 128, 16, 16]) + f = self.layer3(f) # [16, 256, 8, 8] + f = self.layer4(f) # [16, 512, 4, 4] + f = self.avgpool(f) # [16, 512, 1, 1] + out = f.squeeze(3).squeeze(2) + fake = self.classify(out) + # out = self.fc(out) + + return out, fake + + +class Emotion_k(nn.Module): + """ + Detecting a keypoints. Return keypoint position and jacobian near each keypoint. + """ + + def __init__( + self, + block_expansion, + num_channels, + max_features, + num_blocks, + scale_factor=1, + num_classes=8, + ): + super(Emotion_k, self).__init__() + self.inplanes = 64 + self.predictor = Hourglass( + block_expansion, + in_features=num_channels, + max_features=max_features, + num_blocks=num_blocks, + ) + + self.scale_factor = scale_factor + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) + self.conv1 = nn.Conv2d( + self.predictor.out_filters, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + layers = [2, 2, 2, 2] + self.layer1 = self._make_layer(BasicBlock, 64, layers[0]) + self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2) + self.layer3 = self._make_layer(BasicBlock, 256, layers[2], stride=2) + self.layer4 = self._make_layer(BasicBlock, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes) + + self.embed_fn, self.input_ch = get_embedder(10, 0) + + self.fc_p = nn.Sequential( + nn.Linear(10 * 126, 1024), + nn.ReLU(True), + nn.Linear(1024, 512), + nn.ReLU(True), + ) + self.fc_n = nn.Sequential( + nn.Linear(10 * 6, 128), + nn.ReLU(True), + nn.Linear(128, 512), + nn.ReLU(True), + ) + + self.fc_all = nn.Sequential( + nn.Linear(1024, 512), + nn.ReLU(True), + nn.Linear(512, 256), + nn.ReLU(True), + nn.Linear(256, 64), + nn.ReLU(True), + ) + + # self.fc_single = nn.Sequential( + # nn.Linear(512,256), + # nn.ReLU(True), + # nn.Linear(256,64), + # nn.ReLU(True), + # ) + + self.final = nn.Sequential( + nn.Conv1d(1, 2, 4, 2, 1), + nn.MaxPool1d(2, stride=2), + nn.ReLU(True), + nn.Conv1d(2, 4, 4, 2, 1), + nn.ReLU(True), + nn.Conv1d(4, 4, 3), + ) + + self.final_4 = nn.Sequential( + nn.Conv1d(4, 4, 3, 1, 1), + nn.MaxPool1d(2, stride=2), + nn.ReLU(True), + nn.Conv1d(4, 4, 3, 1), + ) + + self.final_10 = nn.Sequential( + nn.Conv1d(4, 8, 3, 1, 1), # [B,8,16] + nn.MaxPool1d(2, stride=2), # [B,8,8] + nn.ReLU(True), + nn.Conv1d(8, 10, 3, 1), # [B,10,6] + ) + + self.classify = Classify() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def linear_10(self, x, value, jacobian): # torch.Size([4, 3, H, W]) + if self.scale_factor != 1: + x = self.down(x) # 0.25 [4, 3, H/4, W/4] + + feature_map = self.predictor(x) # [4,3+32,H/4, W/4] + f = self.conv1(feature_map) # [16,64,64,64] + f = self.bn1(f) # torch.Size([16, 64, 64, 64]) + f = self.relu(f) + f = self.maxpool(f) # [16, 64, 32, 32] + + f = self.layer1(f) # [16, 64, 32, 32] + f = self.layer2(f) # [16, 128, 16, 16]) + f = self.layer3(f) # [16, 256, 8, 8] + f = self.layer4(f) # [16, 512, 4, 4] + f = self.avgpool(f) # [16, 512, 1, 1] + out = f.squeeze(3).squeeze(2) + fake = self.classify(out) + jacobian = jacobian.reshape(jacobian.shape[0], jacobian.shape[1], 4) + neu_input = torch.cat((value, jacobian), 2) + posi_input = self.embed_fn(neu_input) + posi_input = posi_input.reshape(posi_input.shape[0], -1) + ner_feature = self.fc_p(posi_input) + all_fc = self.fc_all(torch.cat((out, ner_feature), 1)).reshape(-1, 4, 16) + result = self.final_10(all_fc) + e_value = result[:, :, :2] + e_jacobian = result[:, :, 2:].reshape(result.shape[0], 10, 2, 2) + kp = {"value": e_value, "jacobian": e_jacobian} + + return kp, fake + + def linear_4(self, x, value, jacobian): # torch.Size([4, 3, H, W]) + if self.scale_factor != 1: + x = self.down(x) # 0.25 [4, 3, H/4, W/4] + + feature_map = self.predictor(x) # [4,3+32,H/4, W/4] + f = self.conv1(feature_map) # [16,64,64,64] + f = self.bn1(f) # torch.Size([16, 64, 64, 64]) + f = self.relu(f) + f = self.maxpool(f) # [16, 64, 32, 32] + + f = self.layer1(f) # [16, 64, 32, 32] + f = self.layer2(f) # [16, 128, 16, 16]) + f = self.layer3(f) # [16, 256, 8, 8] + f = self.layer4(f) # [16, 512, 4, 4] + f = self.avgpool(f) # [16, 512, 1, 1] + out = f.squeeze(3).squeeze(2) + fake = self.classify(out) + # jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4) + # neu_input = torch.cat((value,jacobian),2) + # posi_input = self.embed_fn(neu_input) + # posi_input =posi_input.reshape(posi_input.shape[0],-1) + # ner_feature = self.fc_p(posi_input) + # all_fc = self.fc_all(torch.cat((out,ner_feature),1)).reshape(-1,4,16) + all_fc = torch.unsqueeze(self.fc_single(out), 1) + result = self.final(all_fc) + e_value = result[:, :, :2] + e_jacobian = result[:, :, 2:].reshape(result.shape[0], 4, 2, 2) + kp = {"value": e_value, "jacobian": e_jacobian} + # out = self.fc(out) + + return kp, fake + + def linear_np_10(self, x, value, jacobian): # torch.Size([4, 3, H, W]) + if self.scale_factor != 1: + x = self.down(x) # 0.25 [4, 3, H/4, W/4] + + feature_map = self.predictor(x) # [4,3+32,H/4, W/4] + f = self.conv1(feature_map) # [16,64,64,64] + f = self.bn1(f) # torch.Size([16, 64, 64, 64]) + f = self.relu(f) + f = self.maxpool(f) # [16, 64, 32, 32] + + f = self.layer1(f) # [16, 64, 32, 32] + f = self.layer2(f) # [16, 128, 16, 16]) + f = self.layer3(f) # [16, 256, 8, 8] + f = self.layer4(f) # [16, 512, 4, 4] + f = self.avgpool(f) # [16, 512, 1, 1] + out = f.squeeze(3).squeeze(2) + fake = self.classify(out) + jacobian = jacobian.reshape(jacobian.shape[0], jacobian.shape[1], 4) + neu_input = torch.cat((value, jacobian), 2) + + posi_input = neu_input.reshape(neu_input.shape[0], -1) + ner_feature = self.fc_n(posi_input) + all_fc = self.fc_all(torch.cat((out, ner_feature), 1)).reshape(-1, 4, 16) + result = self.final_10(all_fc) + e_value = result[:, :, :2] + e_jacobian = result[:, :, 2:].reshape(result.shape[0], 10, 2, 2) + kp = {"value": e_value, "jacobian": e_jacobian} + # out = self.fc(out) + + return kp, fake + + def linear_np_4(self, x, value, jacobian): # torch.Size([4, 3, H, W]) + if self.scale_factor != 1: + x = self.down(x) # 0.25 [4, 3, H/4, W/4] + + feature_map = self.predictor(x) # [4,3+32,H/4, W/4] + f = self.conv1(feature_map) # [16,64,64,64] + f = self.bn1(f) # torch.Size([16, 64, 64, 64]) + f = self.relu(f) + f = self.maxpool(f) # [16, 64, 32, 32] + + f = self.layer1(f) # [16, 64, 32, 32] + f = self.layer2(f) # [16, 128, 16, 16]) + f = self.layer3(f) # [16, 256, 8, 8] + f = self.layer4(f) # [16, 512, 4, 4] + f = self.avgpool(f) # [16, 512, 1, 1] + out = f.squeeze(3).squeeze(2) + fake = self.classify(out) + jacobian = jacobian.reshape(jacobian.shape[0], jacobian.shape[1], 4) + neu_input = torch.cat((value, jacobian), 2) + + posi_input = neu_input.reshape(neu_input.shape[0], -1) + ner_feature = self.fc_n(posi_input) + all_fc = torch.unsqueeze(self.fc_all(torch.cat((out, ner_feature), 1)), 1) + result = self.final(all_fc) + e_value = result[:, :, :2] + e_jacobian = result[:, :, 2:].reshape(result.shape[0], 4, 2, 2) + kp = {"value": e_value, "jacobian": e_jacobian} + # out = self.fc(out) + + return kp, fake + + def emotion_feature(self, feature, value, jacobian): # torch.Size([4, 3, H, W]) + out = feature + fake = self.classify(out) + jacobian = jacobian.reshape(jacobian.shape[0], jacobian.shape[1], 4) + neu_input = torch.cat((value, jacobian), 2) + posi_input = self.embed_fn(neu_input) + posi_input = posi_input.reshape(posi_input.shape[0], -1) + ner_feature = self.fc_p(posi_input) + all_fc = torch.unsqueeze(self.fc_all(torch.cat((out, ner_feature), 1)), 1) + result = self.final(all_fc) + e_value = result[:, :, :2] + e_jacobian = result[:, :, 2:].reshape(result.shape[0], 4, 2, 2) + kp = {"value": e_value, "jacobian": e_jacobian} + # out = self.fc(out) + + return kp, fake + + def feature(self, x): # torch.Size([4, 3, H, W]) + if self.scale_factor != 1: + x = self.down(x) # 0.25 [4, 3, H/4, W/4] + + feature_map = self.predictor(x) # [4,3+32,H/4, W/4] + f = self.conv1(feature_map) # [16,64,64,64] + f = self.bn1(f) # torch.Size([16, 64, 64, 64]) + f = self.relu(f) + f = self.maxpool(f) # [16, 64, 32, 32] + + f = self.layer1(f) # [16, 64, 32, 32] + f = self.layer2(f) # [16, 128, 16, 16]) + f = self.layer3(f) # [16, 256, 8, 8] + f = self.layer4(f) # [16, 512, 4, 4] + f = self.avgpool(f) # [16, 512, 1, 1] + out = f.squeeze(3).squeeze(2) + + # out = self.fc(out) + + return out + + def forward(self, x, value, jacobian): # torch.Size([4, 3, H, W]) + if self.scale_factor != 1: + x = self.down(x) # 0.25 [4, 3, H/4, W/4] + + feature_map = self.predictor(x) # [4,3+32,H/4, W/4] + f = self.conv1(feature_map) # [16,64,64,64] + f = self.bn1(f) # torch.Size([16, 64, 64, 64]) + f = self.relu(f) + f = self.maxpool(f) # [16, 64, 32, 32] + + f = self.layer1(f) # [16, 64, 32, 32] + f = self.layer2(f) # [16, 128, 16, 16]) + f = self.layer3(f) # [16, 256, 8, 8] + f = self.layer4(f) # [16, 512, 4, 4] + f = self.avgpool(f) # [16, 512, 1, 1] + out = f.squeeze(3).squeeze(2) + fake = self.classify(out) + jacobian = jacobian.reshape(jacobian.shape[0], jacobian.shape[1], 4) + neu_input = torch.cat((value, jacobian), 2) + posi_input = self.embed_fn(neu_input) + posi_input = posi_input.reshape(posi_input.shape[0], -1) + ner_feature = self.fc_p(posi_input) + all_fc = torch.unsqueeze(self.fc_all(torch.cat((out, ner_feature), 1)), 1) + result = self.final(all_fc) + e_value = result[:, :, :2] + e_jacobian = result[:, :, 2:].reshape(result.shape[0], 4, 2, 2) + kp = {"value": e_value, "jacobian": e_jacobian} + # out = self.fc(out) + + return kp, fake + + +class Emotion_map(nn.Module): + """ + Detecting a keypoints. Return keypoint position and jacobian near each keypoint. + """ + + def __init__( + self, + block_expansion, + num_channels, + max_features, + num_blocks, + scale_factor=1, + num_classes=8, + ): + super(Emotion_map, self).__init__() + self.inplanes = 64 + self.predictor = Hourglass( + block_expansion, + in_features=num_channels, + max_features=max_features, + num_blocks=num_blocks, + ) + + self.scale_factor = scale_factor + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) + self.conv1 = nn.Conv2d( + self.predictor.out_filters, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + layers = [2, 2, 2, 2] + self.layer1 = self._make_layer(BasicBlock, 64, layers[0]) + self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2) + self.layer3 = self._make_layer(BasicBlock, 256, layers[2], stride=2) + self.layer4 = self._make_layer(BasicBlock, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes) + + self.embed_fn, self.input_ch = get_embedder(10, 0) + + self.fc_p = nn.Sequential( + nn.Linear(10 * 126, 1024), + nn.ReLU(True), + nn.Linear(1024, 512), + nn.ReLU(True), + ) + + self.fc_all = nn.Sequential(nn.Linear(1024, 2048), nn.ReLU(True)) + + self.final = nn.Sequential( + nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=True + ), # 8,8 + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d( + 128, 64, kernel_size=4, stride=2, padding=1, bias=True + ), # 16,16 + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d( + 64, 64, kernel_size=4, stride=2, padding=1, bias=True + ), # 32,32 + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d( + 64, 32 + 3, kernel_size=4, stride=2, padding=1, bias=True + ), # 64,64 + ) + + self.classify = Classify() + self.kp = nn.Conv2d( + in_channels=35, out_channels=10, kernel_size=(7, 7), padding=0 + ) + self.jacobian = nn.Conv2d( + in_channels=35, out_channels=4 * 10, kernel_size=(7, 7), padding=0 + ) + self.jacobian.weight.data.zero_() + self.jacobian.bias.data.copy_( + torch.tensor([1, 0, 0, 1] * 10, dtype=torch.float) + ) + self.temperature = 0.1 + + self.kp_4 = nn.Conv2d( + in_channels=35, out_channels=4, kernel_size=(7, 7), padding=0 + ) + self.jacobian_4 = nn.Conv2d( + in_channels=35, out_channels=4 * 4, kernel_size=(7, 7), padding=0 + ) + self.jacobian_4.weight.data.zero_() + self.jacobian_4.bias.data.copy_( + torch.tensor([1, 0, 0, 1] * 4, dtype=torch.float) + ) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def gaussian2kp(self, heatmap): + """ + Extract the mean and from a heatmap + """ + shape = heatmap.shape + heatmap = heatmap.unsqueeze(-1) # [4,10,58,58,1] + grid = ( + make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) + ) # [1,1,58,58,2] + value = (heatmap * grid).sum(dim=(2, 3)) # [4,10,2] + kp = {"value": value} + + return kp + + def map_4(self, x, value, jacobian): # torch.Size([4, 3, H, W]) + if self.scale_factor != 1: + x = self.down(x) # 0.25 [4, 3, H/4, W/4] + + feature_map = self.predictor(x) # [4,3+32,H/4, W/4] + f = self.conv1(feature_map) # [16,64,64,64] + f = self.bn1(f) # torch.Size([16, 64, 64, 64]) + f = self.relu(f) + f = self.maxpool(f) # [16, 64, 32, 32] + + f = self.layer1(f) # [16, 64, 32, 32] + f = self.layer2(f) # [16, 128, 16, 16]) + f = self.layer3(f) # [16, 256, 8, 8] + f = self.layer4(f) # [16, 512, 4, 4] + f = self.avgpool(f) # [16, 512, 1, 1] + out = f.squeeze(3).squeeze(2) + fake = self.classify(out) + jacobian = jacobian.reshape(jacobian.shape[0], jacobian.shape[1], 4) + neu_input = torch.cat((value, jacobian), 2) + posi_input = self.embed_fn(neu_input) + posi_input = posi_input.reshape(posi_input.shape[0], -1) + ner_feature = self.fc_p(posi_input) + all_fc = self.fc_all(torch.cat((out, ner_feature), 1)).reshape(-1, 128, 4, 4) + feature_map = self.final(all_fc) + prediction = self.kp_4(feature_map) # [4,10,H/4-6, W/4-6] + + final_shape = prediction.shape + + heatmap = prediction.view(final_shape[0], final_shape[1], -1) # [4, 10, 58*58] + heatmap = F.softmax(heatmap / self.temperature, dim=2) + heatmap = heatmap.view(*final_shape) # [4,10,58,58] + + out = self.gaussian2kp(heatmap) + out["heatmap"] = heatmap + + if self.jacobian is not None: + jacobian_map = self.jacobian_4(feature_map) ##[4,40,H/4-6, W/4-6] + jacobian_map = jacobian_map.reshape( + final_shape[0], 4, 4, final_shape[2], final_shape[3] + ) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map # [4,10,4,H/4-6, W/4-6] + jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) + jacobian = jacobian.sum(dim=-1) # [4,10,4] + jacobian = jacobian.view( + jacobian.shape[0], jacobian.shape[1], 2, 2 + ) # [4,10,2,2] + out["jacobian"] = jacobian + + return out, fake + + def forward(self, x, value, jacobian): # torch.Size([4, 3, H, W]) + if self.scale_factor != 1: + x = self.down(x) # 0.25 [4, 3, H/4, W/4] + + feature_map = self.predictor(x) # [4,3+32,H/4, W/4] + f = self.conv1(feature_map) # [16,64,64,64] + f = self.bn1(f) # torch.Size([16, 64, 64, 64]) + f = self.relu(f) + f = self.maxpool(f) # [16, 64, 32, 32] + + f = self.layer1(f) # [16, 64, 32, 32] + f = self.layer2(f) # [16, 128, 16, 16]) + f = self.layer3(f) # [16, 256, 8, 8] + f = self.layer4(f) # [16, 512, 4, 4] + f = self.avgpool(f) # [16, 512, 1, 1] + out = f.squeeze(3).squeeze(2) + fake = self.classify(out) + jacobian = jacobian.reshape(jacobian.shape[0], jacobian.shape[1], 4) + neu_input = torch.cat((value, jacobian), 2) + posi_input = self.embed_fn(neu_input) + posi_input = posi_input.reshape(posi_input.shape[0], -1) + ner_feature = self.fc_p(posi_input) + all_fc = self.fc_all(torch.cat((out, ner_feature), 1)).reshape(-1, 128, 4, 4) + feature_map = self.final(all_fc) + + prediction = self.kp(feature_map) # [4,10,H/4-6, W/4-6] + + final_shape = prediction.shape + + heatmap = prediction.view(final_shape[0], final_shape[1], -1) # [4, 10, 58*58] + heatmap = F.softmax(heatmap / self.temperature, dim=2) + heatmap = heatmap.view(*final_shape) # [4,10,58,58] + + out = self.gaussian2kp(heatmap) + out["heatmap"] = heatmap + + if self.jacobian is not None: + jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6] + jacobian_map = jacobian_map.reshape( + final_shape[0], 10, 4, final_shape[2], final_shape[3] + ) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map # [4,10,4,H/4-6, W/4-6] + jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) + jacobian = jacobian.sum(dim=-1) # [4,10,4] + jacobian = jacobian.view( + jacobian.shape[0], jacobian.shape[1], 2, 2 + ) # [4,10,2,2] + out["jacobian"] = jacobian + + return out, fake + + +def conv2d( + channel_in, + channel_out, + ksize=3, + stride=1, + padding=1, + activation=nn.ReLU, + normalizer=nn.BatchNorm2d, +): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.Conv2d(channel_in, channel_out, ksize, stride, padding, bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[0].weight) + + return nn.Sequential(*layer) + + +def _apply(layer, activation, normalizer, channel_out=None): + if normalizer: + layer.append(normalizer(channel_out)) + if activation: + layer.append(activation()) + return layer diff --git a/talkingface/properties/dataset/mead.yaml b/talkingface/properties/dataset/mead.yaml new file mode 100644 index 00000000..256f28d3 --- /dev/null +++ b/talkingface/properties/dataset/mead.yaml @@ -0,0 +1,2 @@ + +need_preprocess: True \ No newline at end of file diff --git a/talkingface/properties/model/EAMM.yaml b/talkingface/properties/model/EAMM.yaml new file mode 100644 index 00000000..38ecb0ce --- /dev/null +++ b/talkingface/properties/model/EAMM.yaml @@ -0,0 +1,24 @@ +# Prepocessed +data_root: 'dataset/mead/data' +preprocessed_root: 'dataset/mead/_data' +_preprocessed_root: 'dataset/mead/preprocessed_data' +need_preprocess: True +config_root: 'talkingface/properties/model/EAMM' + +# Train +mode: "train_part1" # "train_part1", "train_part1_fine_tune", "train_part2" +checkpoint_sub_dir: "/eamm" # 和overall.yaml里checkpoint_dir拼起来作为最终目录 +temp_sub_dir: "/eamm" # 和overall.yaml里temp_dir拼起来作为最终目录 +batch_size: 16 +ngpu: 1 +epochs: 5 + +audio_checkpoint: 'checkpoints/EAMM/1-6000.pth.tar' +checkpoint: 'checkpoints/EAMM/124_52000.pth.tar' +emo_checkpoint: 'checkpoints/EAMM/5-3000.pth.tar' +device_ids: [0] +verbose: False +use_gpu: False +device: 'cpu' +gpu_id: '0' +train: False \ No newline at end of file diff --git a/talkingface/properties/model/EAMM/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml b/talkingface/properties/model/EAMM/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml new file mode 100644 index 00000000..2207228f --- /dev/null +++ b/talkingface/properties/model/EAMM/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml @@ -0,0 +1,105 @@ +dataset_params: + root_dir: /mnt/lustre/share_data/jixinya/MEAD/ + frame_shape: [256, 256, 3] + id_sampling: False + pairs_list: Random_choice + augmentation_params: + crop_mouth_param: + center_x: 135 + center_y: 190 + mask_width: 100 + mask_height: 60 + rotation_param: + degrees: 30 + perspective_param: + pers_num: 30 + enlarge_num: 40 + flip_param: + horizontal_flip: True + time_flip: False + jitter_param: + brightness: 0 + contrast: 0 + saturation: 0 + hue: 0 + +model_params: + common_params: + num_kp: 10 + num_channels: 3 + estimate_jacobian: True + audio_params: + num_kp: 10 + num_channels : 3 + num_channels_a : 3 + estimate_jacobian: True + kp_detector_params: + temperature: 0.1 + block_expansion: 32 + max_features: 1024 + scale_factor: 0.25 + num_blocks: 5 + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + num_bottleneck_blocks: 6 + estimate_occlusion_map: True + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 + discriminator_params: + scales: [1] + block_expansion: 32 + max_features: 512 + num_blocks: 4 + sn: True + +train_params: + type: linear_4 + smooth: False + jaco_net: cnn + ldmark: fake + generator: not + train_generator: False + num_epochs: 300 + num_repeats: 1 + epoch_milestones: [60, 90] + lr_generator: 2.0e-4 + lr_discriminator: 2.0e-4 + lr_kp_detector: 2.0e-4 + lr_audio_feature: 2.0e-4 + batch_size: 16 + scales: [1, 0.5, 0.25, 0.125] + checkpoint_freq: 1 + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + generator_gan: 0 + discriminator_gan: 1 + feature_matching: [10, 10, 10, 10] + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 0 + equivariance_jacobian: 0 + emo: 10 + +reconstruction_params: + num_videos: 1000 + format: '.mp4' + +animate_params: + num_pairs: 50 + format: '.mp4' + normalization_params: + adapt_movement_scale: False + use_relative_movement: True + use_relative_jacobian: True + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' diff --git a/talkingface/properties/model/EAMM/evaluate.yaml b/talkingface/properties/model/EAMM/evaluate.yaml new file mode 100644 index 00000000..cbfa6de2 --- /dev/null +++ b/talkingface/properties/model/EAMM/evaluate.yaml @@ -0,0 +1,57 @@ +cpu: True +config: 'talkingface/properties/model/EAMM/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml' +audio_checkpoint: 'checkpoints/EAMM/1-6000.pth.tar' +checkpoint: 'checkpoints/EAMM/124_52000.pth.tar' +emo_checkpoint: 'checkpoints/EAMM/5-3000.pth.tar' + +source_image: 'dataset/mead/test/image/21.png' +driving_video: 'dataset/mead/test/video/disgusted.mp4' +in_file: 'dataset/mead/test/audio/sample1.mov' +pose_file: 'dataset/mead/test/pose/21.npy' +pose_given: 'dataset/mead/test/pose_long/0zn70Ak8lRc_Daniel_Auteuil_0zn70Ak8lRc_0002.npy' +emotion: 'disgusted' # 'angry', 'contempt','disgusted','fear','happy','neutral','sad','surprised' + + +result_path: 'saved/eamm/output/' +relative: False +adapt_scale: False + + +kp_loss: 0 +smooth_pose: True +pose_long: False +weight: 0 +add_emo: False +check_add: False +type: "linear_3" + + + +# model params + +model_common_num_kp: 10 +model_common_num_channels: 3 +model_common_estimate_jacobian: True +model_audio_num_kp: 10 +model_audio_num_channels : 3 +model_audio_num_channels_a : 3 +model_audio_estimate_jacobian: True +model_kp_detector_temperature: 0.1 +model_kp_detector_block_expansion: 32 +model_kp_detector_max_features: 1024 +model_kp_detector_scale_factor: 0.25 +model_kp_detector_num_blocks: 5 +model_generator_block_expansion: 64 +model_generator_max_features: 512 +model_generator_num_down_blocks: 2 +model_generator_num_bottleneck_blocks: 6 +model_generator_estimate_occlusion_map: True +model_generator_dense_motion_block_expansion: 64 +model_generator_dense_motion_max_features: 1024 +model_generator_dense_motion_num_blocks: 5 +model_generator_dense_motion_scale_factor: 0.25 +model_discriminator_scales: [1] +model_discriminator_block_expansion: 32 +model_discriminator_max_features: 512 +model_discriminator_num_blocks: 4 +model_discriminator_sn: True \ No newline at end of file diff --git a/talkingface/properties/model/EAMM/mb1_120x120.yml b/talkingface/properties/model/EAMM/mb1_120x120.yml new file mode 100644 index 00000000..a59df542 --- /dev/null +++ b/talkingface/properties/model/EAMM/mb1_120x120.yml @@ -0,0 +1,7 @@ +arch: mobilenet # MobileNet V1 +widen_factor: 1.0 +checkpoint_fp: checkpoints/EAMM/mb1_120x120.pth +bfm_fp: checkpoints/EAMM/bfm_noneck_v3.pkl # or configs/bfm_noneck_v3_slim.pkl +size: 120 +num_params: 62 +param_mean_std_fp: checkpoints/EAMM/param_mean_std_62d_120x120.pkl diff --git a/talkingface/properties/model/EAMM/train_part1.yaml b/talkingface/properties/model/EAMM/train_part1.yaml new file mode 100644 index 00000000..8403408c --- /dev/null +++ b/talkingface/properties/model/EAMM/train_part1.yaml @@ -0,0 +1,69 @@ +train_filelist: True +val_filelist: False + +dataset_name: Vox +dataset_root_dir: dataset/lrw/ +dataset_frame_shape: [256, 256, 3] +dataset_id_sampling: False +dataset_augmentation_flip_horizontal_flip: False +dataset_augmentation_flip_time_flip: False +dataset_augmentation_jitter_brightness: 0.1 +dataset_augmentation_jitter_contrast: 0.1 +dataset_augmentation_jitter_saturation: 0.1 +dataset_augmentation_jitter_hue: 0.1 + +model_common_num_kp: 10 +model_common_num_channels: 3 +model_common_estimate_jacobian: True +model_audio_num_kp: 10 +model_audio_num_channels : 3 +model_audio_num_channels_a : 3 +model_audio_estimate_jacobian: True +model_kp_detector_temperature: 0.1 +model_kp_detector_block_expansion: 32 +model_kp_detector_max_features: 1024 +model_kp_detector_scale_factor: 0.25 +model_kp_detector_num_blocks: 5 +model_generator_block_expansion: 64 +model_generator_max_features: 512 +model_generator_num_down_blocks: 2 +model_generator_num_bottleneck_blocks: 6 +model_generator_estimate_occlusion_map: True +model_generator_dense_motion_block_expansion: 64 +model_generator_dense_motion_max_features: 1024 +model_generator_dense_motion_num_blocks: 5 +model_generator_dense_motion_scale_factor: 0.25 +model_discriminator_scales: [1] +model_discriminator_block_expansion: 32 +model_discriminator_max_features: 512 +model_discriminator_num_blocks: 4 +model_discriminator_sn: True + +train_jaco_net: cnn +train_ldmark: fake +train_generator: not +train_num_epochs: 300 +train_num_repeats: 1 +train_epoch_milestones: [60, 90] +train_lr_generator: 2.0e-4 +train_lr_discriminator: 2.0e-4 +train_lr_kp_detector: 2.0e-4 +train_lr_audio_feature: 2.0e-4 +train_batch_size: 8 +train_scales: [1, 0.5, 0.25, 0.125] +train_checkpoint_freq: 1 +train_transform_sigma_affine: 0.05 +train_transform_sigma_tps: 0.005 +train_transform_points_tps: 5 +train_loss_weights_generator_gan: 0 +train_loss_weights_discriminator_gan: 0 +train_loss_weights_feature_matching: [10, 10, 10, 10] +train_loss_weights_perceptual: [10, 10, 10, 10, 10] +train_loss_weights_equivariance_value: 0 +train_loss_weights_equivariance_jacobian: 0 +train_loss_weights_audio: 10 + +visualizer_kp_size: 5 +visualizer_draw_border: True +visualizer_colormap: 'gist_rainbow' + diff --git a/talkingface/properties/model/EAMM/train_part1_fine_tune.yaml b/talkingface/properties/model/EAMM/train_part1_fine_tune.yaml new file mode 100644 index 00000000..290e050e --- /dev/null +++ b/talkingface/properties/model/EAMM/train_part1_fine_tune.yaml @@ -0,0 +1,69 @@ +train_filelist: True +val_filelist: False + +dataset_name: LRW +dataset_root_dir: dataset/LRW/ +dataset_frame_shape: [256, 256, 3] +dataset_id_sampling: False +dataset_augmentation_flip_horizontal_flip: False +dataset_augmentation_flip_time_flip: False +dataset_augmentation_jitter_brightness: 0.1 +dataset_augmentation_jitter_contrast: 0.1 +dataset_augmentation_jitter_saturation: 0.1 +dataset_augmentation_jitter_hue: 0.1 + +model_common_num_kp: 10 +model_common_num_channels: 3 +model_common_estimate_jacobian: True +model_audio_num_kp: 10 +model_audio_num_channels : 3 +model_audio_num_channels_a : 3 +model_audio_estimate_jacobian: True +model_kp_detector_temperature: 0.1 +model_kp_detector_block_expansion: 32 +model_kp_detector_max_features: 1024 +model_kp_detector_scale_factor: 0.25 +model_kp_detector_num_blocks: 5 +model_generator_block_expansion: 64 +model_generator_max_features: 512 +model_generator_num_down_blocks: 2 +model_generator_num_bottleneck_blocks: 6 +model_generator_estimate_occlusion_map: True +model_generator_dense_motion_block_expansion: 64 +model_generator_dense_motion_max_features: 1024 +model_generator_dense_motion_num_blocks: 5 +model_generator_dense_motion_scale_factor: 0.25 +model_discriminator_scales: [1] +model_discriminator_block_expansion: 32 +model_discriminator_max_features: 512 +model_discriminator_num_blocks: 4 +model_discriminator_sn: True + +train_jaco_net: cnn +train_ldmark: fake +train_generator: audio +train_num_epochs: 300 +train_num_repeats: 1 +train_epoch_milestones: [60, 90] +train_lr_generator: 2.0e-4 +train_lr_discriminator: 2.0e-4 +train_lr_kp_detector: 2.0e-4 +train_lr_audio_feature: 2.0e-4 +train_batch_size: 6 +train_scales: [1, 0.5, 0.25, 0.125] +train_checkpoint_freq: 1 +train_transform_sigma_affine: 0.05 +train_transform_sigma_tps: 0.005 +train_transform_points_tps: 5 +train_loss_weights_generator_gan: 0 +train_loss_weights_discriminator_gan: 0 +train_loss_weights_feature_matching: [10, 10, 10, 10] +train_loss_weights_perceptual: [0.1, 0.1, 0.1, 0.1, 0.1] +train_loss_weights_equivariance_value: 0 +train_loss_weights_equivariance_jacobian: 0 +train_loss_weights_audio: 10 + +visualizer_kp_size: 5 +visualizer_draw_border: True +visualizer_colormap: 'gist_rainbow' + diff --git a/talkingface/properties/model/EAMM/train_part2.yaml b/talkingface/properties/model/EAMM/train_part2.yaml new file mode 100644 index 00000000..ea26ba82 --- /dev/null +++ b/talkingface/properties/model/EAMM/train_part2.yaml @@ -0,0 +1,78 @@ +train_filelist: True +val_filelist: False + +dataset_name: MEAD +dataset_root_dir: dataset/MEAD/ +dataset_frame_shape: [256, 256, 3] +dataset_id_sampling: False +dataset_augmentation_crop_mouth_center_x: 135 +dataset_augmentation_crop_mouth_center_y: 190 +dataset_augmentation_crop_mouth_mask_width: 100 +dataset_augmentation_crop_mouth_mask_height: 60 +dataset_augmentation_rotation_degrees: 30 +dataset_augmentation_perspective_pers_num: 30 +dataset_augmentation_perspective_enlarge_num: 40 +dataset_augmentation_flip_horizontal_flip: True +dataset_augmentation_flip_time_flip: False +dataset_augmentation_jitter_brightness: 0 +dataset_augmentation_jitter_contrast: 0 +dataset_augmentation_jitter_saturation: 0 +dataset_augmentation_jitter_hue: 0 + +model_common_num_kp: 10 +model_common_num_channels: 3 +model_common_estimate_jacobian: True +model_audio_num_kp: 10 +model_audio_num_channels : 3 +model_audio_num_channels_a : 3 +model_audio_estimate_jacobian: True +model_kp_detector_temperature: 0.1 +model_kp_detector_block_expansion: 32 +model_kp_detector_max_features: 1024 +model_kp_detector_scale_factor: 0.25 +model_kp_detector_num_blocks: 5 +model_generator_block_expansion: 64 +model_generator_max_features: 512 +model_generator_num_down_blocks: 2 +model_generator_num_bottleneck_blocks: 6 +model_generator_estimate_occlusion_map: True +model_generator_dense_motion_block_expansion: 64 +model_generator_dense_motion_max_features: 1024 +model_generator_dense_motion_num_blocks: 5 +model_generator_dense_motion_scale_factor: 0.25 +model_discriminator_scales: [1] +model_discriminator_block_expansion: 32 +model_discriminator_max_features: 512 +model_discriminator_num_blocks: 4 +model_discriminator_sn: True + +train_type: linear_4 +train_smooth: False +train_jaco_net: cnn +train_ldmark: fake +train_generator: not +train_num_epochs: 300 +train_num_repeats: 1 +train_epoch_milestones: [60, 90] +train_lr_generator: 2.0e-4 +train_lr_discriminator: 2.0e-4 +train_lr_kp_detector: 2.0e-4 +train_lr_audio_feature: 2.0e-4 +train_batch_size: 16 +train_scales: [1, 0.5, 0.25, 0.125] +train_checkpoint_freq: 1 +train_transform_sigma_affine: 0.05 +train_transform_sigma_tps: 0.005 +train_transform_points_tps: 5 +train_loss_weights_generator_gan: 0 +train_loss_weights_discriminator_gan: 0 +train_loss_weights_feature_matching: [10, 10, 10, 10] +train_loss_weights_perceptual: [10, 10, 10, 10, 10] +train_loss_weights_equivariance_value: 0 +train_loss_weights_equivariance_jacobian: 0 +train_loss_weights_emo: 10 + +visualizer_kp_size: 5 +visualizer_draw_border: True +visualizer_colormap: 'gist_rainbow' + diff --git a/talkingface/trainer/trainer.py b/talkingface/trainer/trainer.py index 2c34717b..e03246de 100644 --- a/talkingface/trainer/trainer.py +++ b/talkingface/trainer/trainer.py @@ -554,4 +554,1148 @@ def _valid_epoch(self, valid_data, loss_func=None, show_progress=False): if losses_dict["sync_loss"] < .75: self.model.config["syncnet_wt"] = 0.01 return average_loss_dict - \ No newline at end of file + + +class EAMMTrainer(Trainer): + import matplotlib + + matplotlib.use("Agg") + import yaml + from argparse import ArgumentParser + import skimage + import imageio + import skimage.transform as st + from talkingface.utils.filter1 import OneEuroFilter + import torch.utils + + from torch.autograd import Variable + from talkingface.utils.augmentation import AllAugmentationTransform + + from talkingface.model.audio_driven_talkingface.eamm_modules.generator import OcclusionAwareGenerator + from talkingface.model.audio_driven_talkingface.eamm_modules.keypoint_detector import KPDetector, KPDetector_a + from talkingface.model.audio_driven_talkingface.eamm_modules.util import AT_net, Emotion_k, Emotion_map, AT_net2 + + from scipy.spatial import ConvexHull + + import python_speech_features + import cv2 + import librosa + from skimage import transform as tf + import itertools + from talkingface.utils.eamm_logger import Logger + from torch.optim.lr_scheduler import MultiStepLR + from talkingface.model.audio_driven_talkingface.eamm_modules.model import DiscriminatorFullModel, TrainPart1Model, TrainPart2Model + + def __init__(self, config, model): + super(EAMMTrainer, self).__init__(config, model) + self.opt = config + self.model = model + self.detector = dlib.get_frontal_face_detector() + self.predictor = dlib.shape_predictor("checkpoints/EAMM/shape_predictor_68_face_landmarks.dat") + if self.opt['train']: + self.train_params = { + 'jaco_net': self.opt['train_jaco_net'], + 'ldmark': self.opt['train_ldmark'], + 'generator': self.opt['train_generator'], + 'num_epochs': self.opt['train_num_epochs'], + 'train_num_repeats': self.opt['train_num_repeats'], + 'epoch_milestones': self.opt['train_epoch_milestones'], + 'lr_generator': self.opt['train_lr_generator'], + 'lr_discriminator': self.opt['train_lr_discriminator'], + 'lr_kp_detector': self.opt['train_lr_kp_detector'], + 'lr_audio_feature': self.opt['train_lr_audio_feature'], + 'batch_size': self.opt['train_batch_size'], + 'scales': self.opt['train_scales'], + 'checkpoint_freq': self.opt['train_checkpoint_freq'], + 'transform_params': { + 'sigma_affine': self.opt['train_transform_sigma_affine'], + 'sigma_tps': self.opt['train_transform_sigma_tps'], + 'points_tps': self.opt['train_transform_points_tps'], + }, + 'loss_weights': { + 'generator_gan': self.opt['train_loss_weights_generator_gan'], + 'discriminator_gan': self.opt['train_loss_weights_discriminator_gan'], + 'feature_matching': self.opt['train_loss_weights_feature_matching'], + 'perceptual': self.opt['train_loss_weights_perceptual'], + 'equivariance_value': self.opt['train_loss_weights_equivariance_value'], + 'equivariance_jacobian': self.opt['train_loss_weights_equivariance_jacobian'], + 'audio': self.opt['train_loss_weights_audio'], + }, + } + self._init_train() + + def _init_train(self): + if self.opt['mode'] == 'train_part1': + self._init_train_part1() + elif self.opt['mode'] == 'train_part1_fine_tune': + self._init_train_part1_fine_tune() + elif self.opt['mode'] == 'train_part2': + self._init_train_part2() + + self.step = 0 + self.train_itr = 0 + self.test_itr = 0 + + def _init_train_part1(self): + self.optimizer_audio_feature = torch.optim.Adam( + self.itertools.chain(self.model.audio_feature.parameters(), self.model.kp_detector_a.parameters()), + lr=self.opt["train_lr_audio_feature"], + betas=(0.5, 0.999), + ) + if self.opt['checkpoint'] is not None: + self.start_epoch = self.Logger.load_cpk( + self.opt['checkpoint'], + self.model.generator, + self.model.discriminator, + self.model.kp_detector, + self.model.audio_feature, + None, + None, + None, #! + None, #! + ) + if self.opt['audio_checkpoint'] is not None: + pretrain = torch.load(self.opt['audio_checkpoint']) + self.model.kp_detector_a.load_state_dict(pretrain["kp_detector_a"]) + self.model.audio_feature.load_state_dict(pretrain["audio_feature"]) + self.optimizer_audio_feature.load_state_dict(pretrain["optimizer_audio_feature"]) + self.start_epoch = pretrain["epoch"] + else: + self.start_epoch = 0 + + self.scheduler_audio_feature = self.MultiStepLR( + self.optimizer_audio_feature, + self.opt["train_epoch_milestones"], + gamma=0.1, + last_epoch=-1 + self.start_epoch * (self.opt["train_lr_audio_feature"] != 0), + ) + + self.generator_full = self.TrainPart1Model( + self.model.kp_detector, + self.model.kp_detector_a, + self.model.audio_feature, + self.model.generator, + self.model.discriminator, + self.train_params, + self.opt['device_ids'], + ) + self.discriminator_full = self.DiscriminatorFullModel( + self.model.kp_detector, self.model.generator, self.model.discriminator, self.train_params + ) + + if self.gpu_available: + self.generator_full = self.generator_full.to(self.device) + self.discriminator_full = self.discriminator_full.to(self.device) + else: + self.generator_full = self.generator_full.cpu() + self.discriminator_full = self.discriminator_full.cpu() + + def _init_train_part1_fine_tune(self): + self.optimizer_generator = torch.optim.Adam( + self.model.generator.parameters(), lr=self.train_params["lr_generator"], betas=(0.5, 0.999) + ) + self.optimizer_discriminator = torch.optim.Adam( + self.model.discriminator.parameters(), + lr=self.train_params["lr_discriminator"], + betas=(0.5, 0.999), + ) + self.optimizer_audio_feature = torch.optim.Adam( + self.itertools.chain(self.model.audio_feature.parameters(), self.model.kp_detector_a.parameters()), + lr=self.train_params["lr_audio_feature"], + betas=(0.5, 0.999), + ) + + if self.opt['checkpoint'] is not None: + self.start_epoch = self.Logger.load_cpk( + self.opt['checkpoint'], + self.model.generator, + self.model.discriminator, + self.model.kp_detector, + self.model.audio_feature, + self.optimizer_generator, + self.optimizer_discriminator, + None, #! + None if self.train_params["lr_audio_feature"] == 0 else self.optimizer_audio_feature, + ) + if self.opt['audio_checkpoint'] is not None: + pretrain = torch.load(self.opt['audio_checkpoint']) + self.model.kp_detector_a.load_state_dict(pretrain["kp_detector_a"]) + self.model.audio_feature.load_state_dict(pretrain["audio_feature"]) + self.optimizer_audio_feature.load_state_dict(pretrain["optimizer_audio_feature"]) + self.start_epoch = pretrain["epoch"] + else: + self.start_epoch = 0 + + self.scheduler_generator = self.MultiStepLR( + self.optimizer_generator, + self.train_params["epoch_milestones"], + gamma=0.1, + last_epoch=self.start_epoch - 1, + ) + self.scheduler_discriminator = self.MultiStepLR( + self.optimizer_discriminator, + self.train_params["epoch_milestones"], + gamma=0.1, + last_epoch=self.start_epoch - 1, + ) + self.scheduler_audio_feature = self.MultiStepLR( + self.optimizer_audio_feature, + self.train_params["epoch_milestones"], + gamma=0.1, + last_epoch=-1 + self.start_epoch * (self.train_params["lr_audio_feature"] != 0), + ) + self.generator_full = self.TrainPart1Model( + self.model.kp_detector, + self.model.kp_detector_a, + self.model.audio_feature, + self.model.generator, + self.model.discriminator, + self.train_params, + self.opt['device_ids'], + ) + self.discriminator_full = self.DiscriminatorFullModel( + self.model.kp_detector, self.model.generator, self.model.discriminator, self.train_params + ) + + if self.gpu_available: + self.generator_full = self.generator_full.to(self.device) + self.discriminator_full = self.discriminator_full.to(self.device) + else: + self.generator_full = self.generator_full.cpu() + self.discriminator_full = self.discriminator_full.cpu() + + def _init_train_part2(self): + self.optimizer_emo_detector = torch.optim.Adam( + self.model.emo_detector.parameters(), + lr=self.train_params["lr_audio_feature"], + betas=(0.5, 0.999), + ) + self.optimizer_generator = torch.optim.Adam( + self.model.generator.parameters(), + lr=self.train_params["lr_generator"], + betas=(0.5, 0.999), + ) + self.optimizer_discriminator = torch.optim.Adam( + self.model.discriminator.parameters(), + lr=self.train_params["lr_discriminator"], + betas=(0.5, 0.999), + ) + self.optimizer_audio_feature = torch.optim.Adam( + self.itertools.chain(self.model.audio_feature.parameters(), self.model.kp_detector_a.parameters()), + lr=self.train_params["lr_audio_feature"], + betas=(0.5, 0.999), + ) + + if self.opt['checkpoint'] is not None: + start_epoch = self.Logger.load_cpk( + self.opt['checkpoint'], + self.model.generator, + self.model.discriminator, + self.model.kp_detector, + self.model.audio_feature, + self.optimizer_generator, + self.optimizer_discriminator, + None, #! + None if self.train_params["lr_audio_feature"] == 0 else self.optimizer_audio_feature, + ) + if self.opt['emo_checkpoint'] is not None: + pretrain = torch.load(self.opt['emo_checkpoint']) + tgt_state = self.emo_detector.state_dict() + strip = "module." + if "emo_detector" in pretrain: + self.emo_detector.load_state_dict(pretrain["emo_detector"]) + self.optimizer_emo_detector.load_state_dict(pretrain["optimizer_emo_detector"]) + for name, param in pretrain.items(): + if isinstance(param, nn.Parameter): + param = param.data + if strip is not None and name.startswith(strip): + name = name[len(strip) :] + if name not in tgt_state: + continue + tgt_state[name].copy_(param) + print(name) + if self.opt['audio_checkpoint'] is not None: + pretrain = torch.load(self.opt['audio_checkpoint']) + self.kp_detector_a.load_state_dict(pretrain["kp_detector_a"]) + self.audio_feature.load_state_dict(pretrain["audio_feature"]) + self.optimizer_audio_feature.load_state_dict(pretrain["optimizer_audio_feature"]) + if "emo_detector" in pretrain: + self.emo_detector.load_state_dict(pretrain["emo_detector"]) + self.optimizer_emo_detector.load_state_dict(pretrain["optimizer_emo_detector"]) + self.start_epoch = pretrain["epoch"] + else: + self.start_epoch = 0 + + self.scheduler_emo_detector = self.MultiStepLR( + self.optimizer_emo_detector, + self.train_params["epoch_milestones"], + gamma=0.1, + last_epoch=-1 + start_epoch * (self.train_params["lr_audio_feature"] != 0), + ) + self.generator_full = self.TrainPart2Model( + self.model.kp_detector, + self.model.emo_detector, + self.model.kp_detector_a, + self.model.audio_feature, + self.model.generator, + self.model.discriminator, + self.train_params, + self.opt['device_ids'], + ) + self.discriminator_full = self.DiscriminatorFullModel( + self.model.kp_detector, self.model.generator, self.model.discriminator, self.train_params + ) + if self.gpu_available: + self.generator_full = self.generator_full.to(self.device) + self.discriminator_full = self.discriminator_full.to(self.device) + else: + self.generator_full = self.generator_full.cpu() + self.discriminator_full = self.discriminator_full.cpu() + + def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False): + if self.opt['mode'] == 'train_part1': + return self._train_part1(train_data, epoch_idx, loss_func, show_progress) + elif self.opt['mode'] == 'train_part1_fine_tune': + return self._train_part1_fine_tune(train_data, epoch_idx, loss_func, show_progress) + elif self.opt['mode'] == 'train_part2': + return self._train_part2(train_data, epoch_idx, loss_func, show_progress) + + def _train_part1(self, train_data, epoch_idx, loss_func=None, show_progress=False): + for x in train_data: + losses_generator, generated = self.generator_full(x) + loss_values = [val.mean() for val in losses_generator.values()] + loss = sum(loss_values) + + self.tensorboard.add_scalar("Train", loss, self.train_itr) + self.tensorboard.add_scalar("Train_value", loss_values[0], self.train_itr) + self.tensorboard.add_scalar("Train_heatmap", loss_values[1], self.train_itr) + self.tensorboard.add_scalar("Train_jacobian", loss_values[2], self.train_itr) + + self.train_itr += 1 + loss.backward() + self.optimizer_audio_feature.step() + self.optimizer_audio_feature.zero_grad() + + losses = { + key: value.mean().detach().data.cpu().numpy() + for key, value in losses_generator.items() + } + + self.step += 1 + + self.scheduler_audio_feature.step() + return losses + + def _train_part1_fine_tune(self, train_data, epoch_idx, loss_func=None, show_progress=False): + for x in train_data: + losses_generator, generated = self.generator_full(x) + loss_values = [val.mean() for val in losses_generator.values()] + loss = sum(loss_values) + + self.tensorboard.add_scalar("Train", loss, self.train_itr) + self.tensorboard.add_scalar("Train_value", loss_values[0], self.train_itr) + self.tensorboard.add_scalar("Train_heatmap", loss_values[1], self.train_itr) + self.tensorboard.add_scalar("Train_jacobian", loss_values[2], self.train_itr) + self.tensorboard.add_scalar("Train_perceptual", loss_values[3], self.train_itr) + + self.train_itr += 1 + loss.backward() + + self.optimizer_audio_feature.step() + self.optimizer_audio_feature.zero_grad() + self.optimizer_generator.step() + self.optimizer_generator.zero_grad() + if self.train_params["loss_weights"]["discriminator_gan"] != 0: + self.optimizer_discriminator.zero_grad() + else: + losses_discriminator = {} + + losses_generator.update(losses_discriminator) + losses = { + key: value.mean().detach().data.cpu().numpy() + for key, value in losses_generator.items() + } + self.step += 1 + + self.scheduler_generator.step() + self.scheduler_discriminator.step() + self.scheduler_audio_feature.step() + + return losses + + def _train_part2(self, train_data, epoch_idx, loss_func=None, show_progress=False): + for x in train_data: + losses_generator, generated = self.generator_full(x) + loss_values = [val.mean() for val in losses_generator.values()] + loss = sum(loss_values) + + self.tensorboard.add_scalar("Train", loss, self.train_itr) + self.tensorboard.add_scalar("Train_value", loss_values[0], self.train_itr) + self.tensorboard.add_scalar("Train_jacobian", loss_values[1], self.train_itr) + self.tensorboard.add_scalar("Train_classify", loss_values[2], self.train_itr) + + self.train_itr += 1 + loss.backward() + + self.optimizer_emo_detector.step() + self.optimizer_emo_detector.zero_grad() + + losses = { + key: value.mean().detach().data.cpu().numpy() + for key, value in losses_generator.items() + } + self.step += 1 + + self.scheduler_emo_detector.step() + return losses + + def _valid_epoch(self, valid_data, loss_func=None, show_progress=False): + if self.opt['mode'] == 'train_part1': + return self._valid_part1(valid_data, loss_func, show_progress) + elif self.opt['mode'] == 'train_part1_fine_tune': + return self._valid_part1_fine_tune(valid_data, loss_func, show_progress) + elif self.opt['mode'] == 'train_part2': + return self._valid_part2(valid_data, loss_func, show_progress) + + def _valid_part1(self, valid_data, loss_func=None, show_progress=False): + for x in valid_data: + with torch.no_grad(): + losses_generator, generated = self.generator_full(x) + loss_values = [val.mean() for val in losses_generator.values()] + loss = sum(loss_values) + + self.tensorboard.add_scalar("Test", loss, self.test_itr) + self.tensorboard.add_scalar("Test_value", loss_values[0], self.test_itr) + self.tensorboard.add_scalar("Test_heatmap", loss_values[1], self.test_itr) + self.tensorboard.add_scalar("Test_jacobian", loss_values[2], self.test_itr) + + self.test_itr += 1 + losses = { + key: value.mean().detach().data.cpu().numpy() + for key, value in losses_generator.items() + } + + return losses + + def _valid_part1_fine_tune(self, valid_data, loss_func=None, show_progress=False): + for x in valid_data: + with torch.no_grad(): + losses_generator, generated = self.generator_full(x) + loss_values = [val.mean() for val in losses_generator.values()] + loss = sum(loss_values) + + self.tensorboard.add_scalar("Test", loss, self.test_itr) + self.tensorboard.add_scalar("Test_value", loss_values[0], self.test_itr) + self.tensorboard.add_scalar("Test_heatmap", loss_values[1], self.test_itr) + self.tensorboard.add_scalar("Test_jacobian", loss_values[2], self.test_itr) + self.tensorboard.add_scalar("Test_perceptual", loss_values[3], self.test_itr) + + self.test_itr += 1 + losses = { + key: value.mean().detach().data.cpu().numpy() + for key, value in losses_generator.items() + } + + return losses + + def _valid_part2(self, valid_data, loss_func=None, show_progress=False): + for x in valid_data: + with torch.no_grad(): + losses_generator, generated = self.generator_full(x) + loss_values = [val.mean() for val in losses_generator.values()] + loss = sum(loss_values) + + self.tensorboard.add_scalar("Test", loss, self.test_itr) + self.tensorboard.add_scalar("Test_value", loss_values[0], self.test_itr) + self.tensorboard.add_scalar("Test_jacobian", loss_values[1], self.test_itr) + self.tensorboard.add_scalar("Test_classify", loss_values[2], self.test_itr) + + self.test_itr += 1 + losses = { + key: value.mean().detach().data.cpu().numpy() + for key, value in losses_generator.items() + } + + return losses + + def load_checkpoints( + self, opt, checkpoint_path, audio_checkpoint_path, emo_checkpoint_path, cpu=False + ): + """ + load checkpoints + """ + with open(opt['config']) as f: + config = self.yaml.load(f, Loader=self.yaml.FullLoader) + + generator = self.OcclusionAwareGenerator( + **config["model_params"]["generator_params"], + **config["model_params"]["common_params"] + ) + if not cpu: + generator.cuda() + + kp_detector = self.KPDetector( + **config["model_params"]["kp_detector_params"], + **config["model_params"]["common_params"] + ) + if not cpu: + kp_detector.cuda() + + kp_detector_a = self.KPDetector_a( + **config["model_params"]["kp_detector_params"], + **config["model_params"]["audio_params"] + ) + + audio_feature = self.AT_net2() + if opt['type'].startswith("linear"): + emo_detector = self.Emotion_k( + block_expansion=32, + num_channels=3, + max_features=1024, + num_blocks=5, + scale_factor=0.25, + num_classes=8, + ) + elif opt['type'].startswith("map"): + emo_detector = self.Emotion_map( + block_expansion=32, + num_channels=3, + max_features=1024, + num_blocks=5, + scale_factor=0.25, + num_classes=8, + ) + if not cpu: + kp_detector_a.cuda() + audio_feature.cuda() + emo_detector.cuda() + + if cpu: + checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) + audio_checkpoint = torch.load( + audio_checkpoint_path, map_location=torch.device("cpu") + ) + emo_checkpoint = torch.load( + emo_checkpoint_path, map_location=torch.device("cpu") + ) + else: + checkpoint = torch.load(checkpoint_path) + audio_checkpoint = torch.load(audio_checkpoint_path) + emo_checkpoint = torch.load(emo_checkpoint_path) + + generator.load_state_dict(checkpoint["generator"]) + kp_detector.load_state_dict(checkpoint["kp_detector"]) + audio_feature.load_state_dict(audio_checkpoint["audio_feature"]) + kp_detector_a.load_state_dict(audio_checkpoint["kp_detector_a"]) + emo_detector.load_state_dict(emo_checkpoint["emo_detector"]) + + if not cpu: + generator = generator.cuda() + kp_detector = kp_detector.cuda() + audio_feature = audio_feature.cuda() + kp_detector_a = kp_detector_a.cuda() + emo_detector = emo_detector.cuda() + else: + generator = generator.cpu() + kp_detector = kp_detector.cpu() + audio_feature = audio_feature.cpu() + kp_detector_a = kp_detector_a.cpu() + emo_detector = emo_detector.cpu() + + generator.eval() + kp_detector.eval() + audio_feature.eval() + kp_detector_a.eval() + emo_detector.eval() + return generator, kp_detector, kp_detector_a, audio_feature, emo_detector + + def normalize_kp( + self, + kp_source, + kp_driving, + kp_driving_initial, + adapt_movement_scale=False, + use_relative_movement=False, + use_relative_jacobian=False, + ): + """ + normalize keypoints. + """ + if adapt_movement_scale: + source_area = self.ConvexHull(kp_source["value"][0].data.cpu().numpy()).volume + driving_area = self.ConvexHull( + kp_driving_initial["value"][0].data.cpu().numpy() + ).volume + adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) + else: + adapt_movement_scale = 1 + + kp_new = {k: v for k, v in kp_driving.items()} + + if use_relative_movement: + kp_value_diff = kp_driving["value"] - kp_driving_initial["value"] + kp_value_diff *= adapt_movement_scale + kp_new["value"] = kp_value_diff + kp_source["value"] + + if use_relative_jacobian: + jacobian_diff = torch.matmul( + kp_driving["jacobian"], torch.inverse(kp_driving_initial["jacobian"]) + ) + kp_new["jacobian"] = torch.matmul(jacobian_diff, kp_source["jacobian"]) + + return kp_new + + def shape_to_np(self, shape, dtype="int"): + # initialize the list of (x, y)-coordinates + coords = np.zeros((shape.num_parts, 2), dtype=dtype) + + # loop over all facial landmarks and convert them + # to a 2-tuple of (x, y)-coordinates + for i in range(0, shape.num_parts): + coords[i] = (shape.part(i).x, shape.part(i).y) + + # return the list of (x, y)-coordinates + return coords + + def get_aligned_image(self, driving_video, opt): + """ + emotion video also crop centering and resize to 256 x 256. + """ + aligned_array = [] + + video_array = np.array(driving_video) + source_image = video_array[0] + source_image = np.array(source_image * 255, dtype=np.uint8) + gray = self.cv2.cvtColor(source_image, self.cv2.COLOR_BGR2GRAY) + rects = self.detector(gray, 1) # detect human face + for i, rect in enumerate(rects): + template = self.predictor(gray, rect) # detect 68 points + template = self.shape_to_np(template) + + if opt['emotion'] == "surprised" or opt['emotion'] == "fear": + template = template - [0, 10] + for i in range(len(video_array)): + image = np.array(video_array[i] * 255, dtype=np.uint8) + gray = self.cv2.cvtColor(image, self.cv2.COLOR_BGR2GRAY) + rects = self.detector(gray, 1) # detect human face + for j, rect in enumerate(rects): + shape = self.predictor(gray, rect) # detect 68 points + shape = self.shape_to_np(shape) + + pts2 = np.float32(template[:35, :]) + pts1 = np.float32(shape[:35, :]) # eye and nose + + tform = self.tf.SimilarityTransform() + tform.estimate( + pts2, pts1 + ) # Set the transformation matrix with the explicit parameters. + dst = self.tf.warp(image, tform, output_shape=(256, 256)) + + dst = np.array(dst, dtype=np.float32) + aligned_array.append(dst) + + return aligned_array + + def get_transformed_image(self, driving_video, opt): + """ + augmentation for emotion images. + """ + video_array = np.array(driving_video) + with open(opt['config']) as f: + config = self.yaml.load(f, Loader=self.yaml.FullLoader) + transformations = self.AllAugmentationTransform( + **config["dataset_params"]["augmentation_params"] + ) + transformed_array = transformations(video_array) + return transformed_array + + def make_animation_smooth( + self, + source_image, + driving_video, + transformed_video, + deco_out, + kp_loss, + generator, + kp_detector, + kp_detector_a, + emo_detector, + opt, + relative=True, + adapt_movement_scale=True, + cpu=False, + ): + """ + generate target images. + """ + with torch.no_grad(): + predictions = [] + + source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute( + 0, 3, 1, 2 + ) + + if not cpu: + source = source.cuda() + else: + source = source.cpu() + + driving = torch.tensor( + np.array(driving_video)[np.newaxis].astype(np.float32) + ).permute(0, 4, 1, 2, 3) + transformed_driving = torch.tensor( + np.array(transformed_video)[np.newaxis].astype(np.float32) + ).permute(0, 4, 1, 2, 3) + + kp_source = kp_detector(source) + kp_driving_initial = kp_detector_a(deco_out[:, 0]) + + emo_driving_all = [] + features = [] + kp_driving_all = [] + for frame_idx in tqdm(range(len(deco_out[0]))): + driving_frame = driving[:, :, frame_idx] + transformed_frame = transformed_driving[:, :, frame_idx] + if not cpu: + driving_frame = driving_frame.cuda() + transformed_frame = transformed_frame.cuda() + else: + driving_frame = driving_frame.cpu() + transformed_frame = transformed_frame.cpu() + kp_driving = kp_detector_a(deco_out[:, frame_idx]) + kp_driving_all.append(kp_driving) + if opt['add_emo']: + value = kp_driving["value"] + jacobian = kp_driving["jacobian"] + if opt['type'] == "linear_3": + emo_driving, _ = emo_detector(transformed_frame, value, jacobian) + features.append( + emo_detector.feature(transformed_frame).data.cpu().numpy() + ) + + emo_driving_all.append(emo_driving) + features = np.array(features) + if opt['add_emo']: + one_euro_filter_v = self.OneEuroFilter( + mincutoff=1, beta=0.2, dcutoff=1.0, freq=100 + ) # 1 0.4 + one_euro_filter_j = self.OneEuroFilter( + mincutoff=1, beta=0.2, dcutoff=1.0, freq=100 + ) # 1 0.4 + + for j in range(len(emo_driving_all)): + emo_driving_all[j]["value"] = ( + one_euro_filter_v.process(emo_driving_all[j]["value"].cpu() * 100) + / 100 + ) + if not cpu: + emo_driving_all[j]["value"] = emo_driving_all[j]["value"].cuda() + emo_driving_all[j]["jacobian"] = ( + one_euro_filter_j.process( + emo_driving_all[j]["jacobian"].cpu() * 100 + ) + / 100 + ) + if not cpu: + emo_driving_all[j]["jacobian"] = emo_driving_all[j][ + "jacobian" + ].cuda() + + one_euro_filter_v = self.OneEuroFilter(mincutoff=0.05, beta=8, dcutoff=1.0, freq=100) + one_euro_filter_j = self.OneEuroFilter(mincutoff=0.05, beta=8, dcutoff=1.0, freq=100) + + for j in range(len(kp_driving_all)): + kp_driving_all[j]["value"] = ( + one_euro_filter_v.process(kp_driving_all[j]["value"].cpu() * 10) / 10 + ) + if not cpu: + kp_driving_all[j]["value"] = kp_driving_all[j]["value"].cuda() + kp_driving_all[j]["jacobian"] = ( + one_euro_filter_j.process(kp_driving_all[j]["jacobian"].cpu() * 10) / 10 + ) + if not cpu: + kp_driving_all[j]["jacobian"] = kp_driving_all[j]["jacobian"].cuda() + + for frame_idx in tqdm(range(len(deco_out[0]))): + if opt['check_add']: + kp_driving = kp_detector_a(deco_out[:, 0]) + else: + kp_driving = kp_driving_all[frame_idx] + + if opt['add_emo']: + emo_driving = emo_driving_all[frame_idx] + if opt['type'] == "linear_3": + kp_driving["value"][:, 1] = ( + kp_driving["value"][:, 1] + emo_driving["value"][:, 0] * 0.2 + ) + kp_driving["jacobian"][:, 1] = ( + kp_driving["jacobian"][:, 1] + + emo_driving["jacobian"][:, 0] * 0.2 + ) + kp_driving["value"][:, 4] = ( + kp_driving["value"][:, 4] + emo_driving["value"][:, 1] + ) + kp_driving["jacobian"][:, 4] = ( + kp_driving["jacobian"][:, 4] + emo_driving["jacobian"][:, 1] + ) + kp_driving["value"][:, 6] = ( + kp_driving["value"][:, 6] + emo_driving["value"][:, 2] + ) + kp_driving["jacobian"][:, 6] = ( + kp_driving["jacobian"][:, 6] + emo_driving["jacobian"][:, 2] + ) + + kp_norm = self.normalize_kp( + kp_source=kp_source, + kp_driving=kp_driving, + kp_driving_initial=kp_driving_initial, + use_relative_movement=relative, + use_relative_jacobian=relative, + adapt_movement_scale=adapt_movement_scale, + ) + out = generator(source, kp_source=kp_source, kp_driving=kp_norm) + + predictions.append( + np.transpose(out["prediction"].data.cpu().numpy(), [0, 2, 3, 1])[0] + ) + return predictions, features + + def test_auido(self, example_image, audio_feature, all_pose, opt): + """ + generate audio feature (key points of motion). + """ + with open(opt['config']) as f: + para = self.yaml.load(f, Loader=self.yaml.FullLoader) + + if not opt['cpu']: + audio_feature = audio_feature.cuda() + else: + audio_feature = audio_feature.cpu() + + audio_feature.eval() + test_file = opt['in_file'] + pose = all_pose[:, :6] + if len(pose) == 1: + pose = np.repeat(pose, 100, 0) + elif opt['smooth_pose']: + one_euro_filter = self.OneEuroFilter( + mincutoff=0.004, beta=0.7, dcutoff=1.0, freq=100 + ) + for j in range(len(pose)): + pose[j] = one_euro_filter.process(pose[j]) + + example_image = np.array(example_image, dtype="float32").transpose((2, 0, 1)) + + speech, sr = self.librosa.load(test_file, sr=16000) + speech = np.insert(speech, 0, np.zeros(1920)) + speech = np.append(speech, np.zeros(1920)) + mfcc = self.python_speech_features.mfcc(speech, 16000, winstep=0.01) + + print("=======================================") + print("Start to generate images") + + ind = 3 + with torch.no_grad(): + fake_lmark = [] + input_mfcc = [] + while ind <= int(mfcc.shape[0] / 4) - 4: + t_mfcc = mfcc[(ind - 3) * 4 : (ind + 4) * 4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc).cuda() + input_mfcc.append(t_mfcc) + ind += 1 + input_mfcc = torch.stack(input_mfcc, dim=0).cpu() + + if len(pose) < len(input_mfcc): + gap = len(input_mfcc) - len(pose) + n = int((gap / len(pose) / 2)) + 2 + pose = np.concatenate((pose, pose[::-1, :]), axis=0) + pose = np.tile(pose, (n, 1)) + if len(pose) > len(input_mfcc): + pose = pose[: len(input_mfcc), :] + + if not opt['cpu']: + example_image = self.Variable( + torch.FloatTensor(example_image.astype(float)) + ).cuda() + example_image = torch.unsqueeze(example_image, 0) + pose = self.Variable(torch.FloatTensor(pose.astype(float))).cuda() + else: + example_image = self.Variable( + torch.FloatTensor(example_image.astype(float)) + ).cpu() + example_image = torch.unsqueeze(example_image, 0).cpu() + pose = self.Variable(torch.FloatTensor(pose.astype(float))).cpu() + + pose = pose.unsqueeze(0) + + input_mfcc = input_mfcc.unsqueeze(0) + + deco_out = audio_feature( + example_image, input_mfcc, pose, para["train_params"]["jaco_net"], 1.6 + ) + + return deco_out + + def save(self, path, frames, format): + """ + save png. + """ + if format == ".png": + if not os.path.exists(path): + os.makedirs(path) + for j, frame in enumerate(frames): + self.imageio.imsave(path + "/" + str(j) + ".png", frame) + else: + print("Unknown format %s" % format) + exit() + + class VideoWriter(object): + """ + VideoWriter. + """ + def __init__(self, path, width, height, fps): + fourcc = EAMMTrainer.cv2.VideoWriter_fourcc(*"XVID") + self.path = path + self.out = EAMMTrainer.cv2.VideoWriter(self.path, fourcc, fps, (width, height)) + + def write_frame(self, frame): + self.out.write(frame) + + def end(self): + self.out.release() + + def concatenate(self, number, imgs, save_path): + """ + concatenate generated frames to a video. + """ + width, height = imgs.shape[-3:-1] + imgs = imgs.reshape(number, -1, width, height, 3) + if number == 2: + left = imgs[0] + right = imgs[1] + + im_all = [] + for i in range(len(left)): + im = np.concatenate((left[i], right[i]), axis=1) + im_all.append(im) + if number == 3: + left = imgs[0] + middle = imgs[1][:,:,:,::-1] + right = imgs[2][:,:,:,::-1] + + im_all = [] + for i in range(len(left)): + im = np.concatenate((left[i], middle[i], right[i]), axis=1) + im_all.append(im) + if number == 4: + left = imgs[0] + left2 = imgs[1] + right = imgs[2] + right2 = imgs[3] + + im_all = [] + for i in range(len(left)): + im = np.concatenate((left[i], left2[i], right[i], right2[i]), axis=1) + im_all.append(im) + if number == 5: + left = imgs[0] + left2 = imgs[1] + middle = imgs[2] + right = imgs[3] + right2 = imgs[4] + + im_all = [] + for i in range(len(left)): + im = np.concatenate( + (left[i], left2[i], middle[i], right[i], right2[i]), axis=1 + ) + im_all.append(im) + + self.imageio.mimsave(save_path, [self.skimage.img_as_ubyte(frame) for frame in im_all], fps=25) + + def add_audio(self, video_name=None, audio_dir=None): + """ + add audio to the generated video. + """ + command = ( + "ffmpeg -i " + + video_name + + " -i " + + audio_dir + + " -vcodec copy -acodec copy -y " + + video_name.replace(".mp4", ".mov") + ) + print(command) + subprocess.call(command) + # os.system(command) + + def crop_image(self, source_image): + """ + All videos are aligned via centering (crop & resize) the location of the first frame’s face and resized to 256 × 256 + """ + template = np.load("checkpoints/EAMM/M003_template.npy") + image = self.cv2.imread(source_image) + gray = self.cv2.cvtColor(image, self.cv2.COLOR_BGR2GRAY) + rects = self.detector(gray, 1) # detect human face + if len(rects) != 1: + return 0 + for j, rect in enumerate(rects): + shape = self.predictor(gray, rect) # detect 68 points + shape = self.shape_to_np(shape) + + pts2 = np.float32(template[:47, :]) + pts1 = np.float32(shape[:47, :]) # eye and nose + tform = self.tf.SimilarityTransform() + tform.estimate( + pts2, pts1 + ) # Set the transformation matrix with the explicit parameters. + + dst = self.tf.warp(image, tform, output_shape=(256, 256)) + + dst = np.array(dst * 255, dtype=np.uint8) + return dst + + def smooth_pose(self, pose_file, pose_long): + """ + smooth pose of the driven video. + """ + start = np.load(pose_file) + video_pose = np.load(pose_long) + delta = video_pose - video_pose[0, :] + print(len(delta)) + + pose = np.repeat(start, len(delta), axis=0) + all_pose = pose + delta + + return all_pose + + def _load_config(self, file): + """ + load evaluate opt. + """ + file = Path(file) + assert file.exists() + with open(file, "r", encoding="utf-8") as f: + self.opt.update( + self.yaml.load(f.read(), Loader=self.yaml.FullLoader) + ) + self.logger.info( + "\n".join( + [ + ( + set_color("{}", "cyan") + " =" + set_color(" {}", "yellow") + ).format(arg, value) + for arg, value in self.opt.items() + ] + ) + ) + + @torch.no_grad() + def evaluate(self, load_best_model=True, model_file=None): + if model_file is None: + return + + self.opt = dict() + self._load_config(model_file) + + all_pose = np.load(self.opt['pose_file']).reshape(-1, 7) + if self.opt['pose_long']: + all_pose = self.smooth_pose(self.opt['pose_file'], self.opt['pose_given']) + + source_image = self.skimage.img_as_float32(self.crop_image(self.opt['source_image'])) + source_image = self.st.resize(source_image, (256, 256))[..., :3] + + reader = self.imageio.get_reader(self.opt['driving_video']) + fps = reader.get_meta_data()["fps"] + driving_video = [] + try: + for im in reader: + driving_video.append(im) + except RuntimeError: + pass + reader.close() + + driving_video = [self.st.resize(frame, (256, 256))[..., :3] for frame in driving_video] + driving_video = self.get_aligned_image(driving_video, self.opt) + transformed_video = self.get_transformed_image(driving_video, self.opt) + transformed_video = np.array(transformed_video) + + ( + generator, + kp_detector, + kp_detector_a, + audio_feature, + emo_detector, + ) = self.load_checkpoints( + opt=self.opt, + checkpoint_path=self.opt['checkpoint'], + audio_checkpoint_path=self.opt['audio_checkpoint'], + emo_checkpoint_path=self.opt['emo_checkpoint'], + cpu=self.opt['cpu'], + ) + + deco_out = self.test_auido(source_image, audio_feature, all_pose, self.opt) + if len(driving_video) < len(deco_out[0]): + driving_video = np.resize(driving_video, (len(deco_out[0]), 256, 256, 3)) + transformed_video = np.resize( + transformed_video, (len(deco_out[0]), 256, 256, 3) + ) + else: + driving_video = driving_video[: len(deco_out[0])] + + self.opt['add_emo'] = False + predictions, _ = self.make_animation_smooth( + source_image, + driving_video, + transformed_video, + deco_out, + self.opt['kp_loss'], + generator, + kp_detector, + kp_detector_a, + emo_detector, + self.opt, + relative=self.opt['relative'], + adapt_movement_scale=self.opt['adapt_scale'], + cpu=self.opt['cpu'], + ) + + result_path = Path(self.opt['result_path']) + if not result_path.exists(): + result_path.mkdir(parents=True,exist_ok=True) + + self.imageio.mimsave( + os.path.join(self.opt['result_path'], "neutral.mp4"), + [self.skimage.img_as_ubyte(frame[:,:,::-1]) for frame in predictions], + fps=fps, + ) + predictions = np.array(predictions) + + self.opt['add_emo'] = True + predictions1, _ = self.make_animation_smooth( + source_image, + driving_video, + transformed_video, + deco_out, + self.opt['kp_loss'], + generator, + kp_detector, + kp_detector_a, + emo_detector, + self.opt, + relative=self.opt['relative'], + adapt_movement_scale=self.opt['adapt_scale'], + cpu=self.opt['cpu'], + ) + + self.imageio.mimsave( + os.path.join(self.opt['result_path'], "emotion.mp4"), + [self.skimage.img_as_ubyte(frame[:, :, ::-1]) for frame in predictions1], + fps=fps, + ) + self.add_audio(os.path.join(self.opt['result_path'], "emotion.mp4"), self.opt['in_file']) + predictions1 = np.array(predictions1) + all_imgs = np.concatenate((driving_video, predictions, predictions1), axis=0) + save_path = os.path.join(self.opt['result_path'], "all.mp4") + self.concatenate(3, all_imgs, save_path) + self.add_audio(save_path, self.opt['in_file']) \ No newline at end of file diff --git a/talkingface/utils/augmentation.py b/talkingface/utils/augmentation.py new file mode 100644 index 00000000..2c45db9f --- /dev/null +++ b/talkingface/utils/augmentation.py @@ -0,0 +1,430 @@ +""" +Code from https://github.com/hassony2/torch_videovision +""" + +import numbers +import math +import random +import numpy as np +import PIL +import cv2 +from skimage.transform import resize, rotate, AffineTransform, warp +from skimage.util import pad +import torchvision + +import warnings + +from skimage import img_as_ubyte, img_as_float + + +def crop_clip(clip, min_h, min_w, h, w): + if isinstance(clip[0], np.ndarray): + cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] + + elif isinstance(clip[0], PIL.Image.Image): + cropped = [ + img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return cropped + + +def pad_clip(clip, h, w): + im_h, im_w = clip[0].shape[:2] + pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2) + pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2) + + return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge') + + +def resize_clip(clip, size, interpolation='bilinear'): + if isinstance(clip[0], np.ndarray): + if isinstance(size, numbers.Number): + im_h, im_w, im_c = clip[0].shape + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[1], size[0] + + scaled = [ + resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True, + mode='constant', anti_aliasing=True) for img in clip + ] + elif isinstance(clip[0], PIL.Image.Image): + if isinstance(size, numbers.Number): + im_w, im_h = clip[0].size + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[1], size[0] + if interpolation == 'bilinear': + pil_inter = PIL.Image.NEAREST + else: + pil_inter = PIL.Image.BILINEAR + scaled = [img.resize(size, pil_inter) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return scaled + + +def get_resize_sizes(im_h, im_w, size): + if im_w < im_h: + ow = size + oh = int(size * im_h / im_w) + else: + oh = size + ow = int(size * im_w / im_h) + return oh, ow + + +class RandomFlip(object): + def __init__(self, time_flip=False, horizontal_flip=False): + self.time_flip = time_flip + self.horizontal_flip = horizontal_flip + + def __call__(self, clip): + if random.random() < 0.5 and self.time_flip: + return clip[::-1] + if random.random() < 0.5 and self.horizontal_flip: + return [np.fliplr(img) for img in clip] + + return clip + + +class RandomResize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): + self.ratio = ratio + self.interpolation = interpolation + + def __call__(self, clip): + scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) + + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + + new_w = int(im_w * scaling_factor) + new_h = int(im_h * scaling_factor) + new_size = (new_w, new_h) + resized = resize_clip( + clip, new_size, interpolation=self.interpolation) + + return resized + + +class RandomCrop(object): + """Extract random crop at the same location for a list of videos + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of videos to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of videos + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + + clip = pad_clip(clip, h, w) + im_h, im_w = clip.shape[1:3] + x1 = 0 if h == im_h else random.randint(0, im_w - w) + y1 = 0 if w == im_w else random.randint(0, im_h - h) + cropped = crop_clip(clip, y1, x1, h, w) + + return cropped + + +class MouthCrop(object): + """Extract random crop at the same location for a list of videos + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, center_x, center_y, mask_width, mask_height): + + + self.center_x = center_x + self.center_y = center_y + self.mask_width = mask_width + self.mask_height = mask_height + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of videos to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of videos + """ + start_x = self.center_x - int(self.mask_width/2) + start_y = self.center_y - int(self.mask_height/2) + end_x = start_x + self.mask_width + end_y = start_y + self.mask_height + # mask is all white + # mask = 255*np.ones((mask_height, mask_width, 3), dtype=np.uint8) + # mask is uniform noise + cropped = [] + for i in range(len(clip)): + mask = np.random.rand(self.mask_height, self.mask_width, 3) + img = clip[i].copy() + img[start_y:end_y, start_x:end_x, :] = mask + + cropped.append(img) + cropped = np.array(cropped) + return cropped + +class RandomRotation(object): + """Rotate entire clip randomly by a random angle within + given bounds + Args: + degrees (sequence or int): Range of degrees to select from + If degrees is a number instead of sequence like (min, max), + the range of degrees, will be (-degrees, +degrees). + """ + + def __init__(self, degrees): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError('If degrees is a single number,' + 'must be positive') + degrees = (-degrees, degrees) + else: + if len(degrees) != 2: + raise ValueError('If degrees is a sequence,' + 'it must be of len 2.') + + self.degrees = degrees + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of videos to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of videos + """ + angle = random.uniform(self.degrees[0], self.degrees[1]) + if isinstance(clip[0], np.ndarray): + rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + rotated = [img.rotate(angle) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + + return rotated + +class RandomPerspective(object): + """Rotate entire clip randomly by a random angle within + given bounds + Args: + degrees (sequence or int): Range of degrees to select from + If degrees is a number instead of sequence like (min, max), + the range of degrees, will be (-degrees, +degrees). + """ + + def __init__(self, pers_num, enlarge_num): + self.pers_num = pers_num + self.enlarge_num = enlarge_num + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of videos to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of videos + """ + out = clip + for i in range(len(clip)): + self.pers_size = np.random.randint(20, self.pers_num) * pow(-1, np.random.randint(2)) + self.enlarge_size = np.random.randint(20, self.enlarge_num) * pow(-1, np.random.randint(2)) + h, w, c = clip[i].shape + crop_size=256 + dst = np.array([ + [-self.enlarge_size, -self.enlarge_size], + [-self.enlarge_size + self.pers_size, w + self.enlarge_size], + [h + self.enlarge_size, -self.enlarge_size], + [h + self.enlarge_size - self.pers_size, w + self.enlarge_size],], dtype=np.float32) + src = np.array([[-self.enlarge_size, -self.enlarge_size], [-self.enlarge_size, w + self.enlarge_size], + [h + self.enlarge_size, -self.enlarge_size], [h + self.enlarge_size, w + self.enlarge_size]]).astype(np.float32()) + M = cv2.getPerspectiveTransform(src, dst) + warped = cv2.warpPerspective(clip[i], M, (crop_size, crop_size), borderMode=cv2.BORDER_REPLICATE) + out[i] = warped + + return out + + +class ColorJitter(object): + """Randomly change the brightness, contrast and saturation and hue of the clip + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + def get_params(self, brightness, contrast, saturation, hue): + if brightness > 0: + brightness_factor = random.uniform( + max(0, 1 - brightness), 1 + brightness) + else: + brightness_factor = None + + if contrast > 0: + contrast_factor = random.uniform( + max(0, 1 - contrast), 1 + contrast) + else: + contrast_factor = None + + if saturation > 0: + saturation_factor = random.uniform( + max(0, 1 - saturation), 1 + saturation) + else: + saturation_factor = None + + if hue > 0: + hue_factor = random.uniform(-hue, hue) + else: + hue_factor = None + return brightness_factor, contrast_factor, saturation_factor, hue_factor + + def __call__(self, clip): + """ + Args: + clip (list): list of PIL.Image + Returns: + list PIL.Image : list of transformed PIL.Image + """ + if isinstance(clip[0], np.ndarray): + brightness, contrast, saturation, hue = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + # Create img transform function sequence + img_transforms = [] + if brightness is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + if saturation is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + if hue is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + if contrast is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + random.shuffle(img_transforms) + img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array, + img_as_float] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + jittered_clip = [] + for img in clip: + jittered_img = img + for func in img_transforms: + jittered_img = func(jittered_img) + jittered_clip.append(jittered_img.astype('float32')) + elif isinstance(clip[0], PIL.Image.Image): + brightness, contrast, saturation, hue = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + # Create img transform function sequence + img_transforms = [] + if brightness is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + if saturation is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + if hue is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + if contrast is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + random.shuffle(img_transforms) + + # Apply to all videos + jittered_clip = [] + for img in clip: + for func in img_transforms: + jittered_img = func(img) + jittered_clip.append(jittered_img) + + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return jittered_clip + + +class AllAugmentationTransform: + def __init__(self, crop_mouth_param = None, resize_param=None, rotation_param=None, perspective_param=None, flip_param=None, crop_param=None, jitter_param=None): + self.transforms = [] + if crop_mouth_param is not None: + self.transforms.append(MouthCrop(**crop_mouth_param)) + + if flip_param is not None: + self.transforms.append(RandomFlip(**flip_param)) + + if rotation_param is not None: + self.transforms.append(RandomRotation(**rotation_param)) + + if perspective_param is not None: + self.transforms.append(RandomPerspective(**perspective_param)) + + if resize_param is not None: + self.transforms.append(RandomResize(**resize_param)) + + if crop_param is not None: + self.transforms.append(RandomCrop(**crop_param)) + + if jitter_param is not None: + self.transforms.append(ColorJitter(**jitter_param)) + + def __call__(self, clip): + for t in self.transforms: + clip = t(clip) + return clip diff --git a/talkingface/utils/data_process.py b/talkingface/utils/data_process.py index cbc430ac..29b89283 100644 --- a/talkingface/utils/data_process.py +++ b/talkingface/utils/data_process.py @@ -11,6 +11,11 @@ import librosa.filters from scipy import signal from scipy.io import wavfile +from skimage import transform as tf +import python_speech_features +import dlib +import imageio +from pathlib import Path class lrs2Preprocess: @@ -93,3 +98,197 @@ def run(self): traceback.print_exc() continue +class meadPreprocess: + def __init__(self, config): + self.config = config + self.detector = dlib.get_frontal_face_detector() + self.predictor = dlib.shape_predictor('checkpoints/EAMM/shape_predictor_68_face_landmarks.dat') + + def save(self, path, frames, format): + if format == '.mp4': + imageio.mimsave(path, frames) + elif format == '.png': + if not os.path.exists(path): + os.makedirs(path) + for j, frame in enumerate(frames): + cv2.imwrite(path+'/'+str(j)+'.png',frame) + else: + print ("Unknown format %s" % format) + exit() + + def shape_to_np(self, shape, dtype="int"): + # initialize the list of (x, y)-coordinates + coords = np.zeros((shape.num_parts, 2), dtype=dtype) + # loop over all facial landmarks and convert them + # to a 2-tuple of (x, y)-coordinates + for i in range(0, shape.num_parts): + coords[i] = (shape.part(i).x, shape.part(i).y) + # return the list of (x, y)-coordinates + return coords + + def crop_image(self, image_path, out_path): + template = np.load('./M003_template.npy') + image = cv2.imread(image_path) + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + rects = self.detector(gray, 1) #detect human face + if len(rects) != 1: + return 0 + for (j, rect) in enumerate(rects): + shape = self.predictor(gray, rect) #detect 68 points + shape = self.shape_to_np(shape) + pts2 = np.float32(template[:47,:]) + # pts2 = np.float32(template[17:35,:]) + # pts1 = np.vstack((landmark[27:36,:], landmark[39,:],landmark[42,:],landmark[45,:])) + pts1 = np.float32(shape[:47,:]) #eye and nose + # pts1 = np.float32(landmark[17:35,:]) + tform = tf.SimilarityTransform() + tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters. + dst = tf.warp(image, tform, output_shape=(256, 256)) + dst = np.array(dst * 255, dtype=np.uint8) + cv2.imwrite(out_path,dst) + + def crop_image_tem(self, video_path, out_path): + """ + video alignment + """ + image_all = [] + videoCapture = cv2.VideoCapture(video_path) + success, frame = videoCapture.read() + n = 0 + while success : + image_all.append(frame) + n = n + 1 + success, frame = videoCapture.read() + if len(image_all)!=0 : + template = np.load('checkpoints/EAMM/M003_template.npy') + image=image_all[0] + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + rects = self.detector(gray, 1) #detect human face + if len(rects) != 1: + return 0 + for (j, rect) in enumerate(rects): + shape = self.predictor(gray, rect) #detect 68 points + shape = self.shape_to_np(shape) + pts2 = np.float32(template[:47,:]) + pts1 = np.float32(shape[:47,:]) #eye and nose + tform = tf.SimilarityTransform() + tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters. + out = [] + for i in range(len(image_all)): + image = image_all[i] + dst = tf.warp(image, tform, output_shape=(256, 256)) + dst = np.array(dst * 255, dtype=np.uint8) + out.append(dst) + if not os.path.exists(out_path): + os.makedirs(out_path) + self.save(out_path,out,'.png') + + def proc_audio(self, src_mouth_path, dst_audio_path): + audio_command = 'ffmpeg -i \"{}\" -loglevel error -y -f wav -acodec pcm_s16le ' \ + '-ar 16000 \"{}\"'.format(src_mouth_path, dst_audio_path) + # os.system(audio_command) + subprocess.call(audio_command, shell=True) + + def audio2mfcc(self, audio_file, save, name): + speech, sr = librosa.load(audio_file, sr=16000) + speech = np.insert(speech, 0, np.zeros(1920)) + speech = np.append(speech, np.zeros(1920)) + mfcc = python_speech_features.mfcc(speech,16000,winstep=0.01) + if not os.path.exists(save): + os.makedirs(save) + time_len = mfcc.shape[0] + mfcc_all = [] + for input_idx in range(int((time_len-28)/4)+1): + input_feat = mfcc[4*input_idx:4*input_idx+28,:] + mfcc_all.append(input_feat) + np.save(os.path.join(save,name+'.npy'), mfcc_all) + # print(input_idx) + + def prepare_3dpose(self, filepath, save_path): + from talkingface.utils.pose_3ddfa.FaceBoxes import FaceBoxes + from talkingface.utils.pose_3ddfa.TDDFA import TDDFA + from talkingface.utils.pose_3ddfa.utils.pose import get_pose + import yaml + + pathDir = os.listdir(filepath) + cfg = yaml.load(open('./talkingface/properties/model/EAMM/mb1_120x120.yml'), Loader=yaml.SafeLoader) + + for i in range(len(pathDir)): + image= cv2.imread(os.path.join(filepath,pathDir[i])) + + # Init FaceBoxes and TDDFA, recommend using onnx flag + tddfa = TDDFA(gpu_mode=False, **cfg) + face_boxes = FaceBoxes() + + # Detect faces, get 3DMM params and roi boxes + boxes = face_boxes(image) + n = len(boxes) + if n == 0: + print(f'No face detected, exit') + return None + # print(f'Detect {n} faces') + + param_lst, roi_box_lst = tddfa(image, boxes) + ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=False) + all_pose = get_pose(image, param_lst, ver_lst, show_flag=False, wfp=None, wnp = None) + pose = all_pose.reshape(1,7) + + save = Path(save_path) + if not save.exists(): + save.mkdir(parents=True, exist_ok=True) + np.save((save / (pathDir[i].split('.')[0]+'.npy')),pose) + print(i,pathDir[i]) + + + def run(self): + print(f'Started processing for {self.config["data_root"]} with {self.config["ngpu"]} GPUs') + + data_root = Path(self.config['data_root']) + preprocessed_root = Path(self.config['_preprocessed_root']) + + vfilelist = list((data_root / 'crop').glob("*.mp4")) + print('video alignment...') + for vfile in tqdm(vfilelist): + try: + odir = (preprocessed_root / 'crop' / vfile.stem) + if odir.exists(): + continue + self.crop_image_tem(str(vfile), str(odir)) + except KeyboardInterrupt: + exit(0) + except: + traceback.print_exc() + continue + + afilelist = list((data_root / 'audio').glob("*.m4a")) + print('audio2mfcc...') + for afile in tqdm(afilelist): + try: + save_path = (preprocessed_root / 'MEAD_MFCC') + if (save_path / (afile.stem + '.npy')).exists(): + continue + self.audio2mfcc(str(afile), str(save_path), afile.stem) + except KeyboardInterrupt: + exit(0) + except: + traceback.print_exc() + continue + + + pfilelist = list((preprocessed_root / 'crop').glob('*')) + # print(pfilelist) + print('3d pose...') + for pfile in tqdm(pfilelist): + try: + save_path = (preprocessed_root / 'pose' / pfile.stem) + print(save_path) + if save_path.exists(): + continue + self.prepare_3dpose(str(pfile), str(save_path)) + except KeyboardInterrupt: + exit(0) + except: + traceback.print_exc() + continue + + diff --git a/talkingface/utils/eamm_logger.py b/talkingface/utils/eamm_logger.py new file mode 100644 index 00000000..eff06e3d --- /dev/null +++ b/talkingface/utils/eamm_logger.py @@ -0,0 +1,222 @@ +import numpy as np +import torch +import torch.nn.functional as F +import imageio + +import os +from skimage.draw import circle + +import matplotlib.pyplot as plt +import collections + + +class Logger: + def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'): + + self.loss_list = [] + self.cpk_dir = log_dir + self.visualizations_dir = os.path.join(log_dir, 'train-vis') + if not os.path.exists(self.visualizations_dir): + os.makedirs(self.visualizations_dir) + self.log_file = open(os.path.join(log_dir, log_file_name), 'a') + self.zfill_num = zfill_num + self.visualizer = Visualizer(**visualizer_params) + self.checkpoint_freq = checkpoint_freq + self.epoch = 0 + self.best_loss = float('inf') + self.names = None + + def log_scores(self, loss_names): + loss_mean = np.array(self.loss_list).mean(axis=0) + + loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)]) + loss_string = str(str(self.epoch)+str(self.step).zfill(self.zfill_num)) + ") " + loss_string + + print(loss_string, file=self.log_file) + self.loss_list = [] + self.log_file.flush() + + def visualize_rec(self, inp, out): + # image = self.visualizer.visualize(inp['driving'], inp['source'], out) + image = self.visualizer.visualize(inp['driving'][:,-1], inp['transformed_driving'][:,-1], inp['example_image'], out) + imageio.imsave(os.path.join(self.visualizations_dir, "%s-%s-rec.png" % (str(self.epoch),str(self.step).zfill(self.zfill_num))), image) + + def save_cpk(self, emergent=False): + cpk = {k: v.state_dict() for k, v in self.models.items()} + cpk['epoch'] = self.epoch + cpk['step'] = self.step + cpk_path = os.path.join(self.cpk_dir, '%s-%s-checkpoint.pth.tar' % (str(self.epoch),str(self.step).zfill(self.zfill_num))) + if not (os.path.exists(cpk_path) and emergent): + torch.save(cpk, cpk_path) + + @staticmethod + def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None, audio_feature=None, + optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None, optimizer_audio_feature = None): + checkpoint = torch.load(checkpoint_path) + if generator is not None: + generator.load_state_dict(checkpoint['generator']) + if kp_detector is not None: + kp_detector.load_state_dict(checkpoint['kp_detector']) + if discriminator is not None: + try: + discriminator.load_state_dict(checkpoint['discriminator']) + except: + print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') + # if audio_feature is not None: + # audio_feature.load_state_dict(checkpoint['audio_feature']) + if optimizer_generator is not None: + optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) + if optimizer_discriminator is not None: + try: + optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) + except RuntimeError as e: + print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') + if optimizer_kp_detector is not None: + optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) + # if optimizer_audio_feature is not None: + # a = checkpoint['optimizer_kp_detector']['param_groups'] + # a[0].pop('params') + # optimizer_audio_feature.load_state_dict(checkpoint['optimizer_audio_feature']) + + return checkpoint['epoch'] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if 'models' in self.__dict__: + self.save_cpk() + self.log_file.close() + + def log_iter(self, losses): + losses = collections.OrderedDict(losses.items()) + if self.names is None: + self.names = list(losses.keys()) + self.loss_list.append(list(losses.values())) + + def log_epoch(self, epoch, step, models, inp, out): + self.epoch = epoch + self.step = step + self.models = models + if (self.epoch + 1) % self.checkpoint_freq == 0: + self.save_cpk() + self.log_scores(self.names) + self.visualize_rec(inp, out) + + +class Visualizer: + def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'): + self.kp_size = kp_size + self.draw_border = draw_border + self.colormap = plt.get_cmap(colormap) + + def draw_image_with_kp(self, image, kp_array): + image = np.copy(image) + spatial_size = np.array(image.shape[:2][::-1])[np.newaxis] + kp_array = spatial_size * (kp_array + 1) / 2 + num_kp = kp_array.shape[0] + for kp_ind, kp in enumerate(kp_array): + rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2]) + image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3] + return image + + def create_image_column_with_kp(self, images, kp): + image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)]) + return self.create_image_column(image_array) + + def create_image_column(self, images): + if self.draw_border: + images = np.copy(images) + images[:, :, [0, -1]] = (1, 1, 1) + images[:, :, [0, -1]] = (1, 1, 1) + return np.concatenate(list(images), axis=0) + + def create_image_grid(self, *args): + out = [] + for arg in args: + if type(arg) == tuple: + out.append(self.create_image_column_with_kp(arg[0], arg[1])) + else: + out.append(self.create_image_column(arg)) + return np.concatenate(out, axis=1) + + def visualize(self, driving, transformed_driving, source, out): + images = [] + + # Source image with keypoints + source = source.data.cpu() + kp_source = out['kp_source']['value'].data.cpu().numpy() + source = np.transpose(source, [0, 2, 3, 1]) + images.append((source, kp_source)) + + # Equivariance visualization + if 'transformed_frame' in out: + transformed = out['transformed_frame'].data.cpu().numpy() + transformed = np.transpose(transformed, [0, 2, 3, 1]) + transformed_kp = out['transformed_kp']['value'].data.cpu().numpy() + images.append((transformed, transformed_kp)) + + # Equivariance visualization + transformed_driving = transformed_driving.data.cpu().numpy() + transformed_driving = np.transpose(transformed_driving, [0, 2, 3, 1]) + images.append(transformed_driving) + + # Driving image with keypoints + kp_driving = out['kp_driving'][-1]['value'].data.cpu().numpy() #[-1]['value'] + driving = driving.data.cpu().numpy() + driving = np.transpose(driving, [0, 2, 3, 1]) + images.append((driving, kp_driving)) + + # Deformed image + if 'deformed' in out: + deformed = out['deformed'].data.cpu().numpy() + deformed = np.transpose(deformed, [0, 2, 3, 1]) + images.append(deformed) + + # Result with and without keypoints + prediction = out['prediction'].data.cpu().numpy() + prediction = np.transpose(prediction, [0, 2, 3, 1]) + if 'kp_norm' in out: + kp_norm = out['kp_norm']['value'].data.cpu().numpy() + images.append((prediction, kp_norm)) + images.append(prediction) + + + ## Occlusion map + if 'occlusion_map' in out: + occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1) + occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy() + occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1]) + images.append(occlusion_map) + + # Deformed images according to each individual transform + if 'sparse_deformed' in out: + full_mask = [] + for i in range(out['sparse_deformed'].shape[1]): + image = out['sparse_deformed'][:, i].data.cpu() + image = F.interpolate(image, size=source.shape[1:3]) + mask = out['mask'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1) + mask = F.interpolate(mask, size=source.shape[1:3]) + image = np.transpose(image.numpy(), (0, 2, 3, 1)) + mask = np.transpose(mask.numpy(), (0, 2, 3, 1)) + + if i != 0: + color = np.array(self.colormap((i - 1) / (out['sparse_deformed'].shape[1] - 1)))[:3] + else: + color = np.array((0, 0, 0)) + + color = color.reshape((1, 1, 1, 3)) + + images.append(image) + if i != 0: + images.append(mask * color) + else: + images.append(mask) + + full_mask.append(mask * color) + + images.append(sum(full_mask)) + + image = self.create_image_grid(*images) + image = (255 * image).astype(np.uint8) + return image diff --git a/talkingface/utils/filter1.py b/talkingface/utils/filter1.py new file mode 100644 index 00000000..ca234247 --- /dev/null +++ b/talkingface/utils/filter1.py @@ -0,0 +1,48 @@ +import cv2 +#import pickle +import time +import numpy as np +import copy + +from matplotlib import pyplot as plt +from tqdm import tqdm + + + + +class LowPassFilter: + def __init__(self): + self.prev_raw_value = None + self.prev_filtered_value = None + + def process(self, value, alpha): + if self.prev_raw_value is None: + s = value + else: + s = alpha * value + (1.0 - alpha) * self.prev_filtered_value + self.prev_raw_value = value + self.prev_filtered_value = s + return s + + +class OneEuroFilter: + def __init__(self, mincutoff=1.0, beta=0.0, dcutoff=1.0, freq=30): + self.freq = freq + self.mincutoff = mincutoff + self.beta = beta + self.dcutoff = dcutoff + self.x_filter = LowPassFilter() + self.dx_filter = LowPassFilter() + + def compute_alpha(self, cutoff): + te = 1.0 / self.freq + tau = 1.0 / (2 * np.pi * cutoff) + return 1.0 / (1.0 + tau / te) + + def process(self, x): + prev_x = self.x_filter.prev_raw_value + dx = 0.0 if prev_x is None else (x - prev_x) * self.freq + edx = self.dx_filter.process(dx, self.compute_alpha(self.dcutoff)) + cutoff = self.mincutoff + self.beta * np.abs(edx) + return self.x_filter.process(x, self.compute_alpha(cutoff)) + diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/.gitignore b/talkingface/utils/pose_3ddfa/FaceBoxes/.gitignore new file mode 100644 index 00000000..523e2009 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/.gitignore @@ -0,0 +1,3 @@ +.idea/ +__pycache__ +**/__pycache__ \ No newline at end of file diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/FaceBoxes.py b/talkingface/utils/pose_3ddfa/FaceBoxes/FaceBoxes.py new file mode 100644 index 00000000..78545584 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/FaceBoxes.py @@ -0,0 +1,162 @@ +# coding: utf-8 + +import os.path as osp + +import torch +import numpy as np +import cv2 + +from .utils.prior_box import PriorBox +from .utils.nms_wrapper import nms +from .utils.box_utils import decode +from .utils.timer import Timer +from .utils.functions import check_keys, remove_prefix, load_model +from .utils.config import cfg +from .models.faceboxes import FaceBoxesNet + +# some global configs +confidence_threshold = 0.05 +top_k = 5000 +keep_top_k = 750 +nms_threshold = 0.3 +vis_thres = 0.5 +resize = 1 + +scale_flag = True +HEIGHT, WIDTH = 720, 1080 + +make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) +pretrained_path = make_abs_path('weights/FaceBoxesProd.pth') + + +def viz_bbox(img, dets, wfp='out.jpg'): + # show + for b in dets: + if b[4] < vis_thres: + continue + text = "{:.4f}".format(b[4]) + b = list(map(int, b)) + cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) + cx = b[0] + cy = b[1] + 12 + cv2.putText(img, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255)) + cv2.imwrite(wfp, img) + print(f'Viz bbox to {wfp}') + + +class FaceBoxes: + def __init__(self, timer_flag=False): + torch.set_grad_enabled(False) + + net = FaceBoxesNet(phase='test', size=None, num_classes=2) # initialize detector + self.net = load_model(net, pretrained_path=pretrained_path, load_to_cpu=True) + self.net.eval() + # print('Finished loading model!') + + self.timer_flag = timer_flag + + def __call__(self, img_): + img_raw = img_.copy() + + # scaling to speed up + scale = 1 + if scale_flag: + h, w = img_raw.shape[:2] + if h > HEIGHT: + scale = HEIGHT / h + if w * scale > WIDTH: + scale *= WIDTH / (w * scale) + # print(scale) + if scale == 1: + img_raw_scale = img_raw + else: + h_s = int(scale * h) + w_s = int(scale * w) + # print(h_s, w_s) + img_raw_scale = cv2.resize(img_raw, dsize=(w_s, h_s)) + # print(img_raw_scale.shape) + + img = np.float32(img_raw_scale) + else: + img = np.float32(img_raw) + + # forward + _t = {'forward_pass': Timer(), 'misc': Timer()} + im_height, im_width, _ = img.shape + scale_bbox = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) + img -= (104, 117, 123) + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img).unsqueeze(0) + + _t['forward_pass'].tic() + loc, conf = self.net(img) # forward pass + _t['forward_pass'].toc() + _t['misc'].tic() + priorbox = PriorBox(image_size=(im_height, im_width)) + priors = priorbox.forward() + prior_data = priors.data + boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance']) + if scale_flag: + boxes = boxes * scale_bbox / scale / resize + else: + boxes = boxes * scale_bbox / resize + + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = nms(dets, nms_threshold) + dets = dets[keep, :] + + # keep top-K faster NMS + dets = dets[:keep_top_k, :] + _t['misc'].toc() + + if self.timer_flag: + print('Detection: {:d}/{:d} forward_pass_time: {:.4f}s misc: {:.4f}s'.format(1, 1, _t[ + 'forward_pass'].average_time, _t['misc'].average_time)) + + # filter using vis_thres + det_bboxes = [] + for b in dets: + if b[4] > vis_thres: + xmin, ymin, xmax, ymax, score = b[0], b[1], b[2], b[3], b[4] + bbox = [xmin, ymin, xmax, ymax, score] + det_bboxes.append(bbox) + + return det_bboxes + + +def main(): + face_boxes = FaceBoxes(timer_flag=True) + + fn = 'trump_hillary.jpg' + img_fp = f'../examples/inputs/{fn}' + img = cv2.imread(img_fp) + print(f'input shape: {img.shape}') + dets = face_boxes(img) # xmin, ymin, w, h + # print(dets) + + # repeating inference for `n` times + n = 10 + for i in range(n): + dets = face_boxes(img) + + wfn = fn.replace('.jpg', '_det.jpg') + wfp = osp.join('../examples/results', wfn) + viz_bbox(img, dets, wfp) + + +if __name__ == '__main__': + main() diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/FaceBoxes_ONNX.py b/talkingface/utils/pose_3ddfa/FaceBoxes/FaceBoxes_ONNX.py new file mode 100644 index 00000000..619f5095 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/FaceBoxes_ONNX.py @@ -0,0 +1,168 @@ +# coding: utf-8 + +import os.path as osp + +import torch +import numpy as np +import cv2 + +from .utils.prior_box import PriorBox +from .utils.nms_wrapper import nms +from .utils.box_utils import decode +from .utils.timer import Timer +from .utils.config import cfg +from .onnx import convert_to_onnx + +import onnxruntime + +# some global configs +confidence_threshold = 0.05 +top_k = 5000 +keep_top_k = 750 +nms_threshold = 0.3 +vis_thres = 0.5 +resize = 1 + +scale_flag = True +HEIGHT, WIDTH = 720, 1080 + +make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) +onnx_path = make_abs_path('weights/FaceBoxesProd.onnx') + + +def viz_bbox(img, dets, wfp='out.jpg'): + # show + for b in dets: + if b[4] < vis_thres: + continue + text = "{:.4f}".format(b[4]) + b = list(map(int, b)) + cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) + cx = b[0] + cy = b[1] + 12 + cv2.putText(img, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255)) + cv2.imwrite(wfp, img) + print(f'Viz bbox to {wfp}') + + +class FaceBoxes_ONNX(object): + def __init__(self, timer_flag=False): + if not osp.exists(onnx_path): + convert_to_onnx(onnx_path) + self.session = onnxruntime.InferenceSession(onnx_path, None) + + self.timer_flag = timer_flag + + def __call__(self, img_): + img_raw = img_.copy() + + # scaling to speed up + scale = 1 + if scale_flag: + h, w = img_raw.shape[:2] + if h > HEIGHT: + scale = HEIGHT / h + if w * scale > WIDTH: + scale *= WIDTH / (w * scale) + # print(scale) + if scale == 1: + img_raw_scale = img_raw + else: + h_s = int(scale * h) + w_s = int(scale * w) + # print(h_s, w_s) + img_raw_scale = cv2.resize(img_raw, dsize=(w_s, h_s)) + # print(img_raw_scale.shape) + + img = np.float32(img_raw_scale) + else: + img = np.float32(img_raw) + + # forward + _t = {'forward_pass': Timer(), 'misc': Timer()} + im_height, im_width, _ = img.shape + scale_bbox = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) + + img -= (104, 117, 123) + img = img.transpose(2, 0, 1) + # img = torch.from_numpy(img).unsqueeze(0) + img = img[np.newaxis, ...] + + _t['forward_pass'].tic() + # loc, conf = self.net(img) # forward pass + out = self.session.run(None, {'input': img}) + loc, conf = out[0], out[1] + # for compatibility, may need to optimize + loc = torch.from_numpy(loc) + _t['forward_pass'].toc() + _t['misc'].tic() + + priorbox = PriorBox(image_size=(im_height, im_width)) + priors = priorbox.forward() + prior_data = priors.data + boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance']) + if scale_flag: + boxes = boxes * scale_bbox / scale / resize + else: + boxes = boxes * scale_bbox / resize + + boxes = boxes.cpu().numpy() + scores = conf[0][:, 1] + # scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = nms(dets, nms_threshold) + dets = dets[keep, :] + + # keep top-K faster NMS + dets = dets[:keep_top_k, :] + _t['misc'].toc() + + if self.timer_flag: + print('Detection: {:d}/{:d} forward_pass_time: {:.4f}s misc: {:.4f}s'.format(1, 1, _t[ + 'forward_pass'].average_time, _t['misc'].average_time)) + + # filter using vis_thres + det_bboxes = [] + for b in dets: + if b[4] > vis_thres: + xmin, ymin, xmax, ymax, score = b[0], b[1], b[2], b[3], b[4] + bbox = [xmin, ymin, xmax, ymax, score] + det_bboxes.append(bbox) + + return det_bboxes + + +def main(): + face_boxes = FaceBoxes_ONNX(timer_flag=True) + + fn = 'trump_hillary.jpg' + img_fp = f'../examples/inputs/{fn}' + img = cv2.imread(img_fp) + print(f'input shape: {img.shape}') + dets = face_boxes(img) # xmin, ymin, w, h + # print(dets) + + # repeating inference for `n` times + n = 10 + for i in range(n): + dets = face_boxes(img) + + wfn = fn.replace('.jpg', '_det.jpg') + wfp = osp.join('../examples/results', wfn) + viz_bbox(img, dets, wfp) + + +if __name__ == '__main__': + main() diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/__init__.py b/talkingface/utils/pose_3ddfa/FaceBoxes/__init__.py new file mode 100644 index 00000000..3ab63ff2 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/__init__.py @@ -0,0 +1 @@ +from .FaceBoxes import FaceBoxes diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/build_cpu_nms.sh b/talkingface/utils/pose_3ddfa/FaceBoxes/build_cpu_nms.sh new file mode 100644 index 00000000..98332386 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/build_cpu_nms.sh @@ -0,0 +1,3 @@ +cd utils +python3 build.py build_ext --inplace +cd .. \ No newline at end of file diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/models/__init__.py b/talkingface/utils/pose_3ddfa/FaceBoxes/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/models/faceboxes.py b/talkingface/utils/pose_3ddfa/FaceBoxes/models/faceboxes.py new file mode 100644 index 00000000..83ca7f0d --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/models/faceboxes.py @@ -0,0 +1,150 @@ +# coding: utf-8 + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=1e-5) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) + + +class Inception(nn.Module): + def __init__(self): + super(Inception, self).__init__() + self.branch1x1 = BasicConv2d(128, 32, kernel_size=1, padding=0) + self.branch1x1_2 = BasicConv2d(128, 32, kernel_size=1, padding=0) + self.branch3x3_reduce = BasicConv2d(128, 24, kernel_size=1, padding=0) + self.branch3x3 = BasicConv2d(24, 32, kernel_size=3, padding=1) + self.branch3x3_reduce_2 = BasicConv2d(128, 24, kernel_size=1, padding=0) + self.branch3x3_2 = BasicConv2d(24, 32, kernel_size=3, padding=1) + self.branch3x3_3 = BasicConv2d(32, 32, kernel_size=3, padding=1) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch1x1_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch1x1_2 = self.branch1x1_2(branch1x1_pool) + + branch3x3_reduce = self.branch3x3_reduce(x) + branch3x3 = self.branch3x3(branch3x3_reduce) + + branch3x3_reduce_2 = self.branch3x3_reduce_2(x) + branch3x3_2 = self.branch3x3_2(branch3x3_reduce_2) + branch3x3_3 = self.branch3x3_3(branch3x3_2) + + outputs = [branch1x1, branch1x1_2, branch3x3, branch3x3_3] + return torch.cat(outputs, 1) + + +class CRelu(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(CRelu, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=1e-5) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = torch.cat([x, -x], 1) + x = F.relu(x, inplace=True) + return x + + +class FaceBoxesNet(nn.Module): + + def __init__(self, phase, size, num_classes): + super(FaceBoxesNet, self).__init__() + self.phase = phase + self.num_classes = num_classes + self.size = size + + self.conv1 = CRelu(3, 24, kernel_size=7, stride=4, padding=3) + self.conv2 = CRelu(48, 64, kernel_size=5, stride=2, padding=2) + + self.inception1 = Inception() + self.inception2 = Inception() + self.inception3 = Inception() + + self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0) + self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1) + + self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0) + self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1) + + self.loc, self.conf = self.multibox(self.num_classes) + + if self.phase == 'test': + self.softmax = nn.Softmax(dim=-1) + + if self.phase == 'train': + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + nn.init.xavier_normal_(m.weight.data) + m.bias.data.fill_(0.02) + else: + m.weight.data.normal_(0, 0.01) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def multibox(self, num_classes): + loc_layers = [] + conf_layers = [] + loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)] + conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)] + loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)] + conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)] + loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)] + conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)] + return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers) + + def forward(self, x): + + detection_sources = list() + loc = list() + conf = list() + + x = self.conv1(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + x = self.conv2(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + x = self.inception1(x) + x = self.inception2(x) + x = self.inception3(x) + detection_sources.append(x) + + x = self.conv3_1(x) + x = self.conv3_2(x) + detection_sources.append(x) + + x = self.conv4_1(x) + x = self.conv4_2(x) + detection_sources.append(x) + + for (x, l, c) in zip(detection_sources, self.loc, self.conf): + loc.append(l(x).permute(0, 2, 3, 1).contiguous()) + conf.append(c(x).permute(0, 2, 3, 1).contiguous()) + + loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) + conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) + + if self.phase == "test": + output = (loc.view(loc.size(0), -1, 4), + self.softmax(conf.view(conf.size(0), -1, self.num_classes))) + else: + output = (loc.view(loc.size(0), -1, 4), + conf.view(conf.size(0), -1, self.num_classes)) + + return output diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/onnx.py b/talkingface/utils/pose_3ddfa/FaceBoxes/onnx.py new file mode 100644 index 00000000..2664f395 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/onnx.py @@ -0,0 +1,35 @@ +# coding: utf-8 + +__author__ = 'cleardusk' + +import torch + +from .models.faceboxes import FaceBoxesNet +from .utils.functions import load_model + + +def convert_to_onnx(onnx_path): + pretrained_path = onnx_path.replace('.onnx', '.pth') + # 1. load model + torch.set_grad_enabled(False) + net = FaceBoxesNet(phase='test', size=None, num_classes=2) # initialize detector + net = load_model(net, pretrained_path=pretrained_path, load_to_cpu=True) + net.eval() + + # 2. convert + batch_size = 1 + dummy_input = torch.randn(batch_size, 3, 720, 1080) + # export with dynamic axes for various input sizes + torch.onnx.export( + net, + (dummy_input,), + onnx_path, + input_names=['input'], + output_names=['output'], + dynamic_axes={ + 'input': [0, 2, 3], + 'output': [0] + }, + do_constant_folding=True + ) + print(f'Convert {pretrained_path} to {onnx_path} done.') diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/readme.md b/talkingface/utils/pose_3ddfa/FaceBoxes/readme.md new file mode 100644 index 00000000..6ac97022 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/readme.md @@ -0,0 +1,18 @@ +## How to fun FaceBoxes + +### Build the cpu version of NMS +```shell script +cd utils +python3 build.py build_ext --inplace +``` + +or just run + +```shell script +sh ./build_cpu_nms.sh +``` + +### Run the demo of face detection +```shell script +python3 FaceBoxes.py +``` \ No newline at end of file diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/.gitignore b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/.gitignore new file mode 100644 index 00000000..1fde63e2 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/.gitignore @@ -0,0 +1,4 @@ +utils/build +utils/nms/*.so +utils/*.c +build/ diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/__init__.py b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/box_utils.py b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/box_utils.py new file mode 100644 index 00000000..50092268 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/box_utils.py @@ -0,0 +1,276 @@ +# coding: utf-8 + +import torch +import numpy as np + + +def point_form(boxes): + """ Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin + boxes[:, :2] + boxes[:, 2:] / 2), 1) # xmax, ymax + + +def center_size(boxes): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy + boxes[:, 2:] - boxes[:, :2], 1) # w, h + + +def intersect(box_a, box_b): + """ We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), + box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), + box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2] - box_a[:, 0]) * + (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2] - box_b[:, 0]) * + (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i) + + +def matrix_iof(a, b): + """ + return iof of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when mathing boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + loc_t: (tensor) Tensor to be filled w/ endcoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location and 2)confidence preds. + """ + # jaccard index + overlaps = jaccard( + truths, + point_form(priors) + ) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + loc_t[idx] = 0 + conf_t[idx] = 0 + return + + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] + conf = labels[best_truth_idx] # Shape: [num_priors] + conf[best_truth_overlap < threshold] = 0 # label as background + loc = encode(matches, priors, variances) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = torch.Tensor(scores.size(0)).fill_(0).long() + if boxes.numel() == 0: + return keep + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w * h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter / union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/build.py b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/build.py new file mode 100644 index 00000000..405a1fc8 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/build.py @@ -0,0 +1,57 @@ +# coding: utf-8 + +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +import os +from os.path import join as pjoin +import numpy as np +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + + +def find_in_path(name, path): + "Find a file in a search path" + # adapted fom http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/ + for dir in path.split(os.pathsep): + binpath = pjoin(dir, name) + if os.path.exists(binpath): + return os.path.abspath(binpath) + return None + + +# Obtain the numpy include directory. This logic works across numpy versions. +try: + numpy_include = np.get_include() +except AttributeError: + numpy_include = np.get_numpy_include() + + +# run the customize_compiler +class custom_build_ext(build_ext): + def build_extensions(self): + # customize_compiler_for_nvcc(self.compiler) + build_ext.build_extensions(self) + + +ext_modules = [ + Extension( + "nms.cpu_nms", + ["nms/cpu_nms.pyx"], + # extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]}, + # extra_compile_args=["-Wno-cpp", "-Wno-unused-function"], + include_dirs=[numpy_include] + ) +] + +setup( + name='mot_utils', + ext_modules=ext_modules, + # inject our custom trigger + cmdclass={'build_ext': custom_build_ext}, +) diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/config.py b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/config.py new file mode 100644 index 00000000..20a29cc4 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/config.py @@ -0,0 +1,9 @@ +# coding: utf-8 + +cfg = { + 'name': 'FaceBoxes', + 'min_sizes': [[32, 64, 128], [256], [512]], + 'steps': [32, 64, 128], + 'variance': [0.1, 0.2], + 'clip': False +} diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/functions.py b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/functions.py new file mode 100644 index 00000000..a882575f --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/functions.py @@ -0,0 +1,43 @@ +# coding: utf-8 + +import sys +import os.path as osp +import torch + +def check_keys(model, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(model.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + unused_pretrained_keys = ckpt_keys - model_keys + missing_keys = model_keys - ckpt_keys + # print('Missing keys:{}'.format(len(missing_keys))) + # print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) + # print('Used keys:{}'.format(len(used_pretrained_keys))) + assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' + return True + + +def remove_prefix(state_dict, prefix): + ''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' + # print('remove prefix \'{}\''.format(prefix)) + f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x + return {f(key): value for key, value in state_dict.items()} + + +def load_model(model, pretrained_path, load_to_cpu): + if not osp.isfile(pretrained_path): + print(f'The pre-trained FaceBoxes model {pretrained_path} does not exist') + sys.exit('-1') + # print('Loading pretrained model from {}'.format(pretrained_path)) + if load_to_cpu: + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) + else: + device = torch.cuda.current_device() + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) + if "state_dict" in pretrained_dict.keys(): + pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') + else: + pretrained_dict = remove_prefix(pretrained_dict, 'module.') + check_keys(model, pretrained_dict) + model.load_state_dict(pretrained_dict, strict=False) + return model diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/.gitignore b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/.gitignore new file mode 100644 index 00000000..4b8a745c --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/.gitignore @@ -0,0 +1,2 @@ +*.c +*.so diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/__init__.py b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/cpu_nms.cp38-win_amd64.pyd b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/cpu_nms.cp38-win_amd64.pyd new file mode 100644 index 00000000..6c8a3b4d Binary files /dev/null and b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/cpu_nms.cp38-win_amd64.pyd differ diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/cpu_nms.pyx b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/cpu_nms.pyx new file mode 100644 index 00000000..898c9acf --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/cpu_nms.pyx @@ -0,0 +1,163 @@ +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +import numpy as np +cimport numpy as np + +cdef inline np.float32_t max(np.float32_t a, np.float32_t b): + return a if a >= b else b + +cdef inline np.float32_t min(np.float32_t a, np.float32_t b): + return a if a <= b else b + +def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): + cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] + cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] + cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] + cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] + cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] + + cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) + cdef np.ndarray[np.int64_t, ndim=1] order = scores.argsort()[::-1] + + cdef int ndets = dets.shape[0] + cdef np.ndarray[np.int64_t, ndim=1] suppressed = \ + np.zeros((ndets), dtype=np.int64) + + # nominal indices + cdef int _i, _j + # sorted indices + cdef int i, j + # temp variables for box i's (the box currently under consideration) + cdef np.float32_t ix1, iy1, ix2, iy2, iarea + # variables for computing overlap with box j (lower scoring box) + cdef np.float32_t xx1, yy1, xx2, yy2 + cdef np.float32_t w, h + cdef np.float32_t inter, ovr + + keep = [] + for _i in range(ndets): + i = order[_i] + if suppressed[i] == 1: + continue + keep.append(i) + ix1 = x1[i] + iy1 = y1[i] + ix2 = x2[i] + iy2 = y2[i] + iarea = areas[i] + for _j in range(_i + 1, ndets): + j = order[_j] + if suppressed[j] == 1: + continue + xx1 = max(ix1, x1[j]) + yy1 = max(iy1, y1[j]) + xx2 = min(ix2, x2[j]) + yy2 = min(iy2, y2[j]) + w = max(0.0, xx2 - xx1 + 1) + h = max(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (iarea + areas[j] - inter) + if ovr >= thresh: + suppressed[j] = 1 + + return keep + +def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): + cdef unsigned int N = boxes.shape[0] + cdef float iw, ih, box_area + cdef float ua + cdef int pos = 0 + cdef float maxscore = 0 + cdef int maxpos = 0 + cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov + + for i in range(N): + maxscore = boxes[i, 4] + maxpos = i + + tx1 = boxes[i,0] + ty1 = boxes[i,1] + tx2 = boxes[i,2] + ty2 = boxes[i,3] + ts = boxes[i,4] + + pos = i + 1 + # get max box + while pos < N: + if maxscore < boxes[pos, 4]: + maxscore = boxes[pos, 4] + maxpos = pos + pos = pos + 1 + + # add max box as a detection + boxes[i,0] = boxes[maxpos,0] + boxes[i,1] = boxes[maxpos,1] + boxes[i,2] = boxes[maxpos,2] + boxes[i,3] = boxes[maxpos,3] + boxes[i,4] = boxes[maxpos,4] + + # swap ith box with position of max box + boxes[maxpos,0] = tx1 + boxes[maxpos,1] = ty1 + boxes[maxpos,2] = tx2 + boxes[maxpos,3] = ty2 + boxes[maxpos,4] = ts + + tx1 = boxes[i,0] + ty1 = boxes[i,1] + tx2 = boxes[i,2] + ty2 = boxes[i,3] + ts = boxes[i,4] + + pos = i + 1 + # NMS iterations, note that N changes if detection boxes fall below threshold + while pos < N: + x1 = boxes[pos, 0] + y1 = boxes[pos, 1] + x2 = boxes[pos, 2] + y2 = boxes[pos, 3] + s = boxes[pos, 4] + + area = (x2 - x1 + 1) * (y2 - y1 + 1) + iw = (min(tx2, x2) - max(tx1, x1) + 1) + if iw > 0: + ih = (min(ty2, y2) - max(ty1, y1) + 1) + if ih > 0: + ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) + ov = iw * ih / ua #iou between max box and detection box + + if method == 1: # linear + if ov > Nt: + weight = 1 - ov + else: + weight = 1 + elif method == 2: # gaussian + weight = np.exp(-(ov * ov)/sigma) + else: # original NMS + if ov > Nt: + weight = 0 + else: + weight = 1 + + boxes[pos, 4] = weight*boxes[pos, 4] + + # if box score falls below threshold, discard the box by swapping with last box + # update N + if boxes[pos, 4] < threshold: + boxes[pos,0] = boxes[N-1, 0] + boxes[pos,1] = boxes[N-1, 1] + boxes[pos,2] = boxes[N-1, 2] + boxes[pos,3] = boxes[N-1, 3] + boxes[pos,4] = boxes[N-1, 4] + N = N - 1 + pos = pos - 1 + + pos = pos + 1 + + keep = [i for i in range(N)] + return keep diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/py_cpu_nms.py b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/py_cpu_nms.py new file mode 100644 index 00000000..54e7b25f --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms/py_cpu_nms.py @@ -0,0 +1,38 @@ +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +import numpy as np + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms_wrapper.py b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms_wrapper.py new file mode 100644 index 00000000..765ac9fe --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/nms_wrapper.py @@ -0,0 +1,19 @@ +# coding: utf-8 + +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +from .nms.cpu_nms import cpu_nms, cpu_soft_nms + + +def nms(dets, thresh): + """Dispatch to either CPU or GPU NMS implementations.""" + + if dets.shape[0] == 0: + return [] + return cpu_nms(dets, thresh) + # return gpu_nms(dets, thresh) diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/prior_box.py b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/prior_box.py new file mode 100644 index 00000000..7410c811 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/prior_box.py @@ -0,0 +1,48 @@ +# coding: utf-8 + +from .config import cfg + +import torch +from itertools import product as product +from math import ceil + + +class PriorBox(object): + def __init__(self, image_size=None): + super(PriorBox, self).__init__() + # self.aspect_ratios = cfg['aspect_ratios'] + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps] + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + if min_size == 32: + dense_cx = [x * self.steps[k] / self.image_size[1] for x in + [j + 0, j + 0.25, j + 0.5, j + 0.75]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in + [i + 0, i + 0.25, i + 0.5, i + 0.75]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + elif min_size == 64: + dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0, j + 0.5]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0, i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + else: + cx = (j + 0.5) * self.steps[k] / self.image_size[1] + cy = (i + 0.5) * self.steps[k] / self.image_size[0] + anchors += [cx, cy, s_kx, s_ky] + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/utils/timer.py b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/timer.py new file mode 100644 index 00000000..265c2456 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/utils/timer.py @@ -0,0 +1,43 @@ +# coding: utf-8 + +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +import time + + +class Timer(object): + """A simple timer.""" + + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + return self.average_time + else: + return self.diff + + def clear(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/weights/.gitignore b/talkingface/utils/pose_3ddfa/FaceBoxes/weights/.gitignore new file mode 100644 index 00000000..e1a699ac --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/weights/.gitignore @@ -0,0 +1 @@ +*.onnx diff --git a/talkingface/utils/pose_3ddfa/FaceBoxes/weights/readme.md b/talkingface/utils/pose_3ddfa/FaceBoxes/weights/readme.md new file mode 100644 index 00000000..25d1ba21 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/FaceBoxes/weights/readme.md @@ -0,0 +1,3 @@ +The pre-trained model `FaceBoxesProd.pth` is downloaded from [Google Drive](https://drive.google.com/file/d/1tRVwOlu0QtjvADQ2H7vqrRwsWEmaqioI). + +The converted `FaceBoxesProd.onnx`: [Google Drive](https://drive.google.com/file/d/1pccQOvYqKh3iCEHc5tSWx2-1fhgxs6rh/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1TJS2wFRLSoWZPR4l9E7G7w) (Password: 9hph) diff --git a/talkingface/utils/pose_3ddfa/TDDFA.py b/talkingface/utils/pose_3ddfa/TDDFA.py new file mode 100644 index 00000000..84f629d2 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/TDDFA.py @@ -0,0 +1,143 @@ +# coding: utf-8 + +__author__ = 'cleardusk' + +import os.path as osp +import time +import numpy as np +import cv2 +import torch +from torchvision.transforms import Compose +import torch.backends.cudnn as cudnn + +from . import models +from .bfm import BFMModel +from .utils.io import _load +from .utils.functions import ( + crop_img, parse_roi_box_from_bbox, parse_roi_box_from_landmark, +) +from .utils.tddfa_util import ( + load_model, _parse_param, similar_transform, + ToTensorGjz, NormalizeGjz +) + +make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) + + +class TDDFA(object): + """TDDFA: named Three-D Dense Face Alignment (TDDFA)""" + + def __init__(self, **kvs): + torch.set_grad_enabled(False) + + # load BFM + self.bfm = BFMModel( + bfm_fp=kvs.get('bfm_fp', make_abs_path('configs/bfm_noneck_v3.pkl')), + shape_dim=kvs.get('shape_dim', 40), + exp_dim=kvs.get('exp_dim', 10) + ) + self.tri = self.bfm.tri + + # config + self.gpu_mode = kvs.get('gpu_mode', False) + self.gpu_id = kvs.get('gpu_id', 0) + self.size = kvs.get('size', 120) + + param_mean_std_fp = kvs.get( + 'param_mean_std_fp', make_abs_path(f'configs/param_mean_std_62d_{self.size}x{self.size}.pkl') + ) + + # load model, default output is dimension with length 62 = 12(pose) + 40(shape) +10(expression) + model = getattr(models, kvs.get('arch'))( + num_classes=kvs.get('num_params', 62), + widen_factor=kvs.get('widen_factor', 1), + size=self.size, + mode=kvs.get('mode', 'small') + ) + model = load_model(model, kvs.get('checkpoint_fp')) + + if self.gpu_mode: + cudnn.benchmark = True + model = model.cuda(device=self.gpu_id) + + self.model = model + self.model.eval() # eval mode, fix BN + + # data normalization + transform_normalize = NormalizeGjz(mean=127.5, std=128) + transform_to_tensor = ToTensorGjz() + transform = Compose([transform_to_tensor, transform_normalize]) + self.transform = transform + + # params normalization config + r = _load(param_mean_std_fp) + self.param_mean = r.get('mean') + self.param_std = r.get('std') + + # print('param_mean and param_srd', self.param_mean, self.param_std) + + def __call__(self, img_ori, objs, **kvs): + """The main call of TDDFA, given image and box / landmark, return 3DMM params and roi_box + :param img_ori: the input image + :param objs: the list of box or landmarks + :param kvs: options + :return: param list and roi_box list + """ + # Crop image, forward to get the param + param_lst = [] + roi_box_lst = [] + + crop_policy = kvs.get('crop_policy', 'box') + for obj in objs: + if crop_policy == 'box': + # by face box + roi_box = parse_roi_box_from_bbox(obj) + elif crop_policy == 'landmark': + # by landmarks + roi_box = parse_roi_box_from_landmark(obj) + else: + raise ValueError(f'Unknown crop policy {crop_policy}') + + roi_box_lst.append(roi_box) + img = crop_img(img_ori, roi_box) + img = cv2.resize(img, dsize=(self.size, self.size), interpolation=cv2.INTER_LINEAR) + inp = self.transform(img).unsqueeze(0) + + if self.gpu_mode: + inp = inp.cuda(device=self.gpu_id) + + if kvs.get('timer_flag', False): + end = time.time() + param = self.model(inp) + elapse = f'Inference: {(time.time() - end) * 1000:.1f}ms' + print(elapse) + else: + param = self.model(inp) + + param = param.squeeze().cpu().numpy().flatten().astype(np.float32) + param = param * self.param_std + self.param_mean # re-scale + # print('output', param) + param_lst.append(param) + + return param_lst, roi_box_lst + + def recon_vers(self, param_lst, roi_box_lst, **kvs): + dense_flag = kvs.get('dense_flag', False) + size = self.size + + ver_lst = [] + for param, roi_box in zip(param_lst, roi_box_lst): + if dense_flag: + R, offset, alpha_shp, alpha_exp = _parse_param(param) + pts3d = R @ (self.bfm.u + self.bfm.w_shp @ alpha_shp + self.bfm.w_exp @ alpha_exp). \ + reshape(3, -1, order='F') + offset + pts3d = similar_transform(pts3d, roi_box, size) + else: + R, offset, alpha_shp, alpha_exp = _parse_param(param) + pts3d = R @ (self.bfm.u_base + self.bfm.w_shp_base @ alpha_shp + self.bfm.w_exp_base @ alpha_exp). \ + reshape(3, -1, order='F') + offset + pts3d = similar_transform(pts3d, roi_box, size) + + ver_lst.append(pts3d) + + return ver_lst diff --git a/talkingface/utils/pose_3ddfa/__init.py b/talkingface/utils/pose_3ddfa/__init.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface/utils/pose_3ddfa/bfm/__init__.py b/talkingface/utils/pose_3ddfa/bfm/__init__.py new file mode 100644 index 00000000..d16cb660 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/bfm/__init__.py @@ -0,0 +1 @@ +from .bfm import BFMModel \ No newline at end of file diff --git a/talkingface/utils/pose_3ddfa/bfm/bfm.py b/talkingface/utils/pose_3ddfa/bfm/bfm.py new file mode 100644 index 00000000..46225551 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/bfm/bfm.py @@ -0,0 +1,40 @@ +# coding: utf-8 + +__author__ = 'cleardusk' + +import sys + +sys.path.append('..') + +import os.path as osp +import numpy as np +from ..utils.io import _load + +make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) + + +def _to_ctype(arr): + if not arr.flags.c_contiguous: + return arr.copy(order='C') + return arr + + +class BFMModel(object): + def __init__(self, bfm_fp, shape_dim=40, exp_dim=10): + bfm = _load(bfm_fp) + self.u = bfm.get('u').astype(np.float32) # fix bug + self.w_shp = bfm.get('w_shp').astype(np.float32)[..., :shape_dim] + self.w_exp = bfm.get('w_exp').astype(np.float32)[..., :exp_dim] + if osp.split(bfm_fp)[-1] == 'bfm_noneck_v3.pkl': + self.tri = _load(make_abs_path('../configs/tri.pkl')) # this tri/face is re-built for bfm_noneck_v3 + else: + self.tri = bfm.get('tri') + + self.tri = _to_ctype(self.tri.T).astype(np.int32) + self.keypoints = bfm.get('keypoints').astype(np.long) # fix bug + w = np.concatenate((self.w_shp, self.w_exp), axis=1) + self.w_norm = np.linalg.norm(w, axis=0) + + self.u_base = self.u[self.keypoints].reshape(-1, 1) + self.w_shp_base = self.w_shp[self.keypoints] + self.w_exp_base = self.w_exp[self.keypoints] diff --git a/talkingface/utils/pose_3ddfa/configs/.gitignore b/talkingface/utils/pose_3ddfa/configs/.gitignore new file mode 100644 index 00000000..5f6ba3bd --- /dev/null +++ b/talkingface/utils/pose_3ddfa/configs/.gitignore @@ -0,0 +1,3 @@ +*.pkl +*.yml +*.onnx \ No newline at end of file diff --git a/talkingface/utils/pose_3ddfa/models/__init__.py b/talkingface/utils/pose_3ddfa/models/__init__.py new file mode 100644 index 00000000..4e86ed62 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/models/__init__.py @@ -0,0 +1,3 @@ +from .mobilenet_v1 import * +from .mobilenet_v3 import * +from .resnet import * \ No newline at end of file diff --git a/talkingface/utils/pose_3ddfa/models/mobilenet_v1.py b/talkingface/utils/pose_3ddfa/models/mobilenet_v1.py new file mode 100644 index 00000000..9c9baf51 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/models/mobilenet_v1.py @@ -0,0 +1,163 @@ +# coding: utf-8 + +from __future__ import division + +""" +Creates a MobileNet Model as defined in: +Andrew G. Howard Menglong Zhu Bo Chen, et.al. (2017). +MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications. +Copyright (c) Yang Lu, 2017 + +Modified By cleardusk +""" +import math +import torch.nn as nn + +__all__ = ['MobileNet', 'mobilenet'] + + +# __all__ = ['mobilenet_2', 'mobilenet_1', 'mobilenet_075', 'mobilenet_05', 'mobilenet_025'] + + +class DepthWiseBlock(nn.Module): + def __init__(self, inplanes, planes, stride=1, prelu=False): + super(DepthWiseBlock, self).__init__() + inplanes, planes = int(inplanes), int(planes) + self.conv_dw = nn.Conv2d(inplanes, inplanes, kernel_size=3, padding=1, stride=stride, groups=inplanes, + bias=False) + self.bn_dw = nn.BatchNorm2d(inplanes) + self.conv_sep = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False) + self.bn_sep = nn.BatchNorm2d(planes) + if prelu: + self.relu = nn.PReLU() + else: + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv_dw(x) + out = self.bn_dw(out) + out = self.relu(out) + + out = self.conv_sep(out) + out = self.bn_sep(out) + out = self.relu(out) + + return out + + +class MobileNet(nn.Module): + def __init__(self, widen_factor=1.0, num_classes=1000, prelu=False, input_channel=3): + """ Constructor + Args: + widen_factor: config of widen_factor + num_classes: number of classes + """ + super(MobileNet, self).__init__() + + block = DepthWiseBlock + self.conv1 = nn.Conv2d(input_channel, int(32 * widen_factor), kernel_size=3, stride=2, padding=1, + bias=False) + + self.bn1 = nn.BatchNorm2d(int(32 * widen_factor)) + if prelu: + self.relu = nn.PReLU() + else: + self.relu = nn.ReLU(inplace=True) + + self.dw2_1 = block(32 * widen_factor, 64 * widen_factor, prelu=prelu) + self.dw2_2 = block(64 * widen_factor, 128 * widen_factor, stride=2, prelu=prelu) + + self.dw3_1 = block(128 * widen_factor, 128 * widen_factor, prelu=prelu) + self.dw3_2 = block(128 * widen_factor, 256 * widen_factor, stride=2, prelu=prelu) + + self.dw4_1 = block(256 * widen_factor, 256 * widen_factor, prelu=prelu) + self.dw4_2 = block(256 * widen_factor, 512 * widen_factor, stride=2, prelu=prelu) + + self.dw5_1 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) + self.dw5_2 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) + self.dw5_3 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) + self.dw5_4 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) + self.dw5_5 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) + self.dw5_6 = block(512 * widen_factor, 1024 * widen_factor, stride=2, prelu=prelu) + + self.dw6 = block(1024 * widen_factor, 1024 * widen_factor, prelu=prelu) + + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(int(1024 * widen_factor), num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.dw2_1(x) + x = self.dw2_2(x) + x = self.dw3_1(x) + x = self.dw3_2(x) + x = self.dw4_1(x) + x = self.dw4_2(x) + x = self.dw5_1(x) + x = self.dw5_2(x) + x = self.dw5_3(x) + x = self.dw5_4(x) + x = self.dw5_5(x) + x = self.dw5_6(x) + x = self.dw6(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def mobilenet(**kwargs): + """ + Construct MobileNet. + widen_factor=1.0 for mobilenet_1 + widen_factor=0.75 for mobilenet_075 + widen_factor=0.5 for mobilenet_05 + widen_factor=0.25 for mobilenet_025 + """ + # widen_factor = 1.0, num_classes = 1000 + # model = MobileNet(widen_factor=widen_factor, num_classes=num_classes) + # return model + + model = MobileNet( + widen_factor=kwargs.get('widen_factor', 1.0), + num_classes=kwargs.get('num_classes', 62) + ) + return model + + +def mobilenet_2(num_classes=62, input_channel=3): + model = MobileNet(widen_factor=2.0, num_classes=num_classes, input_channel=input_channel) + return model + + +def mobilenet_1(num_classes=62, input_channel=3): + model = MobileNet(widen_factor=1.0, num_classes=num_classes, input_channel=input_channel) + return model + + +def mobilenet_075(num_classes=62, input_channel=3): + model = MobileNet(widen_factor=0.75, num_classes=num_classes, input_channel=input_channel) + return model + + +def mobilenet_05(num_classes=62, input_channel=3): + model = MobileNet(widen_factor=0.5, num_classes=num_classes, input_channel=input_channel) + return model + + +def mobilenet_025(num_classes=62, input_channel=3): + model = MobileNet(widen_factor=0.25, num_classes=num_classes, input_channel=input_channel) + return model diff --git a/talkingface/utils/pose_3ddfa/models/mobilenet_v3.py b/talkingface/utils/pose_3ddfa/models/mobilenet_v3.py new file mode 100644 index 00000000..e5eaf7f1 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/models/mobilenet_v3.py @@ -0,0 +1,246 @@ +# coding: utf-8 + + +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['MobileNetV3', 'mobilenet_v3'] + + +def conv_bn(inp, oup, stride, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU): + return nn.Sequential( + conv_layer(inp, oup, 3, stride, 1, bias=False), + norm_layer(oup), + nlin_layer(inplace=True) + ) + + +def conv_1x1_bn(inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU): + return nn.Sequential( + conv_layer(inp, oup, 1, 1, 0, bias=False), + norm_layer(oup), + nlin_layer(inplace=True) + ) + + +class Hswish(nn.Module): + def __init__(self, inplace=True): + super(Hswish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x * F.relu6(x + 3., inplace=self.inplace) / 6. + + +class Hsigmoid(nn.Module): + def __init__(self, inplace=True): + super(Hsigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return F.relu6(x + 3., inplace=self.inplace) / 6. + + +class SEModule(nn.Module): + def __init__(self, channel, reduction=4): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + Hsigmoid() + # nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +class Identity(nn.Module): + def __init__(self, channel): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +def make_divisible(x, divisible_by=8): + import numpy as np + return int(np.ceil(x * 1. / divisible_by) * divisible_by) + + +class MobileBottleneck(nn.Module): + def __init__(self, inp, oup, kernel, stride, exp, se=False, nl='RE'): + super(MobileBottleneck, self).__init__() + assert stride in [1, 2] + assert kernel in [3, 5] + padding = (kernel - 1) // 2 + self.use_res_connect = stride == 1 and inp == oup + + conv_layer = nn.Conv2d + norm_layer = nn.BatchNorm2d + if nl == 'RE': + nlin_layer = nn.ReLU # or ReLU6 + elif nl == 'HS': + nlin_layer = Hswish + else: + raise NotImplementedError + if se: + SELayer = SEModule + else: + SELayer = Identity + + self.conv = nn.Sequential( + # pw + conv_layer(inp, exp, 1, 1, 0, bias=False), + norm_layer(exp), + nlin_layer(inplace=True), + # dw + conv_layer(exp, exp, kernel, stride, padding, groups=exp, bias=False), + norm_layer(exp), + SELayer(exp), + nlin_layer(inplace=True), + # pw-linear + conv_layer(exp, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV3(nn.Module): + def __init__(self, widen_factor=1.0, num_classes=141, num_landmarks=136, input_size=120, mode='small'): + super(MobileNetV3, self).__init__() + input_channel = 16 + last_channel = 1280 + if mode == 'large': + # refer to Table 1 in paper + mobile_setting = [ + # k, exp, c, se, nl, s, + [3, 16, 16, False, 'RE', 1], + [3, 64, 24, False, 'RE', 2], + [3, 72, 24, False, 'RE', 1], + [5, 72, 40, True, 'RE', 2], + [5, 120, 40, True, 'RE', 1], + [5, 120, 40, True, 'RE', 1], + [3, 240, 80, False, 'HS', 2], + [3, 200, 80, False, 'HS', 1], + [3, 184, 80, False, 'HS', 1], + [3, 184, 80, False, 'HS', 1], + [3, 480, 112, True, 'HS', 1], + [3, 672, 112, True, 'HS', 1], + [5, 672, 160, True, 'HS', 2], + [5, 960, 160, True, 'HS', 1], + [5, 960, 160, True, 'HS', 1], + ] + elif mode == 'small': + # refer to Table 2 in paper + mobile_setting = [ + # k, exp, c, se, nl, s, + [3, 16, 16, True, 'RE', 2], + [3, 72, 24, False, 'RE', 2], + [3, 88, 24, False, 'RE', 1], + [5, 96, 40, True, 'HS', 2], + [5, 240, 40, True, 'HS', 1], + [5, 240, 40, True, 'HS', 1], + [5, 120, 48, True, 'HS', 1], + [5, 144, 48, True, 'HS', 1], + [5, 288, 96, True, 'HS', 2], + [5, 576, 96, True, 'HS', 1], + [5, 576, 96, True, 'HS', 1], + ] + else: + raise NotImplementedError + + # building first layer + assert input_size % 32 == 0 + last_channel = make_divisible(last_channel * widen_factor) if widen_factor > 1.0 else last_channel + self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)] + # self.classifier = [] + + # building mobile blocks + for k, exp, c, se, nl, s in mobile_setting: + output_channel = make_divisible(c * widen_factor) + exp_channel = make_divisible(exp * widen_factor) + self.features.append(MobileBottleneck(input_channel, output_channel, k, s, exp_channel, se, nl)) + input_channel = output_channel + + # building last several layers + if mode == 'large': + last_conv = make_divisible(960 * widen_factor) + self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish)) + self.features.append(nn.AdaptiveAvgPool2d(1)) + self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) + self.features.append(Hswish(inplace=True)) + elif mode == 'small': + last_conv = make_divisible(576 * widen_factor) + self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish)) + # self.features.append(SEModule(last_conv)) # refer to paper Table2, but I think this is a mistake + self.features.append(nn.AdaptiveAvgPool2d(1)) + self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) + self.features.append(Hswish(inplace=True)) + else: + raise NotImplementedError + + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + + # self.fc_param = nn.Linear(int(last_channel), num_classes) + self.fc = nn.Linear(int(last_channel), num_classes) + # self.fc_lm = nn.Linear(int(last_channel), num_landmarks) + + # building classifier + # self.classifier = nn.Sequential( + # nn.Dropout(p=dropout), # refer to paper section 6 + # nn.Linear(last_channel, n_class), + # ) + + self._initialize_weights() + + def forward(self, x): + x = self.features(x) + x_share = x.mean(3).mean(2) + + # x = self.classifier(x) + # print(x_share.shape) + # xp = self.fc_param(x_share) # param + # xl = self.fc_lm(x_share) # lm + + xp = self.fc(x_share) # param + + return xp # , xl + + def _initialize_weights(self): + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.zeros_(m.bias) + + +def mobilenet_v3(**kwargs): + model = MobileNetV3( + widen_factor=kwargs.get('widen_factor', 1.0), + num_classes=kwargs.get('num_classes', 62), + num_landmarks=kwargs.get('num_landmarks', 136), + input_size=kwargs.get('size', 128), + mode=kwargs.get('mode', 'small') + ) + + return model diff --git a/talkingface/utils/pose_3ddfa/models/resnet.py b/talkingface/utils/pose_3ddfa/models/resnet.py new file mode 100644 index 00000000..87d87273 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/models/resnet.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# coding: utf-8 + +import torch.nn as nn + +__all__ = ['ResNet', 'resnet22'] + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + """Another Strucutre used in caffe-resnet25""" + + def __init__(self, block, layers, num_classes=62, num_landmarks=136, input_channel=3, fc_flg=False): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(input_channel, 32, kernel_size=5, stride=2, padding=2, bias=False) + self.bn1 = nn.BatchNorm2d(32) # 32 is input channels number + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + + # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 128, layers[0], stride=2) + self.layer2 = self._make_layer(block, 256, layers[1], stride=2) + self.layer3 = self._make_layer(block, 512, layers[2], stride=2) + + self.conv_param = nn.Conv2d(512, num_classes, 1) + # self.conv_lm = nn.Conv2d(512, num_landmarks, 1) + self.avgpool = nn.AdaptiveAvgPool2d(1) + # self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc_flg = fc_flg + + # parameter initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # 1. + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, math.sqrt(2. / n)) + + # 2. kaiming normal + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + + # x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + # if self.fc_flg: + # x = self.avgpool(x) + # x = x.view(x.size(0), -1) + # x = self.fc(x) + # else: + xp = self.conv_param(x) + xp = self.avgpool(xp) + xp = xp.view(xp.size(0), -1) + + # xl = self.conv_lm(x) + # xl = self.avgpool(xl) + # xl = xl.view(xl.size(0), -1) + + return xp # , xl + + +def resnet22(**kwargs): + model = ResNet( + BasicBlock, + [3, 4, 3], + num_landmarks=kwargs.get('num_landmarks', 136), + input_channel=kwargs.get('input_channel', 3), + fc_flg=False + ) + return model + + +def main(): + pass + + +if __name__ == '__main__': + main() diff --git a/talkingface/utils/pose_3ddfa/utils/__init__.py b/talkingface/utils/pose_3ddfa/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface/utils/pose_3ddfa/utils/functions.py b/talkingface/utils/pose_3ddfa/utils/functions.py new file mode 100644 index 00000000..970c668d --- /dev/null +++ b/talkingface/utils/pose_3ddfa/utils/functions.py @@ -0,0 +1,182 @@ +# coding: utf-8 + +__author__ = 'cleardusk' + +import numpy as np +import cv2 +from math import sqrt +import matplotlib.pyplot as plt + +RED = (0, 0, 255) +GREEN = (0, 255, 0) +BLUE = (255, 0, 0) + + +def get_suffix(filename): + """a.jpg -> jpg""" + pos = filename.rfind('.') + if pos == -1: + return '' + return filename[pos:] + + +def crop_img(img, roi_box): + h, w = img.shape[:2] + + sx, sy, ex, ey = [int(round(_)) for _ in roi_box] + dh, dw = ey - sy, ex - sx + if len(img.shape) == 3: + res = np.zeros((dh, dw, 3), dtype=np.uint8) + else: + res = np.zeros((dh, dw), dtype=np.uint8) + if sx < 0: + sx, dsx = 0, -sx + else: + dsx = 0 + + if ex > w: + ex, dex = w, dw - (ex - w) + else: + dex = dw + + if sy < 0: + sy, dsy = 0, -sy + else: + dsy = 0 + + if ey > h: + ey, dey = h, dh - (ey - h) + else: + dey = dh + + res[dsy:dey, dsx:dex] = img[sy:ey, sx:ex] + return res + + +def calc_hypotenuse(pts): + bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])] + center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] + radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 + bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius] + llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2) + return llength / 3 + + +def parse_roi_box_from_landmark(pts): + """calc roi box from landmark""" + bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])] + center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] + radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 + bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius] + + llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2) + center_x = (bbox[2] + bbox[0]) / 2 + center_y = (bbox[3] + bbox[1]) / 2 + + roi_box = [0] * 4 + roi_box[0] = center_x - llength / 2 + roi_box[1] = center_y - llength / 2 + roi_box[2] = roi_box[0] + llength + roi_box[3] = roi_box[1] + llength + + return roi_box + + +def parse_roi_box_from_bbox(bbox): + left, top, right, bottom = bbox[:4] + old_size = (right - left + bottom - top) / 2 + center_x = right - (right - left) / 2.0 + center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14 + size = int(old_size * 1.58) + + roi_box = [0] * 4 + roi_box[0] = center_x - size / 2 + roi_box[1] = center_y - size / 2 + roi_box[2] = roi_box[0] + size + roi_box[3] = roi_box[1] + size + + return roi_box + + +def plot_image(img): + height, width = img.shape[:2] + plt.figure(figsize=(12, height / width * 12)) + + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + plt.axis('off') + + plt.imshow(img[..., ::-1]) + plt.show() + + +def draw_landmarks(img, pts, style='fancy', wfp=None, show_flag=False, **kwargs): + """Draw landmarks using matplotlib""" + height, width = img.shape[:2] + plt.figure(figsize=(12, height / width * 12)) + plt.imshow(img[..., ::-1]) + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + plt.axis('off') + + dense_flag = kwargs.get('dense_flag') + + if not type(pts) in [tuple, list]: + pts = [pts] + for i in range(len(pts)): + if dense_flag: + plt.plot(pts[i][0, ::6], pts[i][1, ::6], 'o', markersize=0.4, color='c', alpha=0.7) + else: + alpha = 0.8 + markersize = 4 + lw = 1.5 + color = kwargs.get('color', 'w') + markeredgecolor = kwargs.get('markeredgecolor', 'black') + + nums = [0, 17, 22, 27, 31, 36, 42, 48, 60, 68] + + # close eyes and mouths + plot_close = lambda i1, i2: plt.plot([pts[i][0, i1], pts[i][0, i2]], [pts[i][1, i1], pts[i][1, i2]], + color=color, lw=lw, alpha=alpha - 0.1) + plot_close(41, 36) + plot_close(47, 42) + plot_close(59, 48) + plot_close(67, 60) + + for ind in range(len(nums) - 1): + l, r = nums[ind], nums[ind + 1] + plt.plot(pts[i][0, l:r], pts[i][1, l:r], color=color, lw=lw, alpha=alpha - 0.1) + + plt.plot(pts[i][0, l:r], pts[i][1, l:r], marker='o', linestyle='None', markersize=markersize, + color=color, + markeredgecolor=markeredgecolor, alpha=alpha) + if wfp is not None: + plt.savefig(wfp, dpi=150) + print(f'Save visualization result to {wfp}') + + if show_flag: + plt.show() + + +def cv_draw_landmark(img_ori, pts, box=None, color=GREEN, size=1): + img = img_ori.copy() + n = pts.shape[1] + if n <= 106: + for i in range(n): + cv2.circle(img, (int(round(pts[0, i])), int(round(pts[1, i]))), size, color, -1) + else: + sep = 1 + for i in range(0, n, sep): + cv2.circle(img, (int(round(pts[0, i])), int(round(pts[1, i]))), size, color, 1) + + if box is not None: + left, top, right, bottom = np.round(box).astype(np.int32) + left_top = (left, top) + right_top = (right, top) + right_bottom = (right, bottom) + left_bottom = (left, bottom) + cv2.line(img, left_top, right_top, BLUE, 1, cv2.LINE_AA) + cv2.line(img, right_top, right_bottom, BLUE, 1, cv2.LINE_AA) + cv2.line(img, right_bottom, left_bottom, BLUE, 1, cv2.LINE_AA) + cv2.line(img, left_bottom, left_top, BLUE, 1, cv2.LINE_AA) + + return img + \ No newline at end of file diff --git a/talkingface/utils/pose_3ddfa/utils/io.py b/talkingface/utils/pose_3ddfa/utils/io.py new file mode 100644 index 00000000..31eeaac8 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/utils/io.py @@ -0,0 +1,64 @@ +# coding: utf-8 + +__author__ = 'cleardusk' + +import os +import numpy as np +import torch +import pickle + + +def mkdir(d): + os.makedirs(d, exist_ok=True) + + +def _get_suffix(filename): + """a.jpg -> jpg""" + pos = filename.rfind('.') + if pos == -1: + return '' + return filename[pos + 1:] + + +def _load(fp): + suffix = _get_suffix(fp) + if suffix == 'npy': + return np.load(fp) + elif suffix == 'pkl': + return pickle.load(open(fp, 'rb')) + + +def _dump(wfp, obj): + suffix = _get_suffix(wfp) + if suffix == 'npy': + np.save(wfp, obj) + elif suffix == 'pkl': + pickle.dump(obj, open(wfp, 'wb')) + else: + raise Exception('Unknown Type: {}'.format(suffix)) + + +def _load_tensor(fp, mode='cpu'): + if mode.lower() == 'cpu': + return torch.from_numpy(_load(fp)) + elif mode.lower() == 'gpu': + return torch.from_numpy(_load(fp)).cuda() + + +def _tensor_to_cuda(x): + if x.is_cuda: + return x + else: + return x.cuda() + + +def _load_gpu(fp): + return torch.from_numpy(_load(fp)).cuda() + + +_load_cpu = _load +_numpy_to_tensor = lambda x: torch.from_numpy(x) +_tensor_to_numpy = lambda x: x.numpy() +_numpy_to_cuda = lambda x: _tensor_to_cuda(torch.from_numpy(x)) +_cuda_to_tensor = lambda x: x.cpu() +_cuda_to_numpy = lambda x: x.cpu().numpy() diff --git a/talkingface/utils/pose_3ddfa/utils/pose.py b/talkingface/utils/pose_3ddfa/utils/pose.py new file mode 100644 index 00000000..37341406 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/utils/pose.py @@ -0,0 +1,279 @@ +# coding: utf-8 + +""" +Reference: https://github.com/YadiraF/PRNet/blob/master/utils/estimate_pose.py + +Calculating pose from the output 3DMM parameters, you can also try to use solvePnP to perform estimation +""" + +__author__ = 'cleardusk' + +import cv2 +import numpy as np +from math import cos, sin, atan2, asin, sqrt + +from .functions import calc_hypotenuse, plot_image + + +def P2sRt(P): + """ decompositing camera matrix P. + Args: + P: (3, 4). Affine Camera Matrix. + Returns: + s: scale factor. + R: (3, 3). rotation matrix. + t2d: (2,). 2d translation. + """ + t3d = P[:, 3] + R1 = P[0:1, :3] + R2 = P[1:2, :3] + s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0 + r1 = R1 / np.linalg.norm(R1) + r2 = R2 / np.linalg.norm(R2) + r3 = np.cross(r1, r2) + + R = np.concatenate((r1, r2, r3), 0) + return s, R, t3d + + +def matrix2angle(R): + """ compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf + refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv + todo: check and debug + Args: + R: (3,3). rotation matrix + Returns: + x: yaw + y: pitch + z: roll + """ + if R[2, 0] > 0.998: + z = 0 + x = np.pi / 2 + y = z + atan2(-R[0, 1], -R[0, 2]) + elif R[2, 0] < -0.998: + z = 0 + x = -np.pi / 2 + y = -z + atan2(R[0, 1], R[0, 2]) + else: + x = asin(R[2, 0]) + y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) + z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) + + return x, y, z + +def angle2matrix(theta): + """ compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf + refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv + todo: check and debug + Args: + R: (3,3). rotation matrix + Returns: + x: yaw + y: pitch + z: roll + """ + R_x = np.array([[1, 0, 0 ], + + [0, cos(theta[1]), -sin(theta[1]) ], + + [0, sin(theta[1]), cos(theta[1]) ] + + ]) + + + + R_y = np.array([[cos(theta[0]), 0, sin(-theta[0]) ], + + [0, 1, 0 ], + + [-sin(-theta[0]), 0, cos(theta[0]) ] + + ]) + + + + R_z = np.array([[cos(theta[2]), -sin(theta[2]), 0], + + [sin(theta[2]), cos(theta[2]), 0], + + [0, 0, 1] + + ]) + + + + R = np.dot(R_z, np.dot( R_y, R_x )) + + + + return R + +def angle2matrix_3ddfa(angles): + ''' get rotation matrix from three rotation angles(radian). The same as in 3DDFA. + Args: + angles: [3,]. x, y, z angles + x: pitch. + y: yaw. + z: roll. + Returns: + R: 3x3. rotation matrix. + ''' + # x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2]) + x, y, z = angles[1], angles[0], angles[2] + + # x + Rx=np.array([[1, 0, 0], + [0, cos(x), sin(x)], + [0, -sin(x), cos(x)]]) + # y + Ry=np.array([[ cos(y), 0, -sin(y)], + [ 0, 1, 0], + [sin(y), 0, cos(y)]]) + # z + Rz=np.array([[cos(z), sin(z), 0], + [-sin(z), cos(z), 0], + [ 0, 0, 1]]) + R = Rx.dot(Ry).dot(Rz) + return R.astype(np.float32) + +def calc_pose(param): + P = param[:12].reshape(3, -1) # camera matrix + s, R, t3d = P2sRt(P) + P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale + pose = matrix2angle(R) + pose = [p * 180 / np.pi for p in pose] + + return P, pose + + +def build_camera_box(rear_size=90): + point_3d = [] + rear_depth = 0 + point_3d.append((-rear_size, -rear_size, rear_depth)) + point_3d.append((-rear_size, rear_size, rear_depth)) + point_3d.append((rear_size, rear_size, rear_depth)) + point_3d.append((rear_size, -rear_size, rear_depth)) + point_3d.append((-rear_size, -rear_size, rear_depth)) + + front_size = int(4 / 3 * rear_size) + front_depth = int(4 / 3 * rear_size) + point_3d.append((-front_size, -front_size, front_depth)) + point_3d.append((-front_size, front_size, front_depth)) + point_3d.append((front_size, front_size, front_depth)) + point_3d.append((front_size, -front_size, front_depth)) + point_3d.append((-front_size, -front_size, front_depth)) + point_3d = np.array(point_3d, dtype=np.float32).reshape(-1, 3) + + return point_3d + + +def plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2): + """ Draw a 3D box as annotation of pose. + Ref:https://github.com/yinguobing/head-pose-estimation/blob/master/pose_estimator.py + Args: + img: the input image + P: (3, 4). Affine Camera Matrix. + kpt: (2, 68) or (3, 68) + """ + llength = calc_hypotenuse(ver) + point_3d = build_camera_box(llength) + # Map to 2d image points + point_3d_homo = np.hstack((point_3d, np.ones([point_3d.shape[0], 1]))) # n x 4 + point_2d = point_3d_homo.dot(P.T)[:, :2] + + point_2d[:, 1] = - point_2d[:, 1] + point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d[:4, :2], 0) + np.mean(ver[:2, :27], 1) + point_2d = np.int32(point_2d.reshape(-1, 2)) + + # Draw all the lines + cv2.polylines(img, [point_2d], True, color, line_width, cv2.LINE_AA) + cv2.line(img, tuple(point_2d[1]), tuple( + point_2d[6]), color, line_width, cv2.LINE_AA) + cv2.line(img, tuple(point_2d[2]), tuple( + point_2d[7]), color, line_width, cv2.LINE_AA) + cv2.line(img, tuple(point_2d[3]), tuple( + point_2d[8]), color, line_width, cv2.LINE_AA) + + return img + + +def viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None): + for param, ver in zip(param_lst, ver_lst): + P, pose = calc_pose(param) + img = plot_pose_box(img, P, ver) + # print(P[:, :3]) + print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}') + + if wfp is not None: + cv2.imwrite(wfp, img) + print(f'Save visualization result to {wfp}') + + if show_flag: + plot_image(img) + + return img + +def pose_6(param): + P = param[:12].reshape(3, -1) # camera matrix + s, R, t3d = P2sRt(P) + P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale + pose = matrix2angle(R) + print(t3d) + R1 = angle2matrix(pose) + print(R) + print(R1) + pose = [p * 180 / np.pi for p in pose] + + return s, pose, t3d, P + + +def smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=None, wnp = None): + for param, ver in zip(param_lst, ver_lst): + t3d = np.array([pose_new[4],pose_new[5],pose_new[6]]) + + theta = np.array([pose_new[0],pose_new[1],pose_new[2]]) + theta = [p * np.pi / 180 for p in theta] + R = angle2matrix(theta) + P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) + img = plot_pose_box(img, P, ver) + # print(P,P.shape,t3d) + print(P,pose_new) + print(f'yaw: {theta[0]:.1f}, pitch: {theta[1]:.1f}, roll: {theta[2]:.1f}') + all_pose = [0] + all_pose = np.array(all_pose) + + if wfp is not None: + cv2.imwrite(wfp, img) + print(f'Save visualization result to {wfp}') + + if wnp is not None: + np.save(wnp, all_pose) + print(f'Save visualization result to {wfp}') + + if show_flag: + plot_image(img) + + return img + +def get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = None): + for param, ver in zip(param_lst, ver_lst): + s, pose, t3d, P = pose_6(param) + img = plot_pose_box(img, P, ver) + # print(P,P.shape,t3d) + print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}') + all_pose = [pose[0],pose[1],pose[2],s,t3d[0],t3d[1],t3d[2]] + all_pose = np.array(all_pose) + + if wfp is not None: + cv2.imwrite(wfp, img) + print(f'Save visualization result to {wfp}') + + if wnp is not None: + np.save(wnp, all_pose) + print(f'Save visualization result to {wfp}') + + if show_flag: + plot_image(img) + + return all_pose \ No newline at end of file diff --git a/talkingface/utils/pose_3ddfa/utils/tddfa_util.py b/talkingface/utils/pose_3ddfa/utils/tddfa_util.py new file mode 100644 index 00000000..68eee7b7 --- /dev/null +++ b/talkingface/utils/pose_3ddfa/utils/tddfa_util.py @@ -0,0 +1,102 @@ +# coding: utf-8 + +__author__ = 'cleardusk' + +import sys + +sys.path.append('..') + +import argparse +import numpy as np +import torch + + +def _to_ctype(arr): + if not arr.flags.c_contiguous: + return arr.copy(order='C') + return arr + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected') + + +def load_model(model, checkpoint_fp): + checkpoint = torch.load(checkpoint_fp, map_location=lambda storage, loc: storage)['state_dict'] + model_dict = model.state_dict() + # because the model is trained by multiple gpus, prefix module should be removed + for k in checkpoint.keys(): + kc = k.replace('module.', '') + if kc in model_dict.keys(): + model_dict[kc] = checkpoint[k] + if kc in ['fc_param.bias', 'fc_param.weight']: + model_dict[kc.replace('_param', '')] = checkpoint[k] + + model.load_state_dict(model_dict) + return model + + +class ToTensorGjz(object): + def __call__(self, pic): + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img.float() + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class NormalizeGjz(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + tensor.sub_(self.mean).div_(self.std) + return tensor + + +def similar_transform(pts3d, roi_box, size): + pts3d[0, :] -= 1 # for Python compatibility + pts3d[2, :] -= 1 + pts3d[1, :] = size - pts3d[1, :] + + sx, sy, ex, ey = roi_box + scale_x = (ex - sx) / size + scale_y = (ey - sy) / size + pts3d[0, :] = pts3d[0, :] * scale_x + sx + pts3d[1, :] = pts3d[1, :] * scale_y + sy + s = (scale_x + scale_y) / 2 + pts3d[2, :] *= s + pts3d[2, :] -= np.min(pts3d[2, :]) + return np.array(pts3d, dtype=np.float32) + + +def _parse_param(param): + """matrix pose form + param: shape=(trans_dim+shape_dim+exp_dim,), i.e., 62 = 12 + 40 + 10 + """ + + # pre-defined templates for parameter + n = param.shape[0] + if n == 62: + trans_dim, shape_dim, exp_dim = 12, 40, 10 + elif n == 72: + trans_dim, shape_dim, exp_dim = 12, 40, 20 + elif n == 141: + trans_dim, shape_dim, exp_dim = 12, 100, 29 + else: + raise Exception(f'Undefined templated param parsing rule') + + R_ = param[:trans_dim].reshape(3, -1) + R = R_[:, :3] + offset = R_[:, -1].reshape(3, 1) + alpha_shp = param[trans_dim:trans_dim + shape_dim].reshape(-1, 1) + alpha_exp = param[trans_dim + shape_dim:].reshape(-1, 1) + + return R, offset, alpha_shp, alpha_exp