diff --git a/DINet/Dockerfile b/DINet/Dockerfile new file mode 100644 index 00000000..a83091fa --- /dev/null +++ b/DINet/Dockerfile @@ -0,0 +1,29 @@ +# 使用 Python 3.6.13-slim 作为基础镜像 +FROM python:3.6.13-slim + +# 更新镜像并安装必要的系统库(包括 ffmpeg 和其他依赖) +RUN apt-get update --fix-missing && apt-get install -y \ + ffmpeg \ + libsm6 \ + libxext6 \ + libx264-dev \ + && rm -rf /var/lib/apt/lists/* + +# 设置工作目录 +WORKDIR /app + +# 将当前目录中的所有文件复制到容器的 /app 目录 +COPY . /app + +RUN pip install --upgrade pip +# 安装 PyTorch、TorchVision 和 Torchaudio +#RUN pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 +RUN pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html +# 安装项目依赖项 +RUN pip install --no-cache-dir -r requirements.txt + +# 设置 ENTRYPOINT 为 python +ENTRYPOINT ["python"] + +# 设置默认命令为 run_evaluate.py --process_num 1 +CMD ["run_evaluate.py"] diff --git a/DINet/README.md b/DINet/README.md new file mode 100644 index 00000000..0fce21b7 --- /dev/null +++ b/DINet/README.md @@ -0,0 +1,105 @@ +# DINet: Deformation Inpainting Network for Realistic Face Visually Dubbing on High Resolution Video (AAAI2023) +![在这里插入图片描述](https://img-blog.csdnimg.cn/178c6b3ec0074af7a2dcc9ef26450e75.png) +[Paper](https://fuxivirtualhuman.github.io/pdf/AAAI2023_FaceDubbing.pdf)         [demo video](https://www.youtube.com/watch?v=UU344T-9h7M&t=6s)      Supplementary materials + + +这是2024年语音识别课程大作业的仓库,用于[DINet](https://github.com/MRzzm/DINet)的复现 +# 复现注意事项 +首先这里知识对于原项目进行了复述,因此具体操作参考配置文档.txt文件进行使用。 + +## 数据获取 +##### 在 [Google drive](https://drive.google.com/drive/folders/1rPtOo9Uuhc59YfFVv4gBmkh0_oG0nCQb?usp=share_link)中下载资源 (asserts.zip)。解压缩并将 dir 放入 ./ 中 ++ 使用示例视频进行推理。运行 + ```python +python inference.py --mouth_region_size=256 --source_video_path=./asserts/examples/testxxx.mp4 --source_openface_landmark_path=./asserts/examples/testxxx.csv --driving_audio_path=./asserts/examples/driving_audio_xxx.wav --pretrained_clip_DINet_path=./asserts/clip_training_DINet_256mouth.pth +``` +结果保存在 ./asserts/inference_result + ++ 使用自定义视频进行推理。 +**Note:** 发布的预训练模型是在 HDTF 数据集上训练的。(视频名称在 ./asserts/training_video_name.txt 中) + +使用 [openface](https://github.com/TadasBaltrusaitis/OpenFace)检测自定义视频的平滑面部特征点。 + + +检测到的人脸特征点保存在 “xxxx.csv” 中。运行 + ```python +python inference.py --mouth_region_size=256 --source_video_path= custom video path --source_openface_landmark_path= detected landmark path --driving_audio_path= driving audio path --pretrained_clip_DINet_path=./asserts/clip_training_DINet_256mouth.pth +``` +在您的自定义视频上实现人脸视觉配音。 +## 训练 +### 数据处理 + + 1. 从[HDTF](https://github.com/MRzzm/HDTF)下载视频。根据 xx_annotion_time.txt 分割视频,不裁剪和调整视频大小。 + 2. 将所有分割的视频重新采样为 25fps,并将视频放入 “./asserts/split_video_25fps”。您可以在 “./asserts/split_video_25fps” 中看到两个示例视频。我们使用[软件](http://www.pcfreetime.com/formatfactory/cn/index.html) 对视频进行重新采样。我们在实验中提供了训练视频的名称列表。(请参阅“./asserts/training_video_name.txt”) + 3. 使用 [openface](https://github.com/TadasBaltrusaitis/OpenFace) 检测所有视频的平滑面部特征点。将所有 “.csv” 结果放入 “./asserts/split_video_25fps_landmark_openface” 中。您可以在 “./asserts/split_video_25fps_landmark_openface” 中看到两个示例 csv 文件。 + + + + 4. 从所有视频中提取帧并将帧保存在 “./asserts/split_video_25fps_frame” 中。运行 +```python +python data_processing.py --extract_video_frame +``` + 5. 从所有视频中提取音频,并将音频保存在 ./asserts/split_video_25fps_audio 中。运行 + ```python +python data_processing.py --extract_audio +``` + 6. 从所有音频中提取 deepspeech 特征并将特征保存在 “./asserts/split_video_25fps_deepspeech” 中。运行 + ```python +python data_processing.py --extract_deep_speech +``` + 7. 裁剪所有视频的人脸并将图像保存在 “./asserts/split_video_25fps_crop_face” 中。运行 + ```python +python data_processing.py --crop_face +``` + 8. 生成训练 json 文件 “./asserts/training_json.json”。运行 + ```python +python data_processing.py --generate_training_json +``` + +### 训练模型 +训练过程分为帧训练阶段和 clip 训练阶段。在帧训练阶段,我们使用从粗到细的策略,因此您可以在任意分辨率下训练模型。 + +#### 框架训练阶段。 +在帧训练阶段,我们只使用感知损失和 GAN 损失 + + 1. 首先,以 104x80(嘴部区域为 64x64)分辨率训练 DINet。运行 + ```python +python train_DINet_frame.py --augment_num=32 --mouth_region_size=64 --batch_size=24 --result_path=./asserts/training_model_weight/frame_training_64 +``` + + + 2. 加载预训练模型(面部:104x80 & 嘴巴:64x64)并以更高分辨率训练DINet(面部:208x160 & 嘴巴:128x128)。运行 +python train_DINet_frame.py --augment_num=100 --mouth_region_size=128 --batch_size=80 --coarse2fine --coarse_model_path=./asserts/training_model_weight/frame_training_64/xxxxxx.pth --result_path=./asserts/training_model_weight/frame_training_128 +``` + + + 3. 加载预训练模型(面部:208x160 & 嘴巴:128x128)并以更高分辨率训练DINet(面部:416x320 & 嘴巴:256x256)。运行 + ```python +python train_DINet_frame.py --augment_num=20 --mouth_region_size=256 --batch_size=12 --coarse2fine --coarse_model_path=./asserts/training_model_weight/frame_training_128/xxxxxx.pth --result_path=./asserts/training_model_weight/frame_training_256 +``` + + +#### 剪辑训练阶段。 +在剪辑训练阶段,我们使用感知损失、帧/剪辑 GAN 损失和同步损失。加载预训练的帧模型(面部:416x320 & 嘴巴:256x256),预训练的同步网络模型(嘴巴:256x256)并在剪辑设置中训练DINet。运行 + ```python +python train_DINet_clip.py --augment_num=3 --mouth_region_size=256 --batch_size=3 --pretrained_syncnet_path=./asserts/syncnet_256mouth.pth --pretrained_frame_DINet_path=./asserts/training_model_weight/frame_training_256/xxxxx.pth --result_path=./asserts/training_model_weight/clip_training_256 +``` + + +# 改进推理和评估 +1.上述推理过程中过于繁琐不便于我们进行测试集的推理和评估,因此这里对于该部分进行了修改,通过一个统一的脚本进行推理,我们只需要将测试集放到test_data文件夹下面我们就可以很方便的进行推理: +``` +docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest run_inference.py --process_num x +``` +process_num 遍历推理个数,结果保存在./asserts/inference_result里面 + 推理阶段默认使用预训练模型,如果使用训练模型加入参数 --model_path ./asserts/training_model_weight/clip_training_256/xxxx.pth + 可以选择音频路径 --audio_path /path/to/your_auido + +2.由于该项目并没有评估脚本来对于模型进行评估,因此这里加入了专用的评估脚本,实现了NIQE,PSNR,FID,SSIM,LSE-C,LSE-D指标,通过运行 +``` +docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest run_evaluate.py +``` +评估结果保存在./asserts/evaluate_result里面。 + +## 声明 +整体实现思路来自[https://github.com/MRzzm/DINet](https://github.com/MRzzm/DINet) \ No newline at end of file diff --git a/DINet/config/__pycache__/config.cpython-36.pyc b/DINet/config/__pycache__/config.cpython-36.pyc new file mode 100644 index 00000000..e4f299cc Binary files /dev/null and b/DINet/config/__pycache__/config.cpython-36.pyc differ diff --git a/DINet/config/config.py b/DINet/config/config.py new file mode 100644 index 00000000..7f449163 --- /dev/null +++ b/DINet/config/config.py @@ -0,0 +1,110 @@ +import argparse + +class DataProcessingOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser() + + def parse_args(self): + self.parser.add_argument('--extract_video_frame', action='store_true', help='extract video frame') + self.parser.add_argument('--extract_audio', action='store_true', help='extract audio files from videos') + self.parser.add_argument('--extract_deep_speech', action='store_true', help='extract deep speech features') + self.parser.add_argument('--crop_face', action='store_true', help='crop face') + self.parser.add_argument('--generate_training_json', action='store_true', help='generate training json file') + + self.parser.add_argument('--source_video_dir', type=str, default="./asserts/training_data/split_video_25fps", + help='path of source video in 25 fps') + self.parser.add_argument('--openface_landmark_dir', type=str, default="./asserts/training_data/split_video_25fps_landmark_openface", + help='path of openface landmark dir') + self.parser.add_argument('--video_frame_dir', type=str, default="./asserts/training_data/split_video_25fps_frame", + help='path of video frames') + self.parser.add_argument('--audio_dir', type=str, default="./asserts/training_data/split_video_25fps_audio", + help='path of audios') + self.parser.add_argument('--deep_speech_dir', type=str, default="./asserts/training_data/split_video_25fps_deepspeech", + help='path of deep speech') + self.parser.add_argument('--crop_face_dir', type=str, default="./asserts/training_data/split_video_25fps_crop_face", + help='path of crop face dir') + self.parser.add_argument('--json_path', type=str, default="./asserts/training_data/training_json.json", + help='path of training json') + self.parser.add_argument('--clip_length', type=int, default=9, help='clip length') + self.parser.add_argument('--deep_speech_model', type=str, default="./asserts/output_graph.pb", + help='path of pretrained deepspeech model') + return self.parser.parse_args() + +class DINetTrainingOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser() + + def parse_args(self): + self.parser.add_argument('--seed', type=int, default=456, help='random seed to use.') + self.parser.add_argument('--source_channel', type=int, default=3, help='input source image channels') + self.parser.add_argument('--ref_channel', type=int, default=15, help='input reference image channels') + self.parser.add_argument('--audio_channel', type=int, default=29, help='input audio channels') + self.parser.add_argument('--augment_num', type=int, default=32, help='augment training data') + self.parser.add_argument('--mouth_region_size', type=int, default=64, help='augment training data') + self.parser.add_argument('--train_data', type=str, default=r"./asserts/training_data/training_json.json", + help='path of training json') + self.parser.add_argument('--batch_size', type=int, default=24, help='training batch size') + self.parser.add_argument('--lamb_perception', type=int, default=10, help='weight of perception loss') + self.parser.add_argument('--lamb_syncnet_perception', type=int, default=0.1, help='weight of perception loss') + self.parser.add_argument('--lr_g', type=float, default=0.0001, help='initial learning rate for adam') + self.parser.add_argument('--lr_dI', type=float, default=0.0001, help='initial learning rate for adam') + self.parser.add_argument('--start_epoch', default=1, type=int, help='start epoch in training stage') + self.parser.add_argument('--non_decay', default=200, type=int, help='num of epoches with fixed learning rate') + self.parser.add_argument('--decay', default=200, type=int, help='num of linearly decay epochs') + self.parser.add_argument('--checkpoint', type=int, default=2, help='num of checkpoints in training stage') + self.parser.add_argument('--result_path', type=str, default=r"./asserts/training_model_weight/frame_training_64", + help='result path to save model') + self.parser.add_argument('--coarse2fine', action='store_true', help='If true, load pretrained model path.') + self.parser.add_argument('--coarse_model_path', + default='', + type=str, + help='Save data (.pth) of previous training') + self.parser.add_argument('--pretrained_syncnet_path', + default='', + type=str, + help='Save data (.pth) of pretrained syncnet') + self.parser.add_argument('--pretrained_frame_DINet_path', + default='', + type=str, + help='Save data (.pth) of frame trained DINet') + # ========================= Discriminator ========================== + self.parser.add_argument('--D_num_blocks', type=int, default=4, help='num of down blocks in discriminator') + self.parser.add_argument('--D_block_expansion', type=int, default=64, help='block expansion in discriminator') + self.parser.add_argument('--D_max_features', type=int, default=256, help='max channels in discriminator') + return self.parser.parse_args() + + +class DINetInferenceOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser() + + def parse_args(self): + self.parser.add_argument('--source_channel', type=int, default=3, help='channels of source image') + self.parser.add_argument('--ref_channel', type=int, default=15, help='channels of reference image') + self.parser.add_argument('--audio_channel', type=int, default=29, help='channels of audio feature') + self.parser.add_argument('--mouth_region_size', type=int, default=256, help='help to resize window') + self.parser.add_argument('--source_video_path', + default='./asserts/examples/test4.mp4', + type=str, + help='path of source video') + self.parser.add_argument('--source_openface_landmark_path', + default='./asserts/examples/test4.csv', + type=str, + help='path of detected openface landmark') + self.parser.add_argument('--driving_audio_path', + default='./asserts/examples/driving_audio_1.wav', + type=str, + help='path of driving audio') + self.parser.add_argument('--pretrained_clip_DINet_path', + default='./asserts/clip_training_DINet_256mouth.pth', + type=str, + help='pretrained model of DINet(clip trained)') + self.parser.add_argument('--deepspeech_model_path', + default='./asserts/output_graph.pb', + type=str, + help='path of deepspeech model') + self.parser.add_argument('--res_video_dir', + default='./asserts/inference_result', + type=str, + help='path of generated videos') + return self.parser.parse_args() \ No newline at end of file diff --git a/DINet/data_processing.py b/DINet/data_processing.py new file mode 100644 index 00000000..1afad399 --- /dev/null +++ b/DINet/data_processing.py @@ -0,0 +1,186 @@ +import glob +import os +import subprocess +import cv2 +import numpy as np +import json + +from utils.data_processing import load_landmark_openface,compute_crop_radius +from utils.deep_speech import DeepSpeech +from config.config import DataProcessingOptions + +def extract_audio(source_video_dir,res_audio_dir): + ''' + extract audio files from videos + ''' + if not os.path.exists(source_video_dir): + raise ('wrong path of video dir') + if not os.path.exists(res_audio_dir): + os.mkdir(res_audio_dir) + video_path_list = glob.glob(os.path.join(source_video_dir, '*.mp4')) + for video_path in video_path_list: + print('extract audio from video: {}'.format(os.path.basename(video_path))) + audio_path = os.path.join(res_audio_dir, os.path.basename(video_path).replace('.mp4', '.wav')) + cmd = 'ffmpeg -i {} -f wav -ar 16000 {}'.format(video_path, audio_path) + subprocess.call(cmd, shell=True) + +def extract_deep_speech(audio_dir,res_deep_speech_dir,deep_speech_model_path): + ''' + extract deep speech feature + ''' + if not os.path.exists(res_deep_speech_dir): + os.mkdir(res_deep_speech_dir) + DSModel = DeepSpeech(deep_speech_model_path) + wav_path_list = glob.glob(os.path.join(audio_dir, '*.wav')) + for wav_path in wav_path_list: + video_name = os.path.basename(wav_path).replace('.wav', '') + res_dp_path = os.path.join(res_deep_speech_dir, video_name + '_deepspeech.txt') + if os.path.exists(res_dp_path): + os.remove(res_dp_path) + print('extract deep speech feature from audio:{}'.format(video_name)) + ds_feature = DSModel.compute_audio_feature(wav_path) + np.savetxt(res_dp_path, ds_feature) + +def extract_video_frame(source_video_dir,res_video_frame_dir): + ''' + extract video frames from videos + ''' + if not os.path.exists(source_video_dir): + raise ('wrong path of video dir') + if not os.path.exists(res_video_frame_dir): + os.mkdir(res_video_frame_dir) + video_path_list = glob.glob(os.path.join(source_video_dir, '*.mp4')) + for video_path in video_path_list: + video_name = os.path.basename(video_path) + frame_dir = os.path.join(res_video_frame_dir, video_name.replace('.mp4', '')) + if not os.path.exists(frame_dir): + os.makedirs(frame_dir) + print('extracting frames from {} ...'.format(video_name)) + videoCapture = cv2.VideoCapture(video_path) + fps = videoCapture.get(cv2.CAP_PROP_FPS) + if int(fps) != 25: + raise ('{} video is not in 25 fps'.format(video_path)) + frames = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) + for i in range(int(frames)): + ret, frame = videoCapture.read() + result_path = os.path.join(frame_dir, str(i).zfill(6) + '.jpg') + cv2.imwrite(result_path, frame) + + +def crop_face_according_openfaceLM(openface_landmark_dir,video_frame_dir,res_crop_face_dir,clip_length): + ''' + crop face according to openface landmark + ''' + if not os.path.exists(openface_landmark_dir): + raise ('wrong path of openface landmark dir') + if not os.path.exists(video_frame_dir): + raise ('wrong path of video frame dir') + if not os.path.exists(res_crop_face_dir): + os.mkdir(res_crop_face_dir) + landmark_openface_path_list = glob.glob(os.path.join(openface_landmark_dir, '*.csv')) + for landmark_openface_path in landmark_openface_path_list: + video_name = os.path.basename(landmark_openface_path).replace('.csv', '') + crop_face_video_dir = os.path.join(res_crop_face_dir, video_name) + if not os.path.exists(crop_face_video_dir): + os.makedirs(crop_face_video_dir) + print('cropping face from video: {} ...'.format(video_name)) + landmark_openface_data = load_landmark_openface(landmark_openface_path).astype(np.int) + frame_dir = os.path.join(video_frame_dir, video_name) + if not os.path.exists(frame_dir): + raise ('run last step to extract video frame') + if len(glob.glob(os.path.join(frame_dir, '*.jpg'))) != landmark_openface_data.shape[0]: + raise ('landmark length is different from frame length') + frame_length = min(len(glob.glob(os.path.join(frame_dir, '*.jpg'))), landmark_openface_data.shape[0]) + end_frame_index = list(range(clip_length, frame_length, clip_length)) + video_clip_num = len(end_frame_index) + for i in range(video_clip_num): + first_image = cv2.imread(os.path.join(frame_dir, '000000.jpg')) + video_h,video_w = first_image.shape[0], first_image.shape[1] + crop_flag, radius_clip = compute_crop_radius((video_w,video_h), + landmark_openface_data[end_frame_index[i] - clip_length:end_frame_index[i], :,:]) + if not crop_flag: + continue + radius_clip_1_4 = radius_clip // 4 + print('cropping {}/{} clip from video:{}'.format(i, video_clip_num, video_name)) + res_face_clip_dir = os.path.join(crop_face_video_dir, str(i).zfill(6)) + if not os.path.exists(res_face_clip_dir): + os.mkdir(res_face_clip_dir) + for frame_index in range(end_frame_index[i]- clip_length,end_frame_index[i]): + source_frame_path = os.path.join(frame_dir,str(frame_index).zfill(6)+'.jpg') + source_frame_data = cv2.imread(source_frame_path) + frame_landmark = landmark_openface_data[frame_index, :, :] + crop_face_data = source_frame_data[ + frame_landmark[29, 1] - radius_clip:frame_landmark[ + 29, 1] + radius_clip * 2 + radius_clip_1_4, + frame_landmark[33, 0] - radius_clip - radius_clip_1_4:frame_landmark[ + 33, 0] + radius_clip + radius_clip_1_4, + :].copy() + res_crop_face_frame_path = os.path.join(res_face_clip_dir, str(frame_index).zfill(6) + '.jpg') + if os.path.exists(res_crop_face_frame_path): + os.remove(res_crop_face_frame_path) + cv2.imwrite(res_crop_face_frame_path, crop_face_data) + + +def generate_training_json(crop_face_dir,deep_speech_dir,clip_length,res_json_path): + video_name_list = os.listdir(crop_face_dir) + video_name_list.sort() + res_data_dic = {} + for video_index, video_name in enumerate(video_name_list): + print('generate training json file :{} {}/{}'.format(video_name,video_index,len(video_name_list))) + tem_dic = {} + deep_speech_feature_path = os.path.join(deep_speech_dir, video_name + '_deepspeech.txt') + if not os.path.exists(deep_speech_feature_path): + raise ('wrong path of deep speech') + deep_speech_feature = np.loadtxt(deep_speech_feature_path) + video_clip_dir = os.path.join(crop_face_dir, video_name) + clip_name_list = os.listdir(video_clip_dir) + clip_name_list.sort() + video_clip_num = len(clip_name_list) + clip_data_list = [] + for clip_index, clip_name in enumerate(clip_name_list): + tem_tem_dic = {} + clip_frame_dir = os.path.join(video_clip_dir, clip_name) + frame_path_list = glob.glob(os.path.join(clip_frame_dir, '*.jpg')) + frame_path_list.sort() + assert len(frame_path_list) == clip_length + start_index = int(float(clip_name) * clip_length) + assert int(float(os.path.basename(frame_path_list[0]).replace('.jpg', ''))) == start_index + frame_name_list = [video_name + '/' + clip_name + '/' + os.path.basename(item) for item in frame_path_list] + deep_speech_list = deep_speech_feature[start_index:start_index + clip_length, :].tolist() + if len(frame_name_list) != len(deep_speech_list): + print(' skip video: {}:{}/{} clip:{}:{}/{} because of different length: {} {}'.format( + video_name,video_index,len(video_name_list),clip_name,clip_index,len(clip_name_list), + len(frame_name_list),len(deep_speech_list))) + tem_tem_dic['frame_name_list'] = frame_name_list + tem_tem_dic['frame_path_list'] = frame_path_list + tem_tem_dic['deep_speech_list'] = deep_speech_list + clip_data_list.append(tem_tem_dic) + tem_dic['video_clip_num'] = video_clip_num + tem_dic['clip_data_list'] = clip_data_list + res_data_dic[video_name] = tem_dic + if os.path.exists(res_json_path): + os.remove(res_json_path) + with open(res_json_path,'w') as f: + + json.dump(res_data_dic,f) + + +if __name__ == '__main__': + opt = DataProcessingOptions().parse_args() + ########## step1: extract video frames + if opt.extract_video_frame: + extract_video_frame(opt.source_video_dir, opt.video_frame_dir) + ########## step2: extract audio files + if opt.extract_audio: + extract_audio(opt.source_video_dir,opt.audio_dir) + ########## step3: extract deep speech features + if opt.extract_deep_speech: + extract_deep_speech(opt.audio_dir, opt.deep_speech_dir,opt.deep_speech_model) + ########## step4: crop face images + if opt.crop_face: + crop_face_according_openfaceLM(opt.openface_landmark_dir,opt.video_frame_dir,opt.crop_face_dir,opt.clip_length) + ########## step5: generate training json file + if opt.generate_training_json: + generate_training_json(opt.crop_face_dir,opt.deep_speech_dir,opt.clip_length,opt.json_path) + + diff --git a/DINet/dataset/__pycache__/dataset_DINet_clip.cpython-36.pyc b/DINet/dataset/__pycache__/dataset_DINet_clip.cpython-36.pyc new file mode 100644 index 00000000..2340fecf Binary files /dev/null and b/DINet/dataset/__pycache__/dataset_DINet_clip.cpython-36.pyc differ diff --git a/DINet/dataset/__pycache__/dataset_DINet_frame.cpython-36.pyc b/DINet/dataset/__pycache__/dataset_DINet_frame.cpython-36.pyc new file mode 100644 index 00000000..2d3a6f3c Binary files /dev/null and b/DINet/dataset/__pycache__/dataset_DINet_frame.cpython-36.pyc differ diff --git a/DINet/dataset/dataset_DINet_clip.py b/DINet/dataset/dataset_DINet_clip.py new file mode 100644 index 00000000..7d5e8c6a --- /dev/null +++ b/DINet/dataset/dataset_DINet_clip.py @@ -0,0 +1,111 @@ +import torch +import numpy as np +import json +import random +import cv2 + +from torch.utils.data import Dataset + + +def get_data(json_name,augment_num): + print('start loading data') + with open(json_name,'r') as f: + data_dic = json.load(f) + data_dic_name_list = [] + for augment_index in range(augment_num): + for video_name in data_dic.keys(): + data_dic_name_list.append(video_name) + random.shuffle(data_dic_name_list) + print('finish loading') + return data_dic_name_list,data_dic + + +class DINetDataset(Dataset): + def __init__(self,path_json,augment_num,mouth_region_size): + super(DINetDataset, self).__init__() + self.data_dic_name_list,self.data_dic = get_data(path_json,augment_num) + self.mouth_region_size = mouth_region_size + self.radius = mouth_region_size//2 + self.radius_1_4 = self.radius//4 + self.img_h = self.radius * 3 + self.radius_1_4 + self.img_w = self.radius * 2 + self.radius_1_4 * 2 + self.length = len(self.data_dic_name_list) + + def __getitem__(self, index): + video_name = self.data_dic_name_list[index] + video_clip_num = len(self.data_dic[video_name]['clip_data_list']) + source_anchor = random.sample(range(video_clip_num), 1)[0] + source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list'] + source_clip_list = [] + source_clip_mask_list = [] + deep_speech_list = [] + reference_clip_list = [] + for source_frame_index in range(2, 2 + 5): + ## load source clip + source_image_data = cv2.imread(source_image_path_list[source_frame_index])[:, :, ::-1] + source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h)) / 255.0 + source_clip_list.append(source_image_data) + source_image_mask = source_image_data.copy() + source_image_mask[self.radius:self.radius + self.mouth_region_size, + self.radius_1_4:self.radius_1_4 + self.mouth_region_size, :] = 0 + source_clip_mask_list.append(source_image_mask) + + ## load deep speech feature + deepspeech_array = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'][ + source_frame_index - 2:source_frame_index + 3]) + deep_speech_list.append(deepspeech_array) + + ## ## load reference images + reference_frame_list = [] + reference_anchor_list = random.sample(range(video_clip_num), 5) + for reference_anchor in reference_anchor_list: + reference_frame_path_list = self.data_dic[video_name]['clip_data_list'][reference_anchor][ + 'frame_path_list'] + reference_random_index = random.sample(range(9), 1)[0] + reference_frame_path = reference_frame_path_list[reference_random_index] + reference_frame_data = cv2.imread(reference_frame_path)[:, :, ::-1] + reference_frame_data = cv2.resize(reference_frame_data, (self.img_w, self.img_h)) / 255.0 + reference_frame_list.append(reference_frame_data) + reference_clip_list.append(np.concatenate(reference_frame_list, 2)) + + source_clip = np.stack(source_clip_list, 0) + source_clip_mask = np.stack(source_clip_mask_list, 0) + deep_speech_clip = np.stack(deep_speech_list, 0) + reference_clip = np.stack(reference_clip_list, 0) + deep_speech_full = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list']) + + # # display data + # display_source = np.concatenate(source_clip_list,1) + # display_source_mask = np.concatenate(source_clip_mask_list,1) + # display_reference0 = np.concatenate([reference_clip_list[0][:,:,:3],reference_clip_list[0][:,:,3:6],reference_clip_list[0][:,:,6:9], + # reference_clip_list[0][:,:,9:12],reference_clip_list[0][:,:,12:15]],1) + # display_reference1 = np.concatenate([reference_clip_list[1][:, :, :3], reference_clip_list[1][:, :, 3:6], + # reference_clip_list[1][:, :, 6:9], + # reference_clip_list[1][:, :, 9:12], reference_clip_list[1][:, :, 12:15]],1) + # display_reference2 = np.concatenate([reference_clip_list[2][:, :, :3], reference_clip_list[2][:, :, 3:6], + # reference_clip_list[2][:, :, 6:9], + # reference_clip_list[2][:, :, 9:12], reference_clip_list[2][:, :, 12:15]],1) + # display_reference3 = np.concatenate([reference_clip_list[3][:, :, :3], reference_clip_list[3][:, :, 3:6], + # reference_clip_list[3][:, :, 6:9], + # reference_clip_list[3][:, :, 9:12], reference_clip_list[3][:, :, 12:15]],1) + # display_reference4 = np.concatenate([reference_clip_list[4][:, :, :3], reference_clip_list[4][:, :, 3:6], + # reference_clip_list[4][:, :, 6:9], + # reference_clip_list[4][:, :, 9:12], reference_clip_list[4][:, :, 12:15]],1) + # merge_img = np.concatenate([display_source,display_source_mask, + # display_reference0,display_reference1,display_reference2,display_reference3, + # display_reference4],0) + # cv2.imshow('test',(merge_img[:,:,::-1] * 255).astype(np.uint8)) + # cv2.waitKey(-1) + + + + # # 2 tensor + source_clip = torch.from_numpy(source_clip).float().permute(0, 3, 1, 2) + source_clip_mask = torch.from_numpy(source_clip_mask).float().permute(0, 3, 1, 2) + reference_clip = torch.from_numpy(reference_clip).float().permute(0, 3, 1, 2) + deep_speech_clip = torch.from_numpy(deep_speech_clip).float().permute(0, 2, 1) + deep_speech_full = torch.from_numpy(deep_speech_full).permute(1, 0) + return source_clip,source_clip_mask, reference_clip,deep_speech_clip,deep_speech_full + + def __len__(self): + return self.length diff --git a/DINet/dataset/dataset_DINet_frame.py b/DINet/dataset/dataset_DINet_frame.py new file mode 100644 index 00000000..4a74e01b --- /dev/null +++ b/DINet/dataset/dataset_DINet_frame.py @@ -0,0 +1,76 @@ +import torch +import numpy as np +import json +import random +import cv2 + +from torch.utils.data import Dataset + + +def get_data(json_name,augment_num): + print('start loading data') + with open(json_name,'r') as f: + data_dic = json.load(f) + data_dic_name_list = [] + for augment_index in range(augment_num): + for video_name in data_dic.keys(): + data_dic_name_list.append(video_name) + random.shuffle(data_dic_name_list) + print('finish loading') + return data_dic_name_list,data_dic + + +class DINetDataset(Dataset): + def __init__(self,path_json,augment_num,mouth_region_size): + super(DINetDataset, self).__init__() + self.data_dic_name_list,self.data_dic = get_data(path_json,augment_num) + self.mouth_region_size = mouth_region_size + self.radius = mouth_region_size//2 + self.radius_1_4 = self.radius//4 + self.img_h = self.radius * 3 + self.radius_1_4 + self.img_w = self.radius * 2 + self.radius_1_4 * 2 + self.length = len(self.data_dic_name_list) + + def __getitem__(self, index): + video_name = self.data_dic_name_list[index] + video_clip_num = len(self.data_dic[video_name]['clip_data_list']) + random_anchor = random.sample(range(video_clip_num), 6) + source_anchor, reference_anchor_list = random_anchor[0],random_anchor[1:] + ## load source image + source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list'] + source_random_index = random.sample(range(2, 7), 1)[0] + source_image_data = cv2.imread(source_image_path_list[source_random_index])[:, :, ::-1] + source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h))/ 255.0 + source_image_mask = source_image_data.copy() + source_image_mask[self.radius:self.radius+self.mouth_region_size,self.radius_1_4:self.radius_1_4 +self.mouth_region_size ,:] = 0 + + ## load deep speech feature + deepspeech_feature = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'][source_random_index - 2:source_random_index + 3]) + + ## load reference images + reference_frame_data_list = [] + for reference_anchor in reference_anchor_list: + reference_frame_path_list = self.data_dic[video_name]['clip_data_list'][reference_anchor]['frame_path_list'] + reference_random_index = random.sample(range(9), 1)[0] + reference_frame_path = reference_frame_path_list[reference_random_index] + reference_frame_data = cv2.imread(reference_frame_path)[:, :, ::-1] + reference_frame_data = cv2.resize(reference_frame_data, (self.img_w, self.img_h))/ 255.0 + reference_frame_data_list.append(reference_frame_data) + reference_clip_data = np.concatenate(reference_frame_data_list, 2) + + # display the source image and reference images + # display_img = np.concatenate([source_image_data,source_image_mask]+reference_frame_data_list,1) + # cv2.imshow('image display',(display_img[:,:,::-1] * 255).astype(np.uint8)) + # cv2.waitKey(-1) + + # # to tensor + source_image_data = torch.from_numpy(source_image_data).float().permute(2,0,1) + source_image_mask = torch.from_numpy(source_image_mask).float().permute(2,0,1) + reference_clip_data = torch.from_numpy(reference_clip_data).float().permute(2,0,1) + deepspeech_feature = torch.from_numpy(deepspeech_feature).float().permute(1,0) + return source_image_data,source_image_mask, reference_clip_data,deepspeech_feature + + def __len__(self): + return self.length + + diff --git a/DINet/inference.py b/DINet/inference.py new file mode 100644 index 00000000..ea21b305 --- /dev/null +++ b/DINet/inference.py @@ -0,0 +1,173 @@ +from utils.deep_speech import DeepSpeech +from utils.data_processing import load_landmark_openface,compute_crop_radius +from config.config import DINetInferenceOptions +from models.DINet import DINet + +import numpy as np +import glob +import os +import cv2 +import torch +import subprocess +import random +from collections import OrderedDict + +def extract_frames_from_video(video_path,save_dir): + videoCapture = cv2.VideoCapture(video_path) + fps = videoCapture.get(cv2.CAP_PROP_FPS) + if int(fps) != 25: + print('warning: the input video is not 25 fps, it would be better to trans it to 25 fps!') + frames = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) + frame_height = videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT) + frame_width = videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) + for i in range(int(frames)): + ret, frame = videoCapture.read() + result_path = os.path.join(save_dir, str(i).zfill(6) + '.jpg') + cv2.imwrite(result_path, frame) + return (int(frame_width),int(frame_height)) + +if __name__ == '__main__': + # load config + opt = DINetInferenceOptions().parse_args() + if not os.path.exists(opt.source_video_path): + raise ('wrong video path : {}'.format(opt.source_video_path)) + ############################################## extract frames from source video ############################################## + print('extracting frames from video: {}'.format(opt.source_video_path)) + video_frame_dir = opt.source_video_path.replace('.mp4', '') + if not os.path.exists(video_frame_dir): + os.mkdir(video_frame_dir) + video_size = extract_frames_from_video(opt.source_video_path,video_frame_dir) + ############################################## extract deep speech feature ############################################## + print('extracting deepspeech feature from : {}'.format(opt.driving_audio_path)) + if not os.path.exists(opt.deepspeech_model_path): + raise ('pls download pretrained model of deepspeech') + DSModel = DeepSpeech(opt.deepspeech_model_path) + if not os.path.exists(opt.driving_audio_path): + raise ('wrong audio path :{}'.format(opt.driving_audio_path)) + ds_feature = DSModel.compute_audio_feature(opt.driving_audio_path) + res_frame_length = ds_feature.shape[0] + ds_feature_padding = np.pad(ds_feature, ((2, 2), (0, 0)), mode='edge') + ############################################## load facial landmark ############################################## + print('loading facial landmarks from : {}'.format(opt.source_openface_landmark_path)) + if not os.path.exists(opt.source_openface_landmark_path): + raise ('wrong facial landmark path :{}'.format(opt.source_openface_landmark_path)) + video_landmark_data = load_landmark_openface(opt.source_openface_landmark_path).astype(np.int) + ############################################## align frame with driving audio ############################################## + print('aligning frames with driving audio') + video_frame_path_list = glob.glob(os.path.join(video_frame_dir, '*.jpg')) + if len(video_frame_path_list) != video_landmark_data.shape[0]: + raise ('video frames are misaligned with detected landmarks') + video_frame_path_list.sort() + video_frame_path_list_cycle = video_frame_path_list + video_frame_path_list[::-1] + video_landmark_data_cycle = np.concatenate([video_landmark_data, np.flip(video_landmark_data, 0)], 0) + video_frame_path_list_cycle_length = len(video_frame_path_list_cycle) + if video_frame_path_list_cycle_length >= res_frame_length: + res_video_frame_path_list = video_frame_path_list_cycle[:res_frame_length] + res_video_landmark_data = video_landmark_data_cycle[:res_frame_length, :, :] + else: + divisor = res_frame_length // video_frame_path_list_cycle_length + remainder = res_frame_length % video_frame_path_list_cycle_length + res_video_frame_path_list = video_frame_path_list_cycle * divisor + video_frame_path_list_cycle[:remainder] + res_video_landmark_data = np.concatenate([video_landmark_data_cycle]* divisor + [video_landmark_data_cycle[:remainder, :, :]],0) + res_video_frame_path_list_pad = [video_frame_path_list_cycle[0]] * 2 \ + + res_video_frame_path_list \ + + [video_frame_path_list_cycle[-1]] * 2 + res_video_landmark_data_pad = np.pad(res_video_landmark_data, ((2, 2), (0, 0), (0, 0)), mode='edge') + assert ds_feature_padding.shape[0] == len(res_video_frame_path_list_pad) == res_video_landmark_data_pad.shape[0] + pad_length = ds_feature_padding.shape[0] + + ############################################## randomly select 5 reference images ############################################## + print('selecting five reference images') + ref_img_list = [] + resize_w = int(opt.mouth_region_size + opt.mouth_region_size // 4) + resize_h = int((opt.mouth_region_size // 2) * 3 + opt.mouth_region_size // 8) + ref_index_list = random.sample(range(5, len(res_video_frame_path_list_pad) - 2), 5) + for ref_index in ref_index_list: + crop_flag,crop_radius = compute_crop_radius(video_size,res_video_landmark_data_pad[ref_index - 5:ref_index, :, :]) + if not crop_flag: + raise ('our method can not handle videos with large change of facial size!!') + crop_radius_1_4 = crop_radius // 4 + ref_img = cv2.imread(res_video_frame_path_list_pad[ref_index- 3])[:, :, ::-1] + ref_landmark = res_video_landmark_data_pad[ref_index - 3, :, :] + ref_img_crop = ref_img[ + ref_landmark[29, 1] - crop_radius:ref_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, + ref_landmark[33, 0] - crop_radius - crop_radius_1_4:ref_landmark[33, 0] + crop_radius +crop_radius_1_4, + :] + ref_img_crop = cv2.resize(ref_img_crop,(resize_w,resize_h)) + ref_img_crop = ref_img_crop / 255.0 + ref_img_list.append(ref_img_crop) + ref_video_frame = np.concatenate(ref_img_list, 2) + ref_img_tensor = torch.from_numpy(ref_video_frame).permute(2, 0, 1).unsqueeze(0).float().cuda() + + ############################################## load pretrained model weight ############################################## + print('loading pretrained model from: {}'.format(opt.pretrained_clip_DINet_path)) + model = DINet(opt.source_channel, opt.ref_channel, opt.audio_channel).cuda() + if not os.path.exists(opt.pretrained_clip_DINet_path): + raise ('wrong path of pretrained model weight: {}'.format(opt.pretrained_clip_DINet_path)) + state_dict = torch.load(opt.pretrained_clip_DINet_path)['state_dict']['net_g'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove module. + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + model.eval() + ############################################## inference frame by frame ############################################## + if not os.path.exists(opt.res_video_dir): + os.mkdir(opt.res_video_dir) + res_video_path = os.path.join(opt.res_video_dir,os.path.basename(opt.source_video_path)[:-4] + '_facial_dubbing.mp4') + if os.path.exists(res_video_path): + os.remove(res_video_path) + res_face_path = res_video_path.replace('_facial_dubbing.mp4', '_synthetic_face.mp4') + if os.path.exists(res_face_path): + os.remove(res_face_path) + videowriter = cv2.VideoWriter(res_video_path, cv2.VideoWriter_fourcc(*'XVID'), 25, video_size) + videowriter_face = cv2.VideoWriter(res_face_path, cv2.VideoWriter_fourcc(*'XVID'), 25, (resize_w, resize_h)) + for clip_end_index in range(5, pad_length, 1): + print('synthesizing {}/{} frame'.format(clip_end_index - 5, pad_length - 5)) + crop_flag, crop_radius = compute_crop_radius(video_size,res_video_landmark_data_pad[clip_end_index - 5:clip_end_index, :, :],random_scale = 1.05) + if not crop_flag: + raise ('our method can not handle videos with large change of facial size!!') + crop_radius_1_4 = crop_radius // 4 + frame_data = cv2.imread(res_video_frame_path_list_pad[clip_end_index - 3])[:, :, ::-1] + frame_landmark = res_video_landmark_data_pad[clip_end_index - 3, :, :] + crop_frame_data = frame_data[ + frame_landmark[29, 1] - crop_radius:frame_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, + frame_landmark[33, 0] - crop_radius - crop_radius_1_4:frame_landmark[33, 0] + crop_radius +crop_radius_1_4, + :] + crop_frame_h,crop_frame_w = crop_frame_data.shape[0],crop_frame_data.shape[1] + crop_frame_data = cv2.resize(crop_frame_data, (resize_w,resize_h)) # [32:224, 32:224, :] + crop_frame_data = crop_frame_data / 255.0 + crop_frame_data[opt.mouth_region_size//2:opt.mouth_region_size//2 + opt.mouth_region_size, + opt.mouth_region_size//8:opt.mouth_region_size//8 + opt.mouth_region_size, :] = 0 + + crop_frame_tensor = torch.from_numpy(crop_frame_data).float().cuda().permute(2, 0, 1).unsqueeze(0) + deepspeech_tensor = torch.from_numpy(ds_feature_padding[clip_end_index - 5:clip_end_index, :]).permute(1, 0).unsqueeze(0).float().cuda() + with torch.no_grad(): + pre_frame = model(crop_frame_tensor, ref_img_tensor, deepspeech_tensor) + pre_frame = pre_frame.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() * 255 + videowriter_face.write(pre_frame[:, :, ::-1].copy().astype(np.uint8)) + pre_frame_resize = cv2.resize(pre_frame, (crop_frame_w,crop_frame_h)) + frame_data[ + frame_landmark[29, 1] - crop_radius: + frame_landmark[29, 1] + crop_radius * 2, + frame_landmark[33, 0] - crop_radius - crop_radius_1_4: + frame_landmark[33, 0] + crop_radius + crop_radius_1_4, + :] = pre_frame_resize[:crop_radius * 3,:,:] + videowriter.write(frame_data[:, :, ::-1]) + videowriter.release() + videowriter_face.release() + video_add_audio_path = res_video_path.replace('.mp4', '_add_audio.mp4') + if os.path.exists(video_add_audio_path): + os.remove(video_add_audio_path) + cmd = 'ffmpeg -i {} -i {} -c:v copy -c:a aac -strict experimental -map 0:v:0 -map 1:a:0 {}'.format( + res_video_path, + opt.driving_audio_path, + video_add_audio_path) + subprocess.call(cmd, shell=True) + + + + + + + diff --git a/DINet/models/DINet.py b/DINet/models/DINet.py new file mode 100644 index 00000000..5124bf05 --- /dev/null +++ b/DINet/models/DINet.py @@ -0,0 +1,299 @@ +import torch +from torch import nn +import torch.nn.functional as F +import math +import cv2 +import numpy as np +from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d +from sync_batchnorm import SynchronizedBatchNorm1d as BatchNorm1d + +def make_coordinate_grid_3d(spatial_size, type): + ''' + generate 3D coordinate grid + ''' + d, h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + z = torch.arange(d).type(type) + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + z = (2 * (z / (d - 1)) - 1) + yy = y.view(1,-1, 1).repeat(d,1, w) + xx = x.view(1,1, -1).repeat(d,h, 1) + zz = z.view(-1,1,1).repeat(1,h,w) + meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3)], 3) + return meshed,zz + +class ResBlock1d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features,out_features, kernel_size, padding): + super(ResBlock1d, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + if out_features != in_features: + self.channel_conv = nn.Conv1d(in_features,out_features,1) + self.norm1 = BatchNorm1d(in_features) + self.norm2 = BatchNorm1d(in_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.norm1(x) + out = self.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = self.relu(out) + out = self.conv2(out) + if self.in_features != self.out_features: + out += self.channel_conv(x) + else: + out += x + return out + +class ResBlock2d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features,out_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + if out_features != in_features: + self.channel_conv = nn.Conv2d(in_features,out_features,1) + self.norm1 = BatchNorm2d(in_features) + self.norm2 = BatchNorm2d(in_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.norm1(x) + out = self.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = self.relu(out) + out = self.conv2(out) + if self.in_features != self.out_features: + out += self.channel_conv(x) + else: + out += x + return out + +class UpBlock2d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features, out_features, kernel_size=3, padding=1): + super(UpBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + self.norm = BatchNorm2d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + +class DownBlock1d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features, out_features, kernel_size, padding): + super(DownBlock1d, self).__init__() + self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding,stride=2) + self.norm = BatchNorm1d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + +class DownBlock2d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features, out_features, kernel_size=3, padding=1, stride=2): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, stride=stride) + self.norm = BatchNorm2d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + +class SameBlock1d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features, out_features, kernel_size, padding): + super(SameBlock1d, self).__init__() + self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding) + self.norm = BatchNorm1d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + +class SameBlock2d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features, out_features, kernel_size=3, padding=1): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding) + self.norm = BatchNorm2d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + +class AdaAT(nn.Module): + ''' + AdaAT operator + ''' + def __init__(self, para_ch,feature_ch): + super(AdaAT, self).__init__() + self.para_ch = para_ch + self.feature_ch = feature_ch + self.commn_linear = nn.Sequential( + nn.Linear(para_ch, para_ch), + nn.ReLU() + ) + self.scale = nn.Sequential( + nn.Linear(para_ch, feature_ch), + nn.Sigmoid() + ) + self.rotation = nn.Sequential( + nn.Linear(para_ch, feature_ch), + nn.Tanh() + ) + self.translation = nn.Sequential( + nn.Linear(para_ch, 2 * feature_ch), + nn.Tanh() + ) + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + + def forward(self, feature_map,para_code): + batch,d, h, w = feature_map.size(0), feature_map.size(1), feature_map.size(2), feature_map.size(3) + para_code = self.commn_linear(para_code) + scale = self.scale(para_code).unsqueeze(-1) * 2 + angle = self.rotation(para_code).unsqueeze(-1) * 3.14159# + rotation_matrix = torch.cat([torch.cos(angle), -torch.sin(angle), torch.sin(angle), torch.cos(angle)], -1) + rotation_matrix = rotation_matrix.view(batch, self.feature_ch, 2, 2) + translation = self.translation(para_code).view(batch, self.feature_ch, 2) + grid_xy, grid_z = make_coordinate_grid_3d((d, h, w), feature_map.type()) + grid_xy = grid_xy.unsqueeze(0).repeat(batch, 1, 1, 1, 1) + grid_z = grid_z.unsqueeze(0).repeat(batch, 1, 1, 1) + scale = scale.unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w, 1) + rotation_matrix = rotation_matrix.unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w, 1, 1) + translation = translation.unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w, 1) + trans_grid = torch.matmul(rotation_matrix, grid_xy.unsqueeze(-1)).squeeze(-1) * scale + translation + full_grid = torch.cat([trans_grid, grid_z.unsqueeze(-1)], -1) + trans_feature = F.grid_sample(feature_map.unsqueeze(1), full_grid).squeeze(1) + return trans_feature + +class DINet(nn.Module): + def __init__(self, source_channel,ref_channel,audio_channel): + super(DINet, self).__init__() + self.source_in_conv = nn.Sequential( + SameBlock2d(source_channel,64,kernel_size=7, padding=3), + DownBlock2d(64, 128, kernel_size=3, padding=1), + DownBlock2d(128,256,kernel_size=3, padding=1) + ) + self.ref_in_conv = nn.Sequential( + SameBlock2d(ref_channel, 64, kernel_size=7, padding=3), + DownBlock2d(64, 128, kernel_size=3, padding=1), + DownBlock2d(128, 256, kernel_size=3, padding=1), + ) + self.trans_conv = nn.Sequential( + # 20 →10 + SameBlock2d(512, 128, kernel_size=3, padding=1), + SameBlock2d(128, 128, kernel_size=11, padding=5), + SameBlock2d(128, 128, kernel_size=11, padding=5), + DownBlock2d(128, 128, kernel_size=3, padding=1), + # 10 →5 + SameBlock2d(128, 128, kernel_size=7, padding=3), + SameBlock2d(128, 128, kernel_size=7, padding=3), + DownBlock2d(128, 128, kernel_size=3, padding=1), + # 5 →3 + SameBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128, 128, kernel_size=3, padding=1), + # 3 →2 + SameBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128, 128, kernel_size=3, padding=1), + + ) + self.audio_encoder = nn.Sequential( + SameBlock1d(audio_channel, 128, kernel_size=5, padding=2), + ResBlock1d(128, 128, 3, 1), + DownBlock1d(128, 128, 3, 1), + ResBlock1d(128, 128, 3, 1), + DownBlock1d(128, 128, 3, 1), + SameBlock1d(128, 128, kernel_size=3, padding=1) + ) + + appearance_conv_list = [] + for i in range(2): + appearance_conv_list.append( + nn.Sequential( + ResBlock2d(256, 256, 3, 1), + ResBlock2d(256, 256, 3, 1), + ResBlock2d(256, 256, 3, 1), + ResBlock2d(256, 256, 3, 1), + ) + ) + self.appearance_conv_list = nn.ModuleList(appearance_conv_list) + self.adaAT = AdaAT(256, 256) + self.out_conv = nn.Sequential( + SameBlock2d(512, 128, kernel_size=3, padding=1), + UpBlock2d(128,128,kernel_size=3, padding=1), + ResBlock2d(128, 128, 3, 1), + UpBlock2d(128, 128, kernel_size=3, padding=1), + nn.Conv2d(128, 3, kernel_size=(7, 7), padding=(3, 3)), + nn.Sigmoid() + ) + self.global_avg2d = nn.AdaptiveAvgPool2d(1) + self.global_avg1d = nn.AdaptiveAvgPool1d(1) + def forward(self, source_img,ref_img,audio_feature): + ## source image encoder + source_in_feature = self.source_in_conv(source_img) + ## reference image encoder + ref_in_feature = self.ref_in_conv(ref_img) + ## alignment encoder + img_para = self.trans_conv(torch.cat([source_in_feature,ref_in_feature],1)) + img_para = self.global_avg2d(img_para).squeeze(3).squeeze(2) + ## audio encoder + audio_para = self.audio_encoder(audio_feature) + audio_para = self.global_avg1d(audio_para).squeeze(2) + ## concat alignment feature and audio feature + trans_para = torch.cat([img_para,audio_para],1) + ## use AdaAT do spatial deformation on reference feature maps + ref_trans_feature = self.appearance_conv_list[0](ref_in_feature) + ref_trans_feature = self.adaAT(ref_trans_feature, trans_para) + ref_trans_feature = self.appearance_conv_list[1](ref_trans_feature) + ## feature decoder + merge_feature = torch.cat([source_in_feature,ref_trans_feature],1) + out = self.out_conv(merge_feature) + return out + + + diff --git a/DINet/models/Discriminator.py b/DINet/models/Discriminator.py new file mode 100644 index 00000000..0aa737ab --- /dev/null +++ b/DINet/models/Discriminator.py @@ -0,0 +1,39 @@ +from torch import nn +import torch.nn.functional as F + +class DownBlock2d(nn.Module): + def __init__(self, in_features, out_features, kernel_size=4, pool=False): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) + self.pool = pool + def forward(self, x): + out = x + out = self.conv(out) + out = F.leaky_relu(out, 0.2) + if self.pool: + out = F.avg_pool2d(out, (2, 2)) + return out + + +class Discriminator(nn.Module): + """ + Discriminator for GAN loss + """ + def __init__(self, num_channels, block_expansion=64, num_blocks=4, max_features=512): + super(Discriminator, self).__init__() + down_blocks = [] + for i in range(num_blocks): + down_blocks.append( + DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=4, pool=(i != num_blocks - 1))) + self.down_blocks = nn.ModuleList(down_blocks) + self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) + def forward(self, x): + feature_maps = [] + out = x + for down_block in self.down_blocks: + feature_maps.append(down_block(out)) + out = feature_maps[-1] + out = self.conv(out) + return feature_maps, out diff --git a/DINet/models/Syncnet.py b/DINet/models/Syncnet.py new file mode 100644 index 00000000..44043670 --- /dev/null +++ b/DINet/models/Syncnet.py @@ -0,0 +1,215 @@ +import torch +from torch import nn +class ResBlock1d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features,out_features, kernel_size, padding): + super(ResBlock1d, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + if out_features != in_features: + self.channel_conv = nn.Conv1d(in_features,out_features,1) + self.relu = nn.ReLU() + def forward(self, x): + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + if self.in_features != self.out_features: + out += self.channel_conv(x) + else: + out += x + return out + +class ResBlock2d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features,out_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + if out_features != in_features: + self.channel_conv = nn.Conv2d(in_features,out_features,1) + self.relu = nn.ReLU() + def forward(self, x): + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + if self.in_features != self.out_features: + out += self.channel_conv(x) + else: + out += x + return out + +class DownBlock1d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features, out_features, kernel_size, padding): + super(DownBlock1d, self).__init__() + self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding,stride=2) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + +class DownBlock2d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + out = self.pool(out) + return out + +class SameBlock1d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features, out_features, kernel_size, padding): + super(SameBlock1d, self).__init__() + self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + +class SameBlock2d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding, groups=groups) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + +class FaceEncoder(nn.Module): + ''' + image encoder + ''' + def __init__(self, in_channel, out_dim): + super(FaceEncoder, self).__init__() + self.in_channel = in_channel + self.out_dim = out_dim + self.face_conv = nn.Sequential( + SameBlock2d(in_channel,64,kernel_size=7,padding=3), + # # 64 → 32 + ResBlock2d(64, 64, kernel_size=3, padding=1), + DownBlock2d(64,64,3,1), + SameBlock2d(64, 128), + # 32 → 16 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128, 128), + # 16 → 8 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128, 128), + # 8 → 4 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128, 128), + # 4 → 2 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128,out_dim,kernel_size=1,padding=0) + ) + def forward(self, x): + ## b x c x h x w + out = self.face_conv(x) + return out + +class AudioEncoder(nn.Module): + ''' + audio encoder + ''' + def __init__(self, in_channel, out_dim): + super(AudioEncoder, self).__init__() + self.in_channel = in_channel + self.out_dim = out_dim + self.audio_conv = nn.Sequential( + SameBlock1d(in_channel,128,kernel_size=7,padding=3), + ResBlock1d(128, 128, 3, 1), + # 9-5 + DownBlock1d(128, 128, 3, 1), + ResBlock1d(128, 128, 3, 1), + # 5 -3 + DownBlock1d(128, 128, 3, 1), + ResBlock1d(128, 128, 3, 1), + # 3-2 + DownBlock1d(128, 128, 3, 1), + SameBlock1d(128,out_dim,kernel_size=3,padding=1) + ) + self.global_avg = nn.AdaptiveAvgPool1d(1) + def forward(self, x): + ## b x c x t + out = self.audio_conv(x) + return self.global_avg(out).squeeze(2) + +class SyncNet(nn.Module): + ''' + syncnet + ''' + def __init__(self, in_channel_image,in_channel_audio, out_dim): + super(SyncNet, self).__init__() + self.in_channel_image = in_channel_image + self.in_channel_audio = in_channel_audio + self.out_dim = out_dim + self.face_encoder = FaceEncoder(in_channel_image,out_dim) + self.audio_encoder = AudioEncoder(in_channel_audio,out_dim) + self.merge_encoder = nn.Sequential( + nn.Conv2d(out_dim * 2, out_dim, kernel_size=3, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(out_dim, 1, kernel_size=3, padding=1), + ) + def forward(self, image,audio): + image_embedding = self.face_encoder(image) + audio_embedding = self.audio_encoder(audio).unsqueeze(2).unsqueeze(3).repeat(1,1,image_embedding.size(2),image_embedding.size(3)) + concat_embedding = torch.cat([image_embedding,audio_embedding],1) + out_score = self.merge_encoder(concat_embedding) + return out_score + +class SyncNetPerception(nn.Module): + ''' + use syncnet to compute perception loss + ''' + def __init__(self,pretrain_path): + super(SyncNetPerception, self).__init__() + self.model = SyncNet(15,29,128) + print('load lip sync model : {}'.format(pretrain_path)) + self.model.load_state_dict(torch.load(pretrain_path)['state_dict']['net']) + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + + def forward(self, image,audio): + score = self.model(image,audio) + return score + diff --git a/DINet/models/VGG19.py b/DINet/models/VGG19.py new file mode 100644 index 00000000..a0bf89d4 --- /dev/null +++ b/DINet/models/VGG19.py @@ -0,0 +1,47 @@ +import torch +from torchvision import models +import numpy as np + + +class Vgg19(torch.nn.Module): + """ + Vgg19 network for perceptual loss + """ + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_model = models.vgg19(pretrained=True) + vgg_pretrained_features = vgg_model.features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), + requires_grad=False) + self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), + requires_grad=False) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + X = (X - self.mean) / self.std + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out diff --git a/DINet/models/__pycache__/DINet.cpython-36.pyc b/DINet/models/__pycache__/DINet.cpython-36.pyc new file mode 100644 index 00000000..e0e0bf51 Binary files /dev/null and b/DINet/models/__pycache__/DINet.cpython-36.pyc differ diff --git a/DINet/models/__pycache__/Discriminator.cpython-36.pyc b/DINet/models/__pycache__/Discriminator.cpython-36.pyc new file mode 100644 index 00000000..35c3d9f1 Binary files /dev/null and b/DINet/models/__pycache__/Discriminator.cpython-36.pyc differ diff --git a/DINet/models/__pycache__/Syncnet.cpython-36.pyc b/DINet/models/__pycache__/Syncnet.cpython-36.pyc new file mode 100644 index 00000000..0b1d46d6 Binary files /dev/null and b/DINet/models/__pycache__/Syncnet.cpython-36.pyc differ diff --git a/DINet/models/__pycache__/VGG19.cpython-36.pyc b/DINet/models/__pycache__/VGG19.cpython-36.pyc new file mode 100644 index 00000000..048b20c0 Binary files /dev/null and b/DINet/models/__pycache__/VGG19.cpython-36.pyc differ diff --git a/DINet/models/old/Syncnet_BN.py b/DINet/models/old/Syncnet_BN.py new file mode 100644 index 00000000..63de9e43 --- /dev/null +++ b/DINet/models/old/Syncnet_BN.py @@ -0,0 +1,213 @@ +import torch +from torch import nn + +class ResBlock1d(nn.Module): + ''' + basic block (BN) + ''' + def __init__(self, in_features,out_features, kernel_size, padding): + super(ResBlock1d, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + if out_features != in_features: + self.channel_conv = nn.Conv1d(in_features,out_features,1) + self.norm1 = nn.BatchNorm1d(in_features) + self.norm2 = nn.BatchNorm1d(in_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.norm1(x) + out = self.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = self.relu(out) + out = self.conv2(out) + if self.in_features != self.out_features: + out += self.channel_conv(x) + else: + out += x + return out + +class ResBlock2d(nn.Module): + ''' + basic block (BN) + ''' + def __init__(self, in_features,out_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + if out_features != in_features: + self.channel_conv = nn.Conv2d(in_features,out_features,1) + self.norm1 = nn.BatchNorm2d(in_features) + self.norm2 = nn.BatchNorm2d(in_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.norm1(x) + out = self.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = self.relu(out) + out = self.conv2(out) + if self.in_features != self.out_features: + out += self.channel_conv(x) + else: + out += x + return out + +class DownBlock1d(nn.Module): + ''' + basic block (BN) + ''' + def __init__(self, in_features, out_features, kernel_size, padding): + super(DownBlock1d, self).__init__() + self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding,stride=2) + self.norm = nn.BatchNorm1d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + +class DownBlock2d(nn.Module): + ''' + basic block (BN) + ''' + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + out = self.pool(out) + return out + +class SameBlock1d(nn.Module): + ''' + basic block (BN) + ''' + def __init__(self, in_features, out_features, kernel_size, padding): + super(SameBlock1d, self).__init__() + self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding) + self.norm = nn.BatchNorm1d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + +class SameBlock2d(nn.Module): + ''' + basic block (BN) + ''' + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + + +class FaceEncoder(nn.Module): + ''' + image encoder + ''' + def __init__(self, in_channel, out_dim): + super(FaceEncoder, self).__init__() + self.in_channel = in_channel + self.out_dim = out_dim + self.face_conv = nn.Sequential( + SameBlock2d(in_channel,64,kernel_size=7,padding=3), + # # 64 → 32 + ResBlock2d(64, 64, kernel_size=3, padding=1), + DownBlock2d(64,64,3,1), + SameBlock2d(64, 128), + # 32 → 16 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128, 128), + # 16 → 8 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128, 128), + # 8 → 4 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128, 128), + # 4 → 2 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128,out_dim,kernel_size=1,padding=0) + ) + def forward(self, x): + ## b x c x h x w + out = self.face_conv(x) + return out + +class AudioEncoder(nn.Module): + ''' + audio encoder + ''' + def __init__(self, in_channel, out_dim): + super(AudioEncoder, self).__init__() + self.in_channel = in_channel + self.out_dim = out_dim + self.audio_conv = nn.Sequential( + SameBlock1d(in_channel,128,kernel_size=7,padding=3), + ResBlock1d(128, 128, 3, 1), + # 9-5 + DownBlock1d(128, 128, 3, 1), + ResBlock1d(128, 128, 3, 1), + # 5 -3 + DownBlock1d(128, 128, 3, 1), + ResBlock1d(128, 128, 3, 1), + # 3-2 + DownBlock1d(128, 128, 3, 1), + SameBlock1d(128,out_dim,kernel_size=3,padding=1) + ) + self.global_avg = nn.AdaptiveAvgPool1d(1) + def forward(self, x): + ## b x c x t + out = self.audio_conv(x) + return self.global_avg(out).squeeze(2) + +class SyncNet(nn.Module): + def __init__(self, in_channel_image,in_channel_audio, out_dim): + super(SyncNet, self).__init__() + self.in_channel_image = in_channel_image + self.in_channel_audio = in_channel_audio + self.out_dim = out_dim + self.face_encoder = FaceEncoder(in_channel_image,out_dim) + self.audio_encoder = AudioEncoder(in_channel_audio,out_dim) + self.merge_encoder = nn.Sequential( + nn.Conv2d(out_dim * 2, out_dim, kernel_size=3, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(out_dim, 1, kernel_size=3, padding=1), + ) + def forward(self, image,audio): + image_embedding = self.face_encoder(image) + audio_embedding = self.audio_encoder(audio).unsqueeze(2).unsqueeze(3).repeat(1,1,image_embedding.size(2),image_embedding.size(3)) + concat_embedding = torch.cat([image_embedding,audio_embedding],1) + out_score = self.merge_encoder(concat_embedding) + return out_score + diff --git a/DINet/models/old/Syncnet_halfBN.py b/DINet/models/old/Syncnet_halfBN.py new file mode 100644 index 00000000..1fccc91d --- /dev/null +++ b/DINet/models/old/Syncnet_halfBN.py @@ -0,0 +1,200 @@ +import torch +from torch import nn +import torch.nn.functional as F + +class ResBlock1d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features,out_features, kernel_size, padding): + super(ResBlock1d, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + if out_features != in_features: + self.channel_conv = nn.Conv1d(in_features,out_features,1) + self.relu = nn.ReLU() + def forward(self, x): + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + if self.in_features != self.out_features: + out += self.channel_conv(x) + else: + out += x + return out + +class ResBlock2d(nn.Module): + ''' + basic block (BN) + ''' + def __init__(self, in_features,out_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + if out_features != in_features: + self.channel_conv = nn.Conv2d(in_features,out_features,1) + self.norm1 = nn.BatchNorm2d(in_features) + self.norm2 = nn.BatchNorm2d(in_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.norm1(x) + out = self.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = self.relu(out) + out = self.conv2(out) + if self.in_features != self.out_features: + out += self.channel_conv(x) + else: + out += x + return out + +class DownBlock1d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features, out_features, kernel_size, padding): + super(DownBlock1d, self).__init__() + self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding,stride=2) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + +class DownBlock2d(nn.Module): + ''' + basic block (BN) + ''' + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + out = self.pool(out) + return out + +class SameBlock1d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features, out_features, kernel_size, padding): + super(SameBlock1d, self).__init__() + self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + +class SameBlock2d(nn.Module): + ''' + basic block (BN) + ''' + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + + +class FaceEncoder(nn.Module): + def __init__(self, in_channel, out_dim): + super(FaceEncoder, self).__init__() + self.in_channel = in_channel + self.out_dim = out_dim + self.face_conv = nn.Sequential( + SameBlock2d(in_channel,64,kernel_size=7,padding=3), + # # 64 → 32 + ResBlock2d(64, 64, kernel_size=3, padding=1), + DownBlock2d(64,64,3,1), + SameBlock2d(64, 128), + # 32 → 16 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128, 128), + # 16 → 8 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128, 128), + # 8 → 4 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128, 128), + # 4 → 2 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128,out_dim,kernel_size=1,padding=0) + ) + def forward(self, x): + ## b x c x h x w + out = self.face_conv(x) + return out + +class AudioEncoder(nn.Module): + def __init__(self, in_channel, out_dim): + super(AudioEncoder, self).__init__() + self.in_channel = in_channel + self.out_dim = out_dim + self.audio_conv = nn.Sequential( + SameBlock1d(in_channel,128,kernel_size=7,padding=3), + ResBlock1d(128, 128, 3, 1), + # 9-5 + DownBlock1d(128, 128, 3, 1), + ResBlock1d(128, 128, 3, 1), + # 5 -3 + DownBlock1d(128, 128, 3, 1), + ResBlock1d(128, 128, 3, 1), + # 3-2 + DownBlock1d(128, 128, 3, 1), + SameBlock1d(128,out_dim,kernel_size=3,padding=1) + ) + self.global_avg = nn.AdaptiveAvgPool1d(1) + def forward(self, x): + ## b x c x t + out = self.audio_conv(x) + return self.global_avg(out).squeeze(2) + +class SyncNet(nn.Module): + def __init__(self, in_channel_image,in_channel_audio, out_dim): + super(SyncNet, self).__init__() + self.in_channel_image = in_channel_image + self.in_channel_audio = in_channel_audio + self.out_dim = out_dim + self.face_encoder = FaceEncoder(in_channel_image,out_dim) + self.audio_encoder = AudioEncoder(in_channel_audio,out_dim) + self.merge_encoder = nn.Sequential( + nn.Conv2d(out_dim * 2, out_dim, kernel_size=3, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(out_dim, 1, kernel_size=3, padding=1), + ) + def forward(self, image,audio): + image_embedding = self.face_encoder(image) + audio_embedding = self.audio_encoder(audio).unsqueeze(2).unsqueeze(3).repeat(1,1,image_embedding.size(2),image_embedding.size(3)) + concat_embedding = torch.cat([image_embedding,audio_embedding],1) + out_score = self.merge_encoder(concat_embedding) + return out_score + diff --git a/DINet/requirements.txt b/DINet/requirements.txt new file mode 100644 index 00000000..d11ef7a5 --- /dev/null +++ b/DINet/requirements.txt @@ -0,0 +1,11 @@ +opencv_python == 4.6.0.66 +numpy == 1.19.2 +python_speech_features == 0.6 +resampy == 0.2.2 +scipy == 1.5.4 +tensorflow == 1.15.2 +pandas==1.1.5 +scikit-image==0.17.2 +torchmetrics==0.8.2 +pytorch-fid==0.3.0 +lpips==0.1.4 \ No newline at end of file diff --git a/DINet/run_evaluate.py b/DINet/run_evaluate.py new file mode 100644 index 00000000..0de63972 --- /dev/null +++ b/DINet/run_evaluate.py @@ -0,0 +1,197 @@ +import os +import cv2 +import numpy as np +import torch +import lpips +from skimage.metrics import structural_similarity as ssim +from torchmetrics.image.psnr import PeakSignalNoiseRatio as PSNR +from pytorch_fid import fid_score +import pandas as pd + +# 计算PSNR +def compute_psnr(image1, image2): + global psnr_metric # 声明为全局变量 + image1 = torch.tensor(image1).unsqueeze(0) if isinstance(image1, np.ndarray) else image1 + image2 = torch.tensor(image2).unsqueeze(0) if isinstance(image2, np.ndarray) else image2 + return psnr_metric(image1, image2) + +# 计算SSIM +def compute_ssim(image1, image2): + image1 = np.array(image1) if isinstance(image1, torch.Tensor) else image1 + image2 = np.array(image2) if isinstance(image2, torch.Tensor) else image2 + return ssim(image1, image2, multichannel=True) + +# 计算LPIPS +def compute_lpips(image1, image2): + global lpips_model # 声明为全局变量 + # 将图像转换为 PyTorch tensor 并确保其形状符合要求 (BGR -> RGB, HWC -> CHW) + image1 = torch.tensor(image1).permute(2, 0, 1).unsqueeze(0).float() / 255.0 + image2 = torch.tensor(image2).permute(2, 0, 1).unsqueeze(0).float() / 255.0 + + # 计算 LPIPS 值 + return lpips_model(image1, image2).item() + +# 计算FID +def compute_fid(real_images, fake_images): + print("Computing FID...") + real_dir = './real_images' + fake_dir = './fake_images' + os.makedirs(real_dir, exist_ok=True) + os.makedirs(fake_dir, exist_ok=True) + + for i, img in enumerate(real_images): + cv2.imwrite(os.path.join(real_dir, f"real_{i}.png"), cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + for i, img in enumerate(fake_images): + cv2.imwrite(os.path.join(fake_dir, f"fake_{i}.png"), cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + fid_value = fid_score.calculate_fid_given_paths([real_dir, fake_dir], batch_size=1, device='cuda', dims=2048) + + # 删除临时文件夹 + for img in os.listdir(real_dir): + os.remove(os.path.join(real_dir, img)) + for img in os.listdir(fake_dir): + os.remove(os.path.join(fake_dir, img)) + os.rmdir(real_dir) + os.rmdir(fake_dir) + + print(f"FID computed: {fid_value}") + return fid_value + +def compute_lse_c(real_images, fake_images): + print("Computing LSE-C...") + return np.mean((real_images - fake_images) ** 2) + +def compute_lse_d(real_images, fake_images): + print("Computing LSE-D...") + return np.mean(np.abs(real_images - fake_images)) + +# 评估单个视频的质量(PSNR, SSIM, LPIPS, FID, LSE-C, LSE-D) +def evaluate_video_quality(real_video_path, fake_video_path): + print(f"Evaluating: {real_video_path} vs {fake_video_path}") + real_video = cv2.VideoCapture(real_video_path) + fake_video = cv2.VideoCapture(fake_video_path) + + real_frame_count = int(real_video.get(cv2.CAP_PROP_FRAME_COUNT)) + fake_frame_count = int(fake_video.get(cv2.CAP_PROP_FRAME_COUNT)) + + min_frame_count = min(real_frame_count, fake_frame_count) + + psnr_scores = [] + ssim_scores = [] + lpips_scores = [] # 修改为 LPIPS + frame_idx = 0 + real_images = [] + fake_images = [] + + while True: + ret_real, frame_real = real_video.read() + ret_fake, frame_fake = fake_video.read() + + if ret_real and ret_fake: + frame_real_rgb = cv2.cvtColor(frame_real, cv2.COLOR_BGR2RGB) + frame_fake_rgb = cv2.cvtColor(frame_fake, cv2.COLOR_BGR2RGB) + + # 计算PSNR + psnr_scores.append(compute_psnr(frame_real_rgb, frame_fake_rgb)) + # 计算SSIM + ssim_scores.append(compute_ssim(frame_real_rgb, frame_fake_rgb)) + real_images.append(frame_real_rgb) + fake_images.append(frame_fake_rgb) + # 计算LPIPS + lpips_scores.append(compute_lpips(frame_real_rgb, frame_fake_rgb)) + frame_idx += 1 + + elif ret_fake: + lpips_scores.append(compute_lpips(frame_fake_rgb, frame_fake_rgb)) + else: + break + + avg_psnr = np.mean(psnr_scores) + avg_ssim = np.mean(ssim_scores) + avg_lpips = np.mean(lpips_scores) # 修改为 LPIPS 平均值 + + fid_value = compute_fid(np.array(real_images), np.array(fake_images)) + lse_c_score = compute_lse_c(np.array(real_images), np.array(fake_images)) + lse_d_score = compute_lse_d(np.array(real_images), np.array(fake_images)) + + real_video.release() + fake_video.release() + + print(f"Video evaluated. PSNR: {avg_psnr}, SSIM: {avg_ssim}, LPIPS: {avg_lpips}, FID: {fid_value}, LSE-C: {lse_c_score}, LSE-D: {lse_d_score}") + + return avg_psnr, avg_ssim, avg_lpips, fid_value, lse_c_score, lse_d_score + +# 评估整个测试集 +def evaluate_test_set(real_video_folder, fake_video_folder,evaluate_dir): + print("Evaluating test set...") + video_type = "_facial_dubbing_add_audio.mp4" + psnr_scores = [] + ssim_scores = [] + lpips_scores = [] # 修改为 LPIPS + fid_scores = [] + lse_c_scores = [] + lse_d_scores = [] + + fake_video_list = sorted(os.listdir(fake_video_folder)) + + for fake_video_file in fake_video_list: + if not fake_video_file.endswith(video_type): + continue + + part1 = fake_video_file[:-len(video_type)] + + real_video_path = os.path.join(real_video_folder, part1 + '.mp4') + fake_video_path = os.path.join(fake_video_folder, fake_video_file) + + if not os.path.exists(real_video_path): + print(f"Warning: Real video {real_video_path} not found for {fake_video_path}. Skipping...") + continue + + avg_psnr, avg_ssim, avg_lpips, fid_score, lse_c_score, lse_d_score = evaluate_video_quality(real_video_path, fake_video_path) + + psnr_scores.append(avg_psnr) + ssim_scores.append(avg_ssim) + lpips_scores.append(avg_lpips) # 修改为 LPIPS + fid_scores.append(fid_score) + lse_c_scores.append(lse_c_score) + lse_d_scores.append(lse_d_score) + + avg_psnr = np.mean(psnr_scores) + avg_ssim = np.mean(ssim_scores) + avg_lpips = np.mean(lpips_scores) # 修改为 LPIPS + avg_fid = np.mean(fid_scores) + avg_lse_c = np.mean(lse_c_scores) + avg_lse_d = np.mean(lse_d_scores) + + print(f"Final Results for Test Set:") + print(f"PSNR: {avg_psnr}, SSIM: {avg_ssim}, LPIPS: {avg_lpips}, FID: {avg_fid}, LSE-C: {avg_lse_c}, LSE-D: {avg_lse_d}") + data = { + 'PSNR': [avg_psnr], + 'SSIM': [avg_ssim], + 'LPIPS': [avg_lpips], + 'FID': [avg_fid], + 'LSE-C': [avg_lse_c], + 'LSE-D': [avg_lse_d] + } + df = pd.DataFrame(data) + evaluate_path=os.path.join(evaluate_dir,'evaluate.csv') + df.to_csv(evaluate_path, index=False) + +# 主程序 +if __name__ == "__main__": + print("Starting evaluation...") + # 初始化全局变量 + lpips_model = lpips.LPIPS(net='alex') # 加载 LPIPS 模型 + psnr_metric = PSNR() # 加载 PSNR 计算工具 + + dir = os.path.dirname(os.path.abspath(__file__)) + + real_video_folder = os.path.join(dir, 'asserts', 'test_data', "split_video_25fps") + fake_video_folder = os.path.join(dir, 'asserts', 'inference_result') + evaluate_dir=os.path.join(dir, 'asserts', 'evaluate_result') + if not os.path.exists(evaluate_dir): + os.makedirs(evaluate_dir) + print(f'create: {evaluate_dir}') + # 进行测试集的定性评估 + evaluate_test_set(real_video_folder, fake_video_folder,evaluate_dir=evaluate_dir) diff --git a/DINet/run_inference.py b/DINet/run_inference.py new file mode 100644 index 00000000..79e2230d --- /dev/null +++ b/DINet/run_inference.py @@ -0,0 +1,70 @@ +import os +import subprocess +import argparse + +def get_command_line_args(): + parser = argparse.ArgumentParser(description="Process video files with OpenFace landmarks and audio.") + + # 添加命令行参数 + parser.add_argument('--process_num', type=int, default=1000, help="num for deciding to process videos.") + parser.add_argument('--model_path', type=str, default='./asserts/training_model_weight/clip_training_256/netG_model_epoch_200.pth', help="num for deciding to process videos.") + parser.add_argument('--audio_path', type=str, default="./asserts/examples/driving_audio_1.wav" , help="path of audio") + + # 解析命令行参数 + return parser.parse_args() + +if __name__ =="__main__": + arg= get_command_line_args() + model_path=arg.model_path + if not os.path.exists(model_path): + model_path="./asserts/clip_training_DINet_256mouth.pth" + + print(f"model_path: {model_path}") + + audio_path=arg.audio_path + if not os.path.exists(audio_path): + audio_path="./asserts/examples/driving_audio_1.wav" + print(f"audio_path: {audio_path}") + + # 视频文件夹路径 + video_folder = "./asserts/test_data/split_video_25fps" # 替换为你的实际视频文件夹路径 + landmark_folder = "./asserts/test_data/split_video_25fps_landmark_openface" # 替换为你的landmark文件夹路径 + output_folder = "./asserts/inference_result" # 替换为你希望输出结果的文件夹路径 + #audio_folder = "./asserts/examples/driving_audio_1.wav" # 替换为你的音频文件夹路径 + audio_folder=audio_path + pretrained_model_path = model_path # 替换为你的预训练模型路径 + + # 确保输出文件夹存在 + os.makedirs(output_folder, exist_ok=True) + # 获取所有的视频文件 + video_files = [f for f in os.listdir(video_folder) if f.endswith('.mp4')] + max=len(video_files) + num=0 + + # 批量处理每个视频文件 + for video_file in video_files: + # 这里假设输入视频文件与landmark和音频文件具有相同的文件名 + video_name = os.path.splitext(video_file)[0] # 获取文件名(不带扩展名) + + # 构建输入文件路径 + input_video_path = os.path.join(video_folder, video_file) + input_landmark_path = os.path.join(landmark_folder, video_name + '.csv') + input_audio_path = audio_folder + + # 构建命令并调用模型推理的 Python 脚本 + try: + print(f"Processing video: {video_file}") + subprocess.run([ + "python", "inference.py", + "--mouth_region_size", "256", # 这里的参数根据需求调整 + "--source_video_path", input_video_path, + "--source_openface_landmark_path", input_landmark_path, + "--driving_audio_path", input_audio_path, + "--pretrained_clip_DINet_path", pretrained_model_path, + ], check=True) + print(f"Successfully processed: {video_file}") + except subprocess.CalledProcessError as e: + print(f"Error occurred while processing {video_file}: {e}") + num+=1 + if num>=arg.process_num or num>max: + break \ No newline at end of file diff --git a/DINet/sync_batchnorm/__init__.py b/DINet/sync_batchnorm/__init__.py new file mode 100644 index 00000000..6d9b36c7 --- /dev/null +++ b/DINet/sync_batchnorm/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import set_sbn_eps_mode +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .batchnorm import patch_sync_batchnorm, convert_model +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/DINet/sync_batchnorm/__pycache__/__init__.cpython-36.pyc b/DINet/sync_batchnorm/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 00000000..287dc98f Binary files /dev/null and b/DINet/sync_batchnorm/__pycache__/__init__.cpython-36.pyc differ diff --git a/DINet/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc b/DINet/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc new file mode 100644 index 00000000..f362924b Binary files /dev/null and b/DINet/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc differ diff --git a/DINet/sync_batchnorm/__pycache__/comm.cpython-36.pyc b/DINet/sync_batchnorm/__pycache__/comm.cpython-36.pyc new file mode 100644 index 00000000..481ad881 Binary files /dev/null and b/DINet/sync_batchnorm/__pycache__/comm.cpython-36.pyc differ diff --git a/DINet/sync_batchnorm/__pycache__/replicate.cpython-36.pyc b/DINet/sync_batchnorm/__pycache__/replicate.cpython-36.pyc new file mode 100644 index 00000000..bb76b5c8 Binary files /dev/null and b/DINet/sync_batchnorm/__pycache__/replicate.cpython-36.pyc differ diff --git a/DINet/sync_batchnorm/batchnorm.py b/DINet/sync_batchnorm/batchnorm.py new file mode 100644 index 00000000..bf8d7a73 --- /dev/null +++ b/DINet/sync_batchnorm/batchnorm.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections +import contextlib + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm + +try: + from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast +except ImportError: + ReduceAddCoalesced = Broadcast = None + +try: + from jactorch.parallel.comm import SyncMaster + from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback +except ImportError: + from .comm import SyncMaster + from .replicate import DataParallelWithCallback + +__all__ = [ + 'set_sbn_eps_mode', + 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', + 'patch_sync_batchnorm', 'convert_model' +] + + +SBN_EPS_MODE = 'clamp' + + +def set_sbn_eps_mode(mode): + global SBN_EPS_MODE + assert mode in ('clamp', 'plus') + SBN_EPS_MODE = mode + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dimensions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' + + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + if not self.track_running_stats: + import warnings + warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + if hasattr(torch, 'no_grad'): + with torch.no_grad(): + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + if SBN_EPS_MODE == 'clamp': + return mean, bias_var.clamp(self.eps) ** -0.5 + elif SBN_EPS_MODE == 'plus': + return mean, (bias_var + self.eps) ** -0.5 + else: + raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + + +@contextlib.contextmanager +def patch_sync_batchnorm(): + import torch.nn as nn + + backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + + nn.BatchNorm1d = SynchronizedBatchNorm1d + nn.BatchNorm2d = SynchronizedBatchNorm2d + nn.BatchNorm3d = SynchronizedBatchNorm3d + + yield + + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup + + +def convert_model(module): + """Traverse the input module and its child recursively + and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d + to SynchronizedBatchNorm*N*d + + Args: + module: the input module needs to be convert to SyncBN model + + Examples: + >>> import torch.nn as nn + >>> import torchvision + >>> # m is a standard pytorch model + >>> m = torchvision.models.resnet18(True) + >>> m = nn.DataParallel(m) + >>> # after convert, m is using SyncBN + >>> m = convert_model(m) + """ + if isinstance(module, torch.nn.DataParallel): + mod = module.module + mod = convert_model(mod) + mod = DataParallelWithCallback(mod, device_ids=module.device_ids) + return mod + + mod = module + for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d], + [SynchronizedBatchNorm1d, + SynchronizedBatchNorm2d, + SynchronizedBatchNorm3d]): + if isinstance(module, pth_module): + mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + + for name, child in module.named_children(): + mod.add_module(name, convert_model(child)) + + return mod diff --git a/DINet/sync_batchnorm/batchnorm_reimpl.py b/DINet/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 00000000..18145c33 --- /dev/null +++ b/DINet/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNorm2dReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/DINet/sync_batchnorm/comm.py b/DINet/sync_batchnorm/comm.py new file mode 100644 index 00000000..922f8c4a --- /dev/null +++ b/DINet/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/DINet/sync_batchnorm/replicate.py b/DINet/sync_batchnorm/replicate.py new file mode 100644 index 00000000..b71c7b8e --- /dev/null +++ b/DINet/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/DINet/sync_batchnorm/unittest.py b/DINet/sync_batchnorm/unittest.py new file mode 100644 index 00000000..998223a0 --- /dev/null +++ b/DINet/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) + diff --git a/DINet/train_DINet_clip.py b/DINet/train_DINet_clip.py new file mode 100644 index 00000000..79dc8822 --- /dev/null +++ b/DINet/train_DINet_clip.py @@ -0,0 +1,156 @@ +from models.Discriminator import Discriminator +from models.VGG19 import Vgg19 +from models.DINet import DINet +from models.Syncnet import SyncNetPerception +from utils.training_utils import get_scheduler, update_learning_rate,GANLoss +from config.config import DINetTrainingOptions +from sync_batchnorm import convert_model +from torch.utils.data import DataLoader +from dataset.dataset_DINet_clip import DINetDataset + + +import random +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import os +import torch.nn.functional as F + +torch.cuda.set_per_process_memory_fraction(1.0) # 使用最多 90% 的 GPU 内存 + + +if __name__ == "__main__": + ''' + clip training code of DINet + in the resolution you want, using clip training code after frame training + + ''' + # load config + opt = DINetTrainingOptions().parse_args() + random.seed(opt.seed) + np.random.seed(opt.seed) + torch.cuda.manual_seed(opt.seed) + # load training data + train_data = DINetDataset(opt.train_data,opt.augment_num,opt.mouth_region_size) + training_data_loader = DataLoader(dataset=train_data, batch_size=opt.batch_size, shuffle=True,drop_last=True) + train_data_length = len(training_data_loader) + # init network + net_g = DINet(opt.source_channel,opt.ref_channel,opt.audio_channel).cuda() + net_dI = Discriminator(opt.source_channel ,opt.D_block_expansion, opt.D_num_blocks, opt.D_max_features).cuda() + net_dV = Discriminator(opt.source_channel * 5, opt.D_block_expansion, opt.D_num_blocks, opt.D_max_features).cuda() + net_vgg = Vgg19().cuda() + net_lipsync = SyncNetPerception(opt.pretrained_syncnet_path).cuda() + # parallel + net_g = nn.DataParallel(net_g) + net_g = convert_model(net_g) + net_dI = nn.DataParallel(net_dI) + net_dV = nn.DataParallel(net_dV) + net_vgg = nn.DataParallel(net_vgg) + # setup optimizer + optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr_g) + optimizer_dI = optim.Adam(net_dI.parameters(), lr=opt.lr_dI) + optimizer_dV = optim.Adam(net_dV.parameters(), lr=opt.lr_dI) + ## load frame trained DInet weight + print('loading frame trained DINet weight from: {}'.format(opt.pretrained_frame_DINet_path)) + checkpoint = torch.load(opt.pretrained_frame_DINet_path) + net_g.load_state_dict(checkpoint['state_dict']['net_g']) + # set criterion + criterionGAN = GANLoss().cuda() + criterionL1 = nn.L1Loss().cuda() + criterionMSE = nn.MSELoss().cuda() + # set scheduler + net_g_scheduler = get_scheduler(optimizer_g, opt.non_decay, opt.decay) + net_dI_scheduler = get_scheduler(optimizer_dI, opt.non_decay, opt.decay) + net_dV_scheduler = get_scheduler(optimizer_dV, opt.non_decay, opt.decay) + # set label of syncnet perception loss + real_tensor = torch.tensor(1.0).cuda() + # start train + for epoch in range(opt.start_epoch, opt.non_decay+opt.decay+1): + net_g.train() + for iteration, data in enumerate(training_data_loader): + # forward + source_clip,source_clip_mask, reference_clip,deep_speech_clip,deep_speech_full = data + source_clip = torch.cat(torch.split(source_clip, 1, dim=1), 0).squeeze(1).float().cuda() + source_clip_mask = torch.cat(torch.split(source_clip_mask, 1, dim=1), 0).squeeze(1).float().cuda() + reference_clip = torch.cat(torch.split(reference_clip, 1, dim=1), 0).squeeze(1).float().cuda() + deep_speech_clip = torch.cat(torch.split(deep_speech_clip, 1, dim=1), 0).squeeze(1).float().cuda() + deep_speech_full = deep_speech_full.float().cuda() + fake_out = net_g(source_clip_mask,reference_clip,deep_speech_clip) + fake_out_half = F.avg_pool2d(fake_out, 3, 2, 1, count_include_pad=False) + source_clip_half = F.interpolate(source_clip, scale_factor=0.5, mode='bilinear') + # (1) Update DI network + optimizer_dI.zero_grad() + _,pred_fake_dI = net_dI(fake_out) + loss_dI_fake = criterionGAN(pred_fake_dI, False) + _,pred_real_dI = net_dI(source_clip) + loss_dI_real = criterionGAN(pred_real_dI, True) + # Combined DI loss + loss_dI = (loss_dI_fake + loss_dI_real) * 0.5 + loss_dI.backward(retain_graph=True) + optimizer_dI.step() + + # (2) Update DV network + optimizer_dV.zero_grad() + condition_fake_dV = torch.cat(torch.split(fake_out, opt.batch_size, dim=0), 1) + _, pred_fake_dV = net_dV(condition_fake_dV) + loss_dV_fake = criterionGAN(pred_fake_dV, False) + condition_real_dV = torch.cat(torch.split(source_clip, opt.batch_size, dim=0), 1) + _, pred_real_dV = net_dV(condition_real_dV) + loss_dV_real = criterionGAN(pred_real_dV, True) + # Combined DV loss + loss_dV = (loss_dV_fake + loss_dV_real) * 0.5 + loss_dV.backward(retain_graph=True) + optimizer_dV.step() + + # (2) Update DINet + _, pred_fake_dI = net_dI(fake_out) + _, pred_fake_dV = net_dV(condition_fake_dV) + optimizer_g.zero_grad() + # compute perception loss + perception_real = net_vgg(source_clip) + perception_fake = net_vgg(fake_out) + perception_real_half = net_vgg(source_clip_half) + perception_fake_half = net_vgg(fake_out_half) + loss_g_perception = 0 + for i in range(len(perception_real)): + loss_g_perception += criterionL1(perception_fake[i], perception_real[i]) + loss_g_perception += criterionL1(perception_fake_half[i], perception_real_half[i]) + loss_g_perception = (loss_g_perception / (len(perception_real) * 2)) * opt.lamb_perception + # # gan dI loss + loss_g_dI = criterionGAN(pred_fake_dI, True) + # # gan dV loss + loss_g_dV = criterionGAN(pred_fake_dV, True) + ## sync perception loss + fake_out_clip = torch.cat(torch.split(fake_out, opt.batch_size, dim=0), 1) + fake_out_clip_mouth = fake_out_clip[:, :, train_data.radius:train_data.radius + train_data.mouth_region_size, + train_data.radius_1_4:train_data.radius_1_4 + train_data.mouth_region_size] + sync_score = net_lipsync(fake_out_clip_mouth, deep_speech_full) + loss_sync = criterionMSE(sync_score, real_tensor.expand_as(sync_score)) * opt.lamb_syncnet_perception + # combine all losses + loss_g = loss_g_perception + loss_g_dI +loss_g_dV + loss_sync + loss_g.backward() + optimizer_g.step() + + print( + "===> Epoch[{}]({}/{}): Loss_DI: {:.4f} Loss_GI: {:.4f} Loss_DV: {:.4f} Loss_GV: {:.4f} Loss_perception: {:.4f} Loss_sync: {:.4f} lr_g = {:.7f} ".format( + epoch, iteration, len(training_data_loader), float(loss_dI), float(loss_g_dI),float(loss_dV), float(loss_g_dV), float(loss_g_perception),float(loss_sync), + optimizer_g.param_groups[0]['lr'])) + + update_learning_rate(net_g_scheduler, optimizer_g) + update_learning_rate(net_dI_scheduler, optimizer_dI) + update_learning_rate(net_dV_scheduler, optimizer_dV) + # checkpoint + if epoch % opt.checkpoint == 0: + if not os.path.exists(opt.result_path): + os.mkdir(opt.result_path) + model_out_path = os.path.join(opt.result_path, 'netG_model_epoch_{}.pth'.format(epoch)) + states = { + 'epoch': epoch + 1, + 'state_dict': {'net_g': net_g.state_dict(),'net_dI': net_dI.state_dict(),'net_dV': net_dV.state_dict()}, + 'optimizer': {'net_g': optimizer_g.state_dict(), 'net_dI': optimizer_dI.state_dict(), 'net_dV': optimizer_dV.state_dict()} + } + torch.save(states, model_out_path) + print("Checkpoint saved to {}".format(epoch)) + torch.cuda.empty_cache() + diff --git a/DINet/train_DINet_frame.py b/DINet/train_DINet_frame.py new file mode 100644 index 00000000..70bd3b09 --- /dev/null +++ b/DINet/train_DINet_frame.py @@ -0,0 +1,121 @@ +from models.Discriminator import Discriminator +from models.VGG19 import Vgg19 +from models.DINet import DINet +from utils.training_utils import get_scheduler, update_learning_rate,GANLoss +from torch.utils.data import DataLoader +from dataset.dataset_DINet_frame import DINetDataset +from sync_batchnorm import convert_model +from config.config import DINetTrainingOptions + +import random +import numpy as np +import os +import torch.nn.functional as F +import torch +import torch.nn as nn +import torch.optim as optim + +if __name__ == "__main__": + ''' + frame training code of DINet + we use coarse-to-fine training strategy + so you can use this code to train the model in arbitrary resolution + ''' + # load config + opt = DINetTrainingOptions().parse_args() + # set seed + random.seed(opt.seed) + np.random.seed(opt.seed) + torch.cuda.manual_seed(opt.seed) + # load training data in memory + train_data = DINetDataset(opt.train_data,opt.augment_num,opt.mouth_region_size) + training_data_loader = DataLoader(dataset=train_data, batch_size=opt.batch_size, shuffle=True,drop_last=True) + train_data_length = len(training_data_loader) + # init network + net_g = DINet(opt.source_channel,opt.ref_channel,opt.audio_channel).cuda() + net_dI = Discriminator(opt.source_channel ,opt.D_block_expansion, opt.D_num_blocks, opt.D_max_features).cuda() + net_vgg = Vgg19().cuda() + # parallel + net_g = nn.DataParallel(net_g) + net_g = convert_model(net_g) + net_dI = nn.DataParallel(net_dI) + net_vgg = nn.DataParallel(net_vgg) + # setup optimizer + optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr_g) + optimizer_dI = optim.Adam(net_dI.parameters(), lr=opt.lr_dI) + # coarse2fine + if opt.coarse2fine: + print('loading checkpoint for coarse2fine training: {}'.format(opt.coarse_model_path)) + checkpoint = torch.load(opt.coarse_model_path) + net_g.load_state_dict(checkpoint['state_dict']['net_g']) + # set criterion + criterionGAN = GANLoss().cuda() + criterionL1 = nn.L1Loss().cuda() + # set scheduler + net_g_scheduler = get_scheduler(optimizer_g, opt.non_decay, opt.decay) + net_dI_scheduler = get_scheduler(optimizer_dI, opt.non_decay, opt.decay) + # start train + for epoch in range(opt.start_epoch, opt.non_decay+opt.decay+1): + net_g.train() + for iteration, data in enumerate(training_data_loader): + # read data + source_image_data,source_image_mask, reference_clip_data,deepspeech_feature = data + source_image_data = source_image_data.float().cuda() + source_image_mask = source_image_mask.float().cuda() + reference_clip_data = reference_clip_data.float().cuda() + deepspeech_feature = deepspeech_feature.float().cuda() + # network forward + fake_out = net_g(source_image_mask,reference_clip_data,deepspeech_feature) + # down sample output image and real image + fake_out_half = F.avg_pool2d(fake_out, 3, 2, 1, count_include_pad=False) + target_tensor_half = F.interpolate(source_image_data, scale_factor=0.5, mode='bilinear') + # (1) Update D network + optimizer_dI.zero_grad() + # compute fake loss + _,pred_fake_dI = net_dI(fake_out) + loss_dI_fake = criterionGAN(pred_fake_dI, False) + # compute real loss + _,pred_real_dI = net_dI(source_image_data) + loss_dI_real = criterionGAN(pred_real_dI, True) + # Combined DI loss + loss_dI = (loss_dI_fake + loss_dI_real) * 0.5 + loss_dI.backward(retain_graph=True) + optimizer_dI.step() + # (2) Update G network + _, pred_fake_dI = net_dI(fake_out) + optimizer_g.zero_grad() + # compute perception loss + perception_real = net_vgg(source_image_data) + perception_fake = net_vgg(fake_out) + perception_real_half = net_vgg(target_tensor_half) + perception_fake_half = net_vgg(fake_out_half) + loss_g_perception = 0 + for i in range(len(perception_real)): + loss_g_perception += criterionL1(perception_fake[i], perception_real[i]) + loss_g_perception += criterionL1(perception_fake_half[i], perception_real_half[i]) + loss_g_perception = (loss_g_perception / (len(perception_real) * 2)) * opt.lamb_perception + # # gan dI loss + loss_g_dI = criterionGAN(pred_fake_dI, True) + # combine perception loss and gan loss + loss_g = loss_g_perception + loss_g_dI + loss_g.backward() + optimizer_g.step() + + print( + "===> Epoch[{}]({}/{}): Loss_DI: {:.4f} Loss_GI: {:.4f} Loss_perception: {:.4f} lr_g = {:.7f} ".format( + epoch, iteration, len(training_data_loader), float(loss_dI), float(loss_g_dI), float(loss_g_perception),optimizer_g.param_groups[0]['lr'])) + + update_learning_rate(net_g_scheduler, optimizer_g) + update_learning_rate(net_dI_scheduler, optimizer_dI) + #checkpoint + if epoch % opt.checkpoint == 0: + if not os.path.exists(opt.result_path): + os.mkdir(opt.result_path) + model_out_path = os.path.join(opt.result_path, 'netG_model_epoch_{}.pth'.format(epoch)) + states = { + 'epoch': epoch + 1, + 'state_dict': {'net_g': net_g.state_dict(), 'net_dI': net_dI.state_dict()},# + 'optimizer': {'net_g': optimizer_g.state_dict(), 'net_dI': optimizer_dI.state_dict()}# + } + torch.save(states, model_out_path) + print("Checkpoint saved to {}".format(epoch)) diff --git a/DINet/utils/__pycache__/data_processing.cpython-36.pyc b/DINet/utils/__pycache__/data_processing.cpython-36.pyc new file mode 100644 index 00000000..cdaea793 Binary files /dev/null and b/DINet/utils/__pycache__/data_processing.cpython-36.pyc differ diff --git a/DINet/utils/__pycache__/deep_speech.cpython-36.pyc b/DINet/utils/__pycache__/deep_speech.cpython-36.pyc new file mode 100644 index 00000000..c0d1e629 Binary files /dev/null and b/DINet/utils/__pycache__/deep_speech.cpython-36.pyc differ diff --git a/DINet/utils/__pycache__/training_utils.cpython-36.pyc b/DINet/utils/__pycache__/training_utils.cpython-36.pyc new file mode 100644 index 00000000..79b48d04 Binary files /dev/null and b/DINet/utils/__pycache__/training_utils.cpython-36.pyc differ diff --git a/DINet/utils/data_processing.py b/DINet/utils/data_processing.py new file mode 100644 index 00000000..e13daf38 --- /dev/null +++ b/DINet/utils/data_processing.py @@ -0,0 +1,60 @@ +import csv +import numpy as np +import random + + +def load_landmark_openface(csv_path): + ''' + load openface landmark from .csv file + ''' + with open(csv_path, 'r') as f: + reader = csv.reader(f) + data_all = [row for row in reader] + x_list = [] + y_list = [] + for row_index,row in enumerate(data_all[1:]): + frame_num = float(row[0]) + if int(frame_num)!= row_index+1: + return None + x_list.append([float(x) for x in row[5:5+68]]) + y_list.append([float(y) for y in row[5+68:5+68 + 68]]) + x_array = np.array(x_list) + y_array = np.array(y_list) + landmark_array = np.stack([x_array,y_array],2) + return landmark_array + + +def compute_crop_radius(video_size,landmark_data_clip,random_scale = None): + ''' + judge if crop face and compute crop radius + ''' + video_w, video_h = video_size[0], video_size[1] + landmark_max_clip = np.max(landmark_data_clip, axis=1) + if random_scale is None: + random_scale = random.random() / 10 + 1.05 + else: + random_scale = random_scale + radius_h = (landmark_max_clip[:,1] - landmark_data_clip[:,29, 1]) * random_scale + radius_w = (landmark_data_clip[:,54, 0] - landmark_data_clip[:,48, 0]) * random_scale + radius_clip = np.max(np.stack([radius_h, radius_w],1),1) // 2 + radius_max = np.max(radius_clip) + radius_max = (np.int(radius_max/4) + 1 ) * 4 + radius_max_1_4 = radius_max//4 + clip_min_h = landmark_data_clip[:, 29, + 1] - radius_max + clip_max_h = landmark_data_clip[:, 29, + 1] + radius_max * 2 + radius_max_1_4 + clip_min_w = landmark_data_clip[:, 33, + 0] - radius_max - radius_max_1_4 + clip_max_w = landmark_data_clip[:, 33, + 0] + radius_max + radius_max_1_4 + if min(clip_min_h.tolist() + clip_min_w.tolist()) < 0: + return False,None + elif max(clip_max_h.tolist()) > video_h: + return False,None + elif max(clip_max_w.tolist()) > video_w: + return False,None + elif max(radius_clip) > min(radius_clip) * 1.5: + return False, None + else: + return True,radius_max \ No newline at end of file diff --git a/DINet/utils/deep_speech.py b/DINet/utils/deep_speech.py new file mode 100644 index 00000000..f26a5124 --- /dev/null +++ b/DINet/utils/deep_speech.py @@ -0,0 +1,99 @@ + +import numpy as np +import warnings +import resampy +from scipy.io import wavfile +from python_speech_features import mfcc +import tensorflow as tf + + +class DeepSpeech(): + def __init__(self,model_path): + self.graph, self.logits_ph, self.input_node_ph, self.input_lengths_ph \ + = self._prepare_deepspeech_net(model_path) + self.target_sample_rate = 16000 + + def _prepare_deepspeech_net(self,deepspeech_pb_path): + with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + graph = tf.compat.v1.get_default_graph() + tf.import_graph_def(graph_def, name="deepspeech") + logits_ph = graph.get_tensor_by_name("deepspeech/logits:0") + input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0") + input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0") + + return graph, logits_ph, input_node_ph, input_lengths_ph + + def conv_audio_to_deepspeech_input_vector(self,audio, + sample_rate, + num_cepstrum, + num_context): + # Get mfcc coefficients: + features = mfcc( + signal=audio, + samplerate=sample_rate, + numcep=num_cepstrum) + + # We only keep every second feature (BiRNN stride = 2): + features = features[::2] + + # One stride per time step in the input: + num_strides = len(features) + + # Add empty initial and final contexts: + empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype) + features = np.concatenate((empty_context, features, empty_context)) + + # Create a view into the array with overlapping strides of size + # numcontext (past) + 1 (present) + numcontext (future): + window_size = 2 * num_context + 1 + train_inputs = np.lib.stride_tricks.as_strided( + features, + shape=(num_strides, window_size, num_cepstrum), + strides=(features.strides[0], + features.strides[0], features.strides[1]), + writeable=False) + + # Flatten the second and third dimensions: + train_inputs = np.reshape(train_inputs, [num_strides, -1]) + + train_inputs = np.copy(train_inputs) + train_inputs = (train_inputs - np.mean(train_inputs)) / \ + np.std(train_inputs) + + return train_inputs + + def compute_audio_feature(self,audio_path): + audio_sample_rate, audio = wavfile.read(audio_path) + if audio.ndim != 1: + warnings.warn( + "Audio has multiple channels, the first channel is used") + audio = audio[:, 0] + if audio_sample_rate != self.target_sample_rate: + resampled_audio = resampy.resample( + x=audio.astype(np.float), + sr_orig=audio_sample_rate, + sr_new=self.target_sample_rate) + else: + resampled_audio = audio.astype(np.float) + with tf.compat.v1.Session(graph=self.graph) as sess: + input_vector = self.conv_audio_to_deepspeech_input_vector( + audio=resampled_audio.astype(np.int16), + sample_rate=self.target_sample_rate, + num_cepstrum=26, + num_context=9) + network_output = sess.run( + self.logits_ph, + feed_dict={ + self.input_node_ph: input_vector[np.newaxis, ...], + self.input_lengths_ph: [input_vector.shape[0]]}) + ds_features = network_output[::2,0,:] + return ds_features + +if __name__ == '__main__': + audio_path = r'./00168.wav' + model_path = r'./output_graph.pb' + DSModel = DeepSpeech(model_path) + ds_feature = DSModel.compute_audio_feature(audio_path) + print(ds_feature) \ No newline at end of file diff --git a/DINet/utils/training_utils.py b/DINet/utils/training_utils.py new file mode 100644 index 00000000..30f186f6 --- /dev/null +++ b/DINet/utils/training_utils.py @@ -0,0 +1,51 @@ +from torch.optim import lr_scheduler +import torch.nn as nn +import torch + +def get_scheduler(optimizer, niter,niter_decay,lr_policy='lambda',lr_decay_iters=50): + ''' + scheduler in training stage + ''' + if lr_policy == 'lambda': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch - niter) / float(niter_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_decay_iters, gamma=0.1) + elif lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=niter, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy) + return scheduler + +def update_learning_rate(scheduler, optimizer): + scheduler.step() + lr = optimizer.param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + +class GANLoss(nn.Module): + ''' + GAN loss + ''' + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(input) + + def forward(self, input, target_is_real): + target_tensor = self.get_target_tensor(input, target_is_real) + return self.loss(input, target_tensor) \ No newline at end of file diff --git "a/DINet/\351\205\215\347\275\256\346\226\207\346\241\243.txt" "b/DINet/\351\205\215\347\275\256\346\226\207\346\241\243.txt" new file mode 100644 index 00000000..ce6b0887 --- /dev/null +++ "b/DINet/\351\205\215\347\275\256\346\226\207\346\241\243.txt" @@ -0,0 +1,81 @@ +在 Google Drive 中下载资源 (asserts.zip)。解压缩并将 dir 放入 ./ 中 + +(注意:要在该asserts文件下创建一个文件夹test_data,然后再test_data下面创建两个文件夹split_video_25fps和split_video_25fps_landmark_openface,用来存放测试集和推理数据) + + +为了便于挂载获取需要提前宿主机下载好整个项目DINet + +拉取项目镜像: +docker pull kevia/dinet:latest + +使用镜像: +1.数据处理: + + (1)从 HDTF 数据集下载视频。根据 xx_annotion_time.txt 分割视频,不裁剪和调整视频大小。 + (2)将所有分割的视频重新采样为 25fps,并将视频放入 “./asserts/training_data/split_video_25fps”。 + (3)使用 openface 检测所有视频的平滑面部特征点。将所有 “.csv” 结果放入 “./asserts/training_data/split_video_25fps_landmark_openface” 中。 + 注意:(1),(2),(3)需要处理好数据加入上述对应项目文件夹下面,这里面处理了HDTF数据集,因此可以将./asserts/training_data/split_video_25fps和 + ./asserts/training_data/split_video_25fps_landmark_openface里相对应部分数据移动到./asserts/test_data/split_video_25fps和 + ./asserts/test_data/split_video_25fps_landmark_openface下面作为测试集数据 + + 并且完成上述可以使用预训练模型推理,如果要训练必须进行下面操作对训练数据进行处理(由于提交下面处理数据占用空间过大没法提交) + + (4)从训练集视频中提取帧并将帧保存在 “./asserts/trainning_data/split_video_25fps_frame” 中。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest data_processing.py --extract_video_frame + + (5)从训练集视频中提取音频,并将音频保存在 ./asserts/trainning_data/split_video_25fps_audio 中。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest data_processing.py --extract_audio + + (6)从训练集音频中提取 deepspeech 特征并将特征保存在 “./asserts/trainning_data/split_video_25fps_deepspeech” 中。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest data_processing.py --extract_deep_speech + + (7)裁剪训练集视频的人脸并将图像保存在 “./asserts/trainning_data/split_video_25fps_crop_face” 中。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest data_processing.py --crop_face + + (8)生成训练 json 文件 “./asserts/trainning_data/training_json.json”。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest data_processing.py --generate_training_json + +2.训练过程: + + 我们将训练过程分为帧训练阶段和 clip 训练阶段。在帧训练阶段,我们使用从粗到细的策略,因此您可以在任意分辨率下训练模型。 + (1)框架训练阶段。 + 在帧训练阶段,我们只使用感知损失和 GAN 损失。 + 首先,以 104x80(嘴部区域为 64x64)分辨率训练 DINet。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest train_DINet_frame.py \ + --augment_num=32 --mouth_region_size=64 --batch_size=24 --result_path=./asserts/training_model_weight/frame_training_64 + 您可以在损失收敛时停止训练(我们在大约 270 个 epoch 中停止)。 + + (2)加载预训练模型(面部:104x80 & 嘴巴:64x64)并以更高分辨率训练DINet(面部:208x160 & 嘴巴:128x128)。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest train_DINet_frame.py \ + --augment_num=100 --mouth_region_size=128 --batch_size=80 --coarse2fine \ + --coarse_model_path=./asserts/training_model_weight/frame_training_64/xxxxxx.pth --result_path=./asserts/training_model_weight/frame_training_128 + 您可以在损失收敛时停止训练(我们在大约 200 个 epoch 中停止)。 + + (3)加载预训练模型(面部:208x160 & 嘴巴:128x128)并以更高分辨率训练DINet(面部:416x320 & 嘴巴:256x256)。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest train_DINet_frame.py \ + --augment_num=20 --mouth_region_size=256 --batch_size=12 --coarse2fine \ + --coarse_model_path=./asserts/training_model_weight/frame_training_128/xxxxxx.pth --result_path=./asserts/training_model_weight/frame_training_256 + 您可以在损失收敛时停止训练(我们在大约 200 个 epoch 中停止)。 + + (4)在剪辑训练阶段,我们使用感知损失、帧/剪辑 GAN 损失和同步损失。加载预训练的帧模型(面部:416x320 & 嘴巴:256x256),预训练的同步网络模型(嘴巴:256x256)并在剪辑设置中训练DINet。跑 + + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest train_DINet_clip.py \ + --augment_num=3 --mouth_region_size=256 --batch_size=3 \ + --pretrained_syncnet_path=./asserts/syncnet_256mouth.pth --pretrained_frame_DINet_path=./asserts/training_model_weight/frame_training_256/xxxxx.pth \ + --result_path=./asserts/training_model_weight/clip_training_256 + 您可以在损失收敛时停止训练并选择最佳模型(我们的最佳模型为 160 epoch)。 + +3.推理评估阶段 + + (1)docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest run_inference.py --process_num x + process_num 遍历推理个数,结果保存在./asserts/inference_result里面 + 推理阶段默认使用预训练模型,如果使用训练模型加入参数 --model_path ./asserts/training_model_weight/clip_training_256/xxxx.pth + 可以选择音频路径 --audio_path /path/to/your_auido + + + (2)docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest run_evaluate.py + 评估结果保存在./asserts/evaluate_result里面。 + + 注意:如果像推理自己的数据: + 和上述数据处理一样的步骤(1),(2),(3)放到./asserts/test_data文件里相应位置。 +