From b68e6b9de596257e5d69a7196708b938dec0ae0e Mon Sep 17 00:00:00 2001 From: KeviaWang <104704083+KeviaWang@users.noreply.github.com> Date: Sun, 22 Dec 2024 16:14:14 +0800 Subject: [PATCH 1/3] Add files via upload --- DINet/Dockerfile | 29 ++ DINet/README.md | 88 ++++ .../config/__pycache__/config.cpython-36.pyc | Bin 0 -> 5725 bytes DINet/config/config.py | 110 +++++ DINet/data_processing.py | 186 ++++++++ .../dataset_DINet_clip.cpython-36.pyc | Bin 0 -> 2956 bytes .../dataset_DINet_frame.cpython-36.pyc | Bin 0 -> 2658 bytes DINet/dataset/dataset_DINet_clip.py | 111 +++++ DINet/dataset/dataset_DINet_frame.py | 76 ++++ DINet/inference.py | 173 ++++++++ DINet/models/DINet.py | 299 +++++++++++++ DINet/models/Discriminator.py | 39 ++ DINet/models/Syncnet.py | 215 +++++++++ DINet/models/VGG19.py | 47 ++ DINet/models/__pycache__/DINet.cpython-36.pyc | Bin 0 -> 9388 bytes .../__pycache__/Discriminator.cpython-36.pyc | Bin 0 -> 1910 bytes .../models/__pycache__/Syncnet.cpython-36.pyc | Bin 0 -> 7737 bytes DINet/models/__pycache__/VGG19.cpython-36.pyc | Bin 0 -> 1751 bytes DINet/models/old/Syncnet_BN.py | 213 +++++++++ DINet/models/old/Syncnet_halfBN.py | 200 +++++++++ DINet/requirements.txt | 11 + DINet/run_evaluate.py | 197 +++++++++ DINet/run_inference.py | 70 +++ DINet/sync_batchnorm/__init__.py | 14 + .../__pycache__/__init__.cpython-36.pyc | Bin 0 -> 465 bytes .../__pycache__/batchnorm.cpython-36.pyc | Bin 0 -> 15307 bytes .../__pycache__/comm.cpython-36.pyc | Bin 0 -> 4744 bytes .../__pycache__/replicate.cpython-36.pyc | Bin 0 -> 3420 bytes DINet/sync_batchnorm/batchnorm.py | 412 ++++++++++++++++++ DINet/sync_batchnorm/batchnorm_reimpl.py | 74 ++++ DINet/sync_batchnorm/comm.py | 137 ++++++ DINet/sync_batchnorm/replicate.py | 94 ++++ DINet/sync_batchnorm/unittest.py | 29 ++ DINet/train_DINet_clip.py | 156 +++++++ DINet/train_DINet_frame.py | 121 +++++ .../data_processing.cpython-36.pyc | Bin 0 -> 2098 bytes .../__pycache__/deep_speech.cpython-36.pyc | Bin 0 -> 2904 bytes .../__pycache__/training_utils.cpython-36.pyc | Bin 0 -> 2300 bytes DINet/utils/data_processing.py | 60 +++ DINet/utils/deep_speech.py | 99 +++++ DINet/utils/training_utils.py | 51 +++ ...5\347\275\256\346\226\207\346\241\243.txt" | 81 ++++ 42 files changed, 3392 insertions(+) create mode 100644 DINet/Dockerfile create mode 100644 DINet/README.md create mode 100644 DINet/config/__pycache__/config.cpython-36.pyc create mode 100644 DINet/config/config.py create mode 100644 DINet/data_processing.py create mode 100644 DINet/dataset/__pycache__/dataset_DINet_clip.cpython-36.pyc create mode 100644 DINet/dataset/__pycache__/dataset_DINet_frame.cpython-36.pyc create mode 100644 DINet/dataset/dataset_DINet_clip.py create mode 100644 DINet/dataset/dataset_DINet_frame.py create mode 100644 DINet/inference.py create mode 100644 DINet/models/DINet.py create mode 100644 DINet/models/Discriminator.py create mode 100644 DINet/models/Syncnet.py create mode 100644 DINet/models/VGG19.py create mode 100644 DINet/models/__pycache__/DINet.cpython-36.pyc create mode 100644 DINet/models/__pycache__/Discriminator.cpython-36.pyc create mode 100644 DINet/models/__pycache__/Syncnet.cpython-36.pyc create mode 100644 DINet/models/__pycache__/VGG19.cpython-36.pyc create mode 100644 DINet/models/old/Syncnet_BN.py create mode 100644 DINet/models/old/Syncnet_halfBN.py create mode 100644 DINet/requirements.txt create mode 100644 DINet/run_evaluate.py create mode 100644 DINet/run_inference.py create mode 100644 DINet/sync_batchnorm/__init__.py create mode 100644 DINet/sync_batchnorm/__pycache__/__init__.cpython-36.pyc create mode 100644 DINet/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc create mode 100644 DINet/sync_batchnorm/__pycache__/comm.cpython-36.pyc create mode 100644 DINet/sync_batchnorm/__pycache__/replicate.cpython-36.pyc create mode 100644 DINet/sync_batchnorm/batchnorm.py create mode 100644 DINet/sync_batchnorm/batchnorm_reimpl.py create mode 100644 DINet/sync_batchnorm/comm.py create mode 100644 DINet/sync_batchnorm/replicate.py create mode 100644 DINet/sync_batchnorm/unittest.py create mode 100644 DINet/train_DINet_clip.py create mode 100644 DINet/train_DINet_frame.py create mode 100644 DINet/utils/__pycache__/data_processing.cpython-36.pyc create mode 100644 DINet/utils/__pycache__/deep_speech.cpython-36.pyc create mode 100644 DINet/utils/__pycache__/training_utils.cpython-36.pyc create mode 100644 DINet/utils/data_processing.py create mode 100644 DINet/utils/deep_speech.py create mode 100644 DINet/utils/training_utils.py create mode 100644 "DINet/\351\205\215\347\275\256\346\226\207\346\241\243.txt" diff --git a/DINet/Dockerfile b/DINet/Dockerfile new file mode 100644 index 00000000..a83091fa --- /dev/null +++ b/DINet/Dockerfile @@ -0,0 +1,29 @@ +# 使用 Python 3.6.13-slim 作为基础镜像 +FROM python:3.6.13-slim + +# 更新镜像并安装必要的系统库(包括 ffmpeg 和其他依赖) +RUN apt-get update --fix-missing && apt-get install -y \ + ffmpeg \ + libsm6 \ + libxext6 \ + libx264-dev \ + && rm -rf /var/lib/apt/lists/* + +# 设置工作目录 +WORKDIR /app + +# 将当前目录中的所有文件复制到容器的 /app 目录 +COPY . /app + +RUN pip install --upgrade pip +# 安装 PyTorch、TorchVision 和 Torchaudio +#RUN pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 +RUN pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html +# 安装项目依赖项 +RUN pip install --no-cache-dir -r requirements.txt + +# 设置 ENTRYPOINT 为 python +ENTRYPOINT ["python"] + +# 设置默认命令为 run_evaluate.py --process_num 1 +CMD ["run_evaluate.py"] diff --git a/DINet/README.md b/DINet/README.md new file mode 100644 index 00000000..d093c347 --- /dev/null +++ b/DINet/README.md @@ -0,0 +1,88 @@ +# DINet: Deformation Inpainting Network for Realistic Face Visually Dubbing on High Resolution Video (AAAI2023) +![在这里插入图片描述](https://img-blog.csdnimg.cn/178c6b3ec0074af7a2dcc9ef26450e75.png) +[Paper](https://fuxivirtualhuman.github.io/pdf/AAAI2023_FaceDubbing.pdf)         [demo video](https://www.youtube.com/watch?v=UU344T-9h7M&t=6s)      Supplementary materials + +## Inference +##### Download resources (asserts.zip) in [Google drive](https://drive.google.com/drive/folders/1rPtOo9Uuhc59YfFVv4gBmkh0_oG0nCQb?usp=share_link). unzip and put dir in ./. ++ Inference with example videos. Run + ```python +python inference.py --mouth_region_size=256 --source_video_path=./asserts/examples/testxxx.mp4 --source_openface_landmark_path=./asserts/examples/testxxx.csv --driving_audio_path=./asserts/examples/driving_audio_xxx.wav --pretrained_clip_DINet_path=./asserts/clip_training_DINet_256mouth.pth +``` +The results are saved in ./asserts/inference_result + ++ Inference with custom videos. +**Note:** The released pretrained model is trained on HDTF dataset with 363 training videos (video names are in ./asserts/training_video_name.txt), so the generalization is limited. It would be better to test custom videos with normal lighting, frontal view etc.(see the limitation section in the paper). **We also release the training code**, so if a larger high resolution audio-visual dataset is proposed in the further, you can use the training code to train a model with greater generalization. Besides, we release coarse-to-fine training strategy, **so you can use the training code to train a model in arbitrary resolution** (larger than 416x320 if gpu memory and training dataset are available). + +Using [openface](https://github.com/TadasBaltrusaitis/OpenFace) to detect smooth facial landmarks of your custom video. We run the **OpenFaceOffline.exe** on windows 10 system with this setting: + +| Record | Recording settings | OpenFace setting | View | Face Detector | Landmark Detector | +|--|--|--|--|--|--| +| 2D landmark & tracked videos | Mask aligned image | Use dynamic AU models | Show video | Openface (MTCNN)| CE-CLM | + +The detected facial landmarks are saved in "xxxx.csv". Run + ```python +python inference.py --mouth_region_size=256 --source_video_path= custom video path --source_openface_landmark_path= detected landmark path --driving_audio_path= driving audio path --pretrained_clip_DINet_path=./asserts/clip_training_DINet_256mouth.pth +``` +to realize face visually dubbing on your custom videos. +## Training +### Data Processing +We release the code of video processing on [HDTF dataset](https://github.com/MRzzm/HDTF). You can also use this code to process custom videos. + + 1. Downloading videos from [HDTF dataset](https://github.com/MRzzm/HDTF). Splitting videos according to xx_annotion_time.txt and **do not** crop&resize videos. + 2. Resampling all split videos into **25fps** and put videos into "./asserts/split_video_25fps". You can see the two example videos in "./asserts/split_video_25fps". We use [software](http://www.pcfreetime.com/formatfactory/cn/index.html) to resample videos. We provide the name list of training videos in our experiment. (pls see "./asserts/training_video_name.txt") + 3. Using [openface](https://github.com/TadasBaltrusaitis/OpenFace) to detect smooth facial landmarks of all videos. Putting all ".csv" results into "./asserts/split_video_25fps_landmark_openface". You can see the two example csv files in "./asserts/split_video_25fps_landmark_openface". + + 4. Extracting frames from all videos and saving frames in "./asserts/split_video_25fps_frame". Run +```python +python data_processing.py --extract_video_frame +``` + 5. Extracting audios from all videos and saving audios in "./asserts/split_video_25fps_audio". Run + ```python +python data_processing.py --extract_audio +``` + 6. Extracting deepspeech features from all audios and saving features in "./asserts/split_video_25fps_deepspeech". Run + ```python +python data_processing.py --extract_deep_speech +``` + 7. Cropping faces from all videos and saving images in "./asserts/split_video_25fps_crop_face". Run + ```python +python data_processing.py --crop_face +``` + 8. Generating training json file "./asserts/training_json.json". Run + ```python +python data_processing.py --generate_training_json +``` + +### Training models +We split the training process into **frame training stage** and **clip training stage**. In frame training stage, we use coarse-to-fine strategy, **so you can train the model in arbitrary resolution**. + +#### Frame training stage. +In frame training stage, we only use perception loss and GAN loss. + + 1. Firstly, train the DINet in 104x80 (mouth region is 64x64) resolution. Run + ```python +python train_DINet_frame.py --augment_num=32 --mouth_region_size=64 --batch_size=24 --result_path=./asserts/training_model_weight/frame_training_64 +``` +You can stop the training when the loss converges (we stop in about 270 epoch). + + 2. Loading the pretrained model (face:104x80 & mouth:64x64) and train the DINet in higher resolution (face:208x160 & mouth:128x128). Run + ```python +python train_DINet_frame.py --augment_num=100 --mouth_region_size=128 --batch_size=80 --coarse2fine --coarse_model_path=./asserts/training_model_weight/frame_training_64/xxxxxx.pth --result_path=./asserts/training_model_weight/frame_training_128 +``` +You can stop the training when the loss converges (we stop in about 200 epoch). + + 3. Loading the pretrained model (face:208x160 & mouth:128x128) and train the DINet in higher resolution (face:416x320 & mouth:256x256). Run + ```python +python train_DINet_frame.py --augment_num=20 --mouth_region_size=256 --batch_size=12 --coarse2fine --coarse_model_path=./asserts/training_model_weight/frame_training_128/xxxxxx.pth --result_path=./asserts/training_model_weight/frame_training_256 +``` +You can stop the training when the loss converges (we stop in about 200 epoch). + +#### Clip training stage. +In clip training stage, we use perception loss, frame/clip GAN loss and sync loss. Loading the pretrained frame model (face:416x320 & mouth:256x256), pretrained syncnet model (mouth:256x256) and train the DINet in clip setting. Run + ```python +python train_DINet_clip.py --augment_num=3 --mouth_region_size=256 --batch_size=3 --pretrained_syncnet_path=./asserts/syncnet_256mouth.pth --pretrained_frame_DINet_path=./asserts/training_model_weight/frame_training_256/xxxxx.pth --result_path=./asserts/training_model_weight/clip_training_256 +``` +You can stop the training when the loss converges and select the best model (our best model is at 160 epoch). + +## Acknowledge +The AdaAT is borrowed from [AdaAT](https://github.com/MRzzm/AdaAT). The deepspeech feature is borrowed from [AD-NeRF](https://github.com/YudongGuo/AD-NeRF). The basic module is borrowed from [first-order](https://github.com/AliaksandrSiarohin/first-order-model). Thanks for their released code. \ No newline at end of file diff --git a/DINet/config/__pycache__/config.cpython-36.pyc b/DINet/config/__pycache__/config.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4f299cc66789c1276764e3892be30b00429f265 GIT binary patch literal 5725 zcmcIo%W~Vu6$KxXXuWLfVLdF#@Jli$CKb!Fl6WR#J66?<7f~iPNoC5VDin-vlCVJl z=?3-anw@7mACUY(J|TZF>nyoU76~hVkWF%K10Vp28avK}EHsJJeQ)3HzPRUJJe--C z{QKO$ov$_w<8MapQ^EhIxa1xhX0VcHyfi*DSeaG6F<9ko`DF=rRW^pZG2E4LH_j$- zH-WnfD}8CyYXis{2M)FAr;I;Oamh6_+;~}H#>+A*vGO+tq$;e6UucE&czi?Ygp7v{i@_%>raN5fR2APpFckQE&Ao*@17le_SNTy`&_<>g79$v#Q~2F zoxpFo?c}E!p7v{&<@#=9SyeKF>@iAeW2EKHbZB7-9v>&t=-&RT~ zG6rP4fJ^|X6_86nCJV?Ekm&+46P4L4yUeb9TOI1Y3T%$e7qA6ji|kqfTLQMsRtnfE zu^Y&7^uBy107hqdR#&_ z=*DVferCpbXvvUsr=!+VIRecXG+W$`dV)(e)o3^(2(6awaCM3Pn)>f@qtWI*7j{HO z+i)JgZM~L(udb)Knc^nl+8{O8Xvm-^9G)zmxkA;O+qQ(SMshoquo#}1Z5euS$#|D` z9<@TLHbOh9*g0I(_D9A)O-U2>M9$5&yS3V>g*i zb?RG5cILc6>)d^RJ4rP0qxn>E&X34Thwh_wKYV}PLqWBesvi2E-eUMBd6q@J>Z6Mk z4}FTpHnocm^`(8SII>K4O$*^(QS0M)y4WKW$mR z@7Z45BWSa)qq{n9GGEKz^2(uVcaY*a9{z8S+I!B zgP|6{gZVkGL3Rl|n8#Ltt%3*h*c!0)GuTE{1uvR~aW_X`x7ck+gBx>e+hkk7?iH}Q z#y%=wKfnLPOyzJ90Yl9LqGv{4j>=9+PG#QuFNYD=XwkHFUx~n54lg$J}jqBFZ4LF4GI7 zR4a*yJoO!)M>)~3>N~t&EY26L$t5xGzN6%V|y8Q?9;(_ zq+tt3Lu=70knJ72$c|d>37oB95ZR%{)LJ6pxjroRPC-Y|7Auf+UQUNQZ^8h~GoTQy z?j!;kF3;NfET_Dit35Tjo+tAn zF^d&!mAal6KSR${A{AI!tdht|UL@w`7Ez+9-$}HN seo2-MOH4xQ)wt55XeTH_3qnWLsQua}acTQ0C1;Wf?OJ(uWA^@k0HO&>_W%F@ literal 0 HcmV?d00001 diff --git a/DINet/config/config.py b/DINet/config/config.py new file mode 100644 index 00000000..7f449163 --- /dev/null +++ b/DINet/config/config.py @@ -0,0 +1,110 @@ +import argparse + +class DataProcessingOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser() + + def parse_args(self): + self.parser.add_argument('--extract_video_frame', action='store_true', help='extract video frame') + self.parser.add_argument('--extract_audio', action='store_true', help='extract audio files from videos') + self.parser.add_argument('--extract_deep_speech', action='store_true', help='extract deep speech features') + self.parser.add_argument('--crop_face', action='store_true', help='crop face') + self.parser.add_argument('--generate_training_json', action='store_true', help='generate training json file') + + self.parser.add_argument('--source_video_dir', type=str, default="./asserts/training_data/split_video_25fps", + help='path of source video in 25 fps') + self.parser.add_argument('--openface_landmark_dir', type=str, default="./asserts/training_data/split_video_25fps_landmark_openface", + help='path of openface landmark dir') + self.parser.add_argument('--video_frame_dir', type=str, default="./asserts/training_data/split_video_25fps_frame", + help='path of video frames') + self.parser.add_argument('--audio_dir', type=str, default="./asserts/training_data/split_video_25fps_audio", + help='path of audios') + self.parser.add_argument('--deep_speech_dir', type=str, default="./asserts/training_data/split_video_25fps_deepspeech", + help='path of deep speech') + self.parser.add_argument('--crop_face_dir', type=str, default="./asserts/training_data/split_video_25fps_crop_face", + help='path of crop face dir') + self.parser.add_argument('--json_path', type=str, default="./asserts/training_data/training_json.json", + help='path of training json') + self.parser.add_argument('--clip_length', type=int, default=9, help='clip length') + self.parser.add_argument('--deep_speech_model', type=str, default="./asserts/output_graph.pb", + help='path of pretrained deepspeech model') + return self.parser.parse_args() + +class DINetTrainingOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser() + + def parse_args(self): + self.parser.add_argument('--seed', type=int, default=456, help='random seed to use.') + self.parser.add_argument('--source_channel', type=int, default=3, help='input source image channels') + self.parser.add_argument('--ref_channel', type=int, default=15, help='input reference image channels') + self.parser.add_argument('--audio_channel', type=int, default=29, help='input audio channels') + self.parser.add_argument('--augment_num', type=int, default=32, help='augment training data') + self.parser.add_argument('--mouth_region_size', type=int, default=64, help='augment training data') + self.parser.add_argument('--train_data', type=str, default=r"./asserts/training_data/training_json.json", + help='path of training json') + self.parser.add_argument('--batch_size', type=int, default=24, help='training batch size') + self.parser.add_argument('--lamb_perception', type=int, default=10, help='weight of perception loss') + self.parser.add_argument('--lamb_syncnet_perception', type=int, default=0.1, help='weight of perception loss') + self.parser.add_argument('--lr_g', type=float, default=0.0001, help='initial learning rate for adam') + self.parser.add_argument('--lr_dI', type=float, default=0.0001, help='initial learning rate for adam') + self.parser.add_argument('--start_epoch', default=1, type=int, help='start epoch in training stage') + self.parser.add_argument('--non_decay', default=200, type=int, help='num of epoches with fixed learning rate') + self.parser.add_argument('--decay', default=200, type=int, help='num of linearly decay epochs') + self.parser.add_argument('--checkpoint', type=int, default=2, help='num of checkpoints in training stage') + self.parser.add_argument('--result_path', type=str, default=r"./asserts/training_model_weight/frame_training_64", + help='result path to save model') + self.parser.add_argument('--coarse2fine', action='store_true', help='If true, load pretrained model path.') + self.parser.add_argument('--coarse_model_path', + default='', + type=str, + help='Save data (.pth) of previous training') + self.parser.add_argument('--pretrained_syncnet_path', + default='', + type=str, + help='Save data (.pth) of pretrained syncnet') + self.parser.add_argument('--pretrained_frame_DINet_path', + default='', + type=str, + help='Save data (.pth) of frame trained DINet') + # ========================= Discriminator ========================== + self.parser.add_argument('--D_num_blocks', type=int, default=4, help='num of down blocks in discriminator') + self.parser.add_argument('--D_block_expansion', type=int, default=64, help='block expansion in discriminator') + self.parser.add_argument('--D_max_features', type=int, default=256, help='max channels in discriminator') + return self.parser.parse_args() + + +class DINetInferenceOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser() + + def parse_args(self): + self.parser.add_argument('--source_channel', type=int, default=3, help='channels of source image') + self.parser.add_argument('--ref_channel', type=int, default=15, help='channels of reference image') + self.parser.add_argument('--audio_channel', type=int, default=29, help='channels of audio feature') + self.parser.add_argument('--mouth_region_size', type=int, default=256, help='help to resize window') + self.parser.add_argument('--source_video_path', + default='./asserts/examples/test4.mp4', + type=str, + help='path of source video') + self.parser.add_argument('--source_openface_landmark_path', + default='./asserts/examples/test4.csv', + type=str, + help='path of detected openface landmark') + self.parser.add_argument('--driving_audio_path', + default='./asserts/examples/driving_audio_1.wav', + type=str, + help='path of driving audio') + self.parser.add_argument('--pretrained_clip_DINet_path', + default='./asserts/clip_training_DINet_256mouth.pth', + type=str, + help='pretrained model of DINet(clip trained)') + self.parser.add_argument('--deepspeech_model_path', + default='./asserts/output_graph.pb', + type=str, + help='path of deepspeech model') + self.parser.add_argument('--res_video_dir', + default='./asserts/inference_result', + type=str, + help='path of generated videos') + return self.parser.parse_args() \ No newline at end of file diff --git a/DINet/data_processing.py b/DINet/data_processing.py new file mode 100644 index 00000000..1afad399 --- /dev/null +++ b/DINet/data_processing.py @@ -0,0 +1,186 @@ +import glob +import os +import subprocess +import cv2 +import numpy as np +import json + +from utils.data_processing import load_landmark_openface,compute_crop_radius +from utils.deep_speech import DeepSpeech +from config.config import DataProcessingOptions + +def extract_audio(source_video_dir,res_audio_dir): + ''' + extract audio files from videos + ''' + if not os.path.exists(source_video_dir): + raise ('wrong path of video dir') + if not os.path.exists(res_audio_dir): + os.mkdir(res_audio_dir) + video_path_list = glob.glob(os.path.join(source_video_dir, '*.mp4')) + for video_path in video_path_list: + print('extract audio from video: {}'.format(os.path.basename(video_path))) + audio_path = os.path.join(res_audio_dir, os.path.basename(video_path).replace('.mp4', '.wav')) + cmd = 'ffmpeg -i {} -f wav -ar 16000 {}'.format(video_path, audio_path) + subprocess.call(cmd, shell=True) + +def extract_deep_speech(audio_dir,res_deep_speech_dir,deep_speech_model_path): + ''' + extract deep speech feature + ''' + if not os.path.exists(res_deep_speech_dir): + os.mkdir(res_deep_speech_dir) + DSModel = DeepSpeech(deep_speech_model_path) + wav_path_list = glob.glob(os.path.join(audio_dir, '*.wav')) + for wav_path in wav_path_list: + video_name = os.path.basename(wav_path).replace('.wav', '') + res_dp_path = os.path.join(res_deep_speech_dir, video_name + '_deepspeech.txt') + if os.path.exists(res_dp_path): + os.remove(res_dp_path) + print('extract deep speech feature from audio:{}'.format(video_name)) + ds_feature = DSModel.compute_audio_feature(wav_path) + np.savetxt(res_dp_path, ds_feature) + +def extract_video_frame(source_video_dir,res_video_frame_dir): + ''' + extract video frames from videos + ''' + if not os.path.exists(source_video_dir): + raise ('wrong path of video dir') + if not os.path.exists(res_video_frame_dir): + os.mkdir(res_video_frame_dir) + video_path_list = glob.glob(os.path.join(source_video_dir, '*.mp4')) + for video_path in video_path_list: + video_name = os.path.basename(video_path) + frame_dir = os.path.join(res_video_frame_dir, video_name.replace('.mp4', '')) + if not os.path.exists(frame_dir): + os.makedirs(frame_dir) + print('extracting frames from {} ...'.format(video_name)) + videoCapture = cv2.VideoCapture(video_path) + fps = videoCapture.get(cv2.CAP_PROP_FPS) + if int(fps) != 25: + raise ('{} video is not in 25 fps'.format(video_path)) + frames = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) + for i in range(int(frames)): + ret, frame = videoCapture.read() + result_path = os.path.join(frame_dir, str(i).zfill(6) + '.jpg') + cv2.imwrite(result_path, frame) + + +def crop_face_according_openfaceLM(openface_landmark_dir,video_frame_dir,res_crop_face_dir,clip_length): + ''' + crop face according to openface landmark + ''' + if not os.path.exists(openface_landmark_dir): + raise ('wrong path of openface landmark dir') + if not os.path.exists(video_frame_dir): + raise ('wrong path of video frame dir') + if not os.path.exists(res_crop_face_dir): + os.mkdir(res_crop_face_dir) + landmark_openface_path_list = glob.glob(os.path.join(openface_landmark_dir, '*.csv')) + for landmark_openface_path in landmark_openface_path_list: + video_name = os.path.basename(landmark_openface_path).replace('.csv', '') + crop_face_video_dir = os.path.join(res_crop_face_dir, video_name) + if not os.path.exists(crop_face_video_dir): + os.makedirs(crop_face_video_dir) + print('cropping face from video: {} ...'.format(video_name)) + landmark_openface_data = load_landmark_openface(landmark_openface_path).astype(np.int) + frame_dir = os.path.join(video_frame_dir, video_name) + if not os.path.exists(frame_dir): + raise ('run last step to extract video frame') + if len(glob.glob(os.path.join(frame_dir, '*.jpg'))) != landmark_openface_data.shape[0]: + raise ('landmark length is different from frame length') + frame_length = min(len(glob.glob(os.path.join(frame_dir, '*.jpg'))), landmark_openface_data.shape[0]) + end_frame_index = list(range(clip_length, frame_length, clip_length)) + video_clip_num = len(end_frame_index) + for i in range(video_clip_num): + first_image = cv2.imread(os.path.join(frame_dir, '000000.jpg')) + video_h,video_w = first_image.shape[0], first_image.shape[1] + crop_flag, radius_clip = compute_crop_radius((video_w,video_h), + landmark_openface_data[end_frame_index[i] - clip_length:end_frame_index[i], :,:]) + if not crop_flag: + continue + radius_clip_1_4 = radius_clip // 4 + print('cropping {}/{} clip from video:{}'.format(i, video_clip_num, video_name)) + res_face_clip_dir = os.path.join(crop_face_video_dir, str(i).zfill(6)) + if not os.path.exists(res_face_clip_dir): + os.mkdir(res_face_clip_dir) + for frame_index in range(end_frame_index[i]- clip_length,end_frame_index[i]): + source_frame_path = os.path.join(frame_dir,str(frame_index).zfill(6)+'.jpg') + source_frame_data = cv2.imread(source_frame_path) + frame_landmark = landmark_openface_data[frame_index, :, :] + crop_face_data = source_frame_data[ + frame_landmark[29, 1] - radius_clip:frame_landmark[ + 29, 1] + radius_clip * 2 + radius_clip_1_4, + frame_landmark[33, 0] - radius_clip - radius_clip_1_4:frame_landmark[ + 33, 0] + radius_clip + radius_clip_1_4, + :].copy() + res_crop_face_frame_path = os.path.join(res_face_clip_dir, str(frame_index).zfill(6) + '.jpg') + if os.path.exists(res_crop_face_frame_path): + os.remove(res_crop_face_frame_path) + cv2.imwrite(res_crop_face_frame_path, crop_face_data) + + +def generate_training_json(crop_face_dir,deep_speech_dir,clip_length,res_json_path): + video_name_list = os.listdir(crop_face_dir) + video_name_list.sort() + res_data_dic = {} + for video_index, video_name in enumerate(video_name_list): + print('generate training json file :{} {}/{}'.format(video_name,video_index,len(video_name_list))) + tem_dic = {} + deep_speech_feature_path = os.path.join(deep_speech_dir, video_name + '_deepspeech.txt') + if not os.path.exists(deep_speech_feature_path): + raise ('wrong path of deep speech') + deep_speech_feature = np.loadtxt(deep_speech_feature_path) + video_clip_dir = os.path.join(crop_face_dir, video_name) + clip_name_list = os.listdir(video_clip_dir) + clip_name_list.sort() + video_clip_num = len(clip_name_list) + clip_data_list = [] + for clip_index, clip_name in enumerate(clip_name_list): + tem_tem_dic = {} + clip_frame_dir = os.path.join(video_clip_dir, clip_name) + frame_path_list = glob.glob(os.path.join(clip_frame_dir, '*.jpg')) + frame_path_list.sort() + assert len(frame_path_list) == clip_length + start_index = int(float(clip_name) * clip_length) + assert int(float(os.path.basename(frame_path_list[0]).replace('.jpg', ''))) == start_index + frame_name_list = [video_name + '/' + clip_name + '/' + os.path.basename(item) for item in frame_path_list] + deep_speech_list = deep_speech_feature[start_index:start_index + clip_length, :].tolist() + if len(frame_name_list) != len(deep_speech_list): + print(' skip video: {}:{}/{} clip:{}:{}/{} because of different length: {} {}'.format( + video_name,video_index,len(video_name_list),clip_name,clip_index,len(clip_name_list), + len(frame_name_list),len(deep_speech_list))) + tem_tem_dic['frame_name_list'] = frame_name_list + tem_tem_dic['frame_path_list'] = frame_path_list + tem_tem_dic['deep_speech_list'] = deep_speech_list + clip_data_list.append(tem_tem_dic) + tem_dic['video_clip_num'] = video_clip_num + tem_dic['clip_data_list'] = clip_data_list + res_data_dic[video_name] = tem_dic + if os.path.exists(res_json_path): + os.remove(res_json_path) + with open(res_json_path,'w') as f: + + json.dump(res_data_dic,f) + + +if __name__ == '__main__': + opt = DataProcessingOptions().parse_args() + ########## step1: extract video frames + if opt.extract_video_frame: + extract_video_frame(opt.source_video_dir, opt.video_frame_dir) + ########## step2: extract audio files + if opt.extract_audio: + extract_audio(opt.source_video_dir,opt.audio_dir) + ########## step3: extract deep speech features + if opt.extract_deep_speech: + extract_deep_speech(opt.audio_dir, opt.deep_speech_dir,opt.deep_speech_model) + ########## step4: crop face images + if opt.crop_face: + crop_face_according_openfaceLM(opt.openface_landmark_dir,opt.video_frame_dir,opt.crop_face_dir,opt.clip_length) + ########## step5: generate training json file + if opt.generate_training_json: + generate_training_json(opt.crop_face_dir,opt.deep_speech_dir,opt.clip_length,opt.json_path) + + diff --git a/DINet/dataset/__pycache__/dataset_DINet_clip.cpython-36.pyc b/DINet/dataset/__pycache__/dataset_DINet_clip.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2340fecf4d099dedadabf51313b91248c98072ca GIT binary patch literal 2956 zcmZuzTW{mW6`tW;)Wxza>*8$NZ3+Ys8n|c@d)*i>`OLf8|V7{(otA_#=d&Gaj#F5ox59xuhzy>PVOm>B)vGhU@tztV5Of2 z$}zH`F3rkO_QZmM&iPHJC}UYZoh{=eTTGuO&{K1H)0$+9thl;$t)fDQH}P9MLbV+Jzh`e>7o>i z^<49b4(M2sWMkdA!^}c4%Zf5;-mT6SNqVirud^gw8vZ|-?7(*a`}MP5UH+{|WpR1& z+w;%A{KMs0T0Ae8tIM-Lo~Pv{$%OImW?`seJj+(6tCu>Mrlp`IA^D<s5fbsba&aA|rCYV-AZviVNK^Dx6h% zjVTK0p_!x-?1F;AVQb~qmLW83#SsqHrKOl~0Isqzddg1t2?ORUmivii@W9}97xX;n z6Srn4yrLwS4K*z5pd&o6qLDBY6t@(mx1`N2sWNJ6w^*-IDUb1`10fI_r4Ys>vJIL% zAjaW*xh}7SOs5$fU1T>Yg+Y?73msOo_@($ryV-myuBz!9ZO;%+9)@XoHFC65q_as> zTg4DW@gkcf{v8a+qMCry;-bLOm#Up!gTA4O=k9hYZOFi{zHFhDgjIqa-Zt+$mx6E&RK+5*w@Fh5I$i!CJt^MUr{| z9}oa)ule=gOBZoa-&$wvHUDa{2}>{cm8(LMjf%p)rFyd)lg799gyh$;Dp2Gh@?JQh zbNzQfS6*IMb(75Xys>4QMq;Z5)--p}>O&UrgB_Uf>V%2?2~$4yH%w7z=Z9GH=BDNr z!H}siaR)TkNRqKYQw7-7q+PgGYtJ6u*4i9gzfuP%LTb}~Ag{f%%3-xW3HR*RQHQF7 zdmb5U9KFR7&7CE>#u8mferSoVYC)$i>GYbv_}v9lwY*-o^NwmI?iR%pg)l)8Z0ASH zL;15d#|HNe-ZOY$@V>!o1|I-_t5pyO!?K$ntK%)tds~+Gk>$0-!dVU^6Q}nUPS0=# zZ{Z9Kr@zN>5ncACZ|FnN`;bn04;5sqf$FQ_#6_OA&Z8F9FjKosrM%?`V@X4FFp(&< z!sHPZQP~6a*2ORdF!t}|XS$uF=}Hur`zxG$Puv7pi1+nz)h}8 z>Rs9^;`wTpO3EN@kH7vvd)Zv3=n0-oDgWh9Xv2?bpyH~X@pAQ2KA~?)Y55}}T8ou- zV=3d8x;|bm#&MZ00H`@-Jbtd-aw*4GI-JPmocaYSwmU)NDYcK1G+&qL=tRCtf?U%d z!K7j}AGK z=_Hlu0&?#+baopm2N;#2!w1JvuX#iw!IAOD&K<@cQTP6UN?fJVJ+-<{)h64I$;6#l{U87*>)EU->u+ki zt_y+YoR#TZ2y(IcO=GC;9dzd&`g8Smt@rqVy0mr5Lwf*TXgjhDC3|mEuL*-cnLa)W<%ghX*W{G^1#n4AOlr%Y6RT}f?{r;ev*MIo iF6j^haf-F2hlFMKp2plvfAR0|d(`~)X zZYB1hb&`V|cmGEI5>Em4u$aHFm*gw9G#cC6LKVqku~;ltebmR@Zv5ZD@7WhU#{R=@ zeIEK>fy%#uNhbM(B~8l{$Bx~^wXK(Uw)GPqtuqOxVG?pCUFp4GN%S?7z6@S48K`jW zCN0VT$%f%Itg`bVH{n@YrKPGeha~tyP#g3JR90XWHeYkjDqh3<+WA+1W%VQUS4kH!n=jPN zcttq}r&Y$+>1?cw_e4D}O^_}yFHHcUoKH?=S76XX_CWhQX}zyVIE5YN(|GnKAbgz&5;kzg;O)(f&AJ94{BHPdknp>W<@9++mCGDJ%qf7?$T__ zN_ZoYXstA^+K?kZpwW@cX2mt7={0$CO|A^P#xIu(rS%<*Oe6%7qY}b~ME4=ndo-~> zoiD3Pq18CYO_%wVqCAlKvNUniiVwvn#?Pl?aoIFK8Fzww3c$wI<qm~faBrQv9P4})3h}{D_29@Lpv-cI?#yYsSBYt(~cA!tUp5N3VS)f{c z1S)$tHe$hQe!@_Fh)Cho&Iu~>gi#e%ZsApk`5O1Tc2-{PN%sWRBmoUSL(gF=9~F3O zRqbOnXcabVWgx?~BRvpSuA(X^!aA%YvJY5+l^vIzpWP8@wV)N%bbCbc2z$%u4an`Y=>thS&h-7d~Jf`*b#BHEW=h=%o z0JhYt?n|q>o0Ii6Cu`#b+jH&Jy?PIx_wD)m|L1Jo&Dr+t+4f=i<+JVAonpV(-hVch;;pLQGL-YZq}MdHk9@w(`Mi^2QGQ-q^uy!g?E# zj_X0)9(kyR-g(l?Ci%kRgN38(0I8xm*fp_a&*F%%2}^LLk7$Z=Y7AoRxA|8|Un;c_ zqi@{e438cE`G_uM!vDxfd2kV$^_QtWzjS(pWyXCoHt<}skB;Zm7MH$Z*Gbf zHy-qNQyb4;Q@**@r(6T!A|!ksTw4^fiu$<0`E~?si_bY`0^t# z20$C|K0k7f0BY~@V?0Qw?gQ6@{0I+K#NURj&pXjE+G7H~<5z(1#%M-qOCPmF7W*ix zQH6y69RFDKTNtAu4_@^!6asK9#J89$0hzFJb@<9|RsruE-U@p0g)p%Y)45zusNWId z=`x)(GZw)0XSC7hV5YkP`AkhFTcD@#z?*kS$$~yTM0Y_Y`Hjaqj^4A6(=gT_LSlTI z%lZ_3Z3~uaNq<07eVg*9%POCgr*tiv&d|JNWcJo^cefHfY@Gj+0#@R0BsSgroQ@lH F{{V?prse= res_frame_length: + res_video_frame_path_list = video_frame_path_list_cycle[:res_frame_length] + res_video_landmark_data = video_landmark_data_cycle[:res_frame_length, :, :] + else: + divisor = res_frame_length // video_frame_path_list_cycle_length + remainder = res_frame_length % video_frame_path_list_cycle_length + res_video_frame_path_list = video_frame_path_list_cycle * divisor + video_frame_path_list_cycle[:remainder] + res_video_landmark_data = np.concatenate([video_landmark_data_cycle]* divisor + [video_landmark_data_cycle[:remainder, :, :]],0) + res_video_frame_path_list_pad = [video_frame_path_list_cycle[0]] * 2 \ + + res_video_frame_path_list \ + + [video_frame_path_list_cycle[-1]] * 2 + res_video_landmark_data_pad = np.pad(res_video_landmark_data, ((2, 2), (0, 0), (0, 0)), mode='edge') + assert ds_feature_padding.shape[0] == len(res_video_frame_path_list_pad) == res_video_landmark_data_pad.shape[0] + pad_length = ds_feature_padding.shape[0] + + ############################################## randomly select 5 reference images ############################################## + print('selecting five reference images') + ref_img_list = [] + resize_w = int(opt.mouth_region_size + opt.mouth_region_size // 4) + resize_h = int((opt.mouth_region_size // 2) * 3 + opt.mouth_region_size // 8) + ref_index_list = random.sample(range(5, len(res_video_frame_path_list_pad) - 2), 5) + for ref_index in ref_index_list: + crop_flag,crop_radius = compute_crop_radius(video_size,res_video_landmark_data_pad[ref_index - 5:ref_index, :, :]) + if not crop_flag: + raise ('our method can not handle videos with large change of facial size!!') + crop_radius_1_4 = crop_radius // 4 + ref_img = cv2.imread(res_video_frame_path_list_pad[ref_index- 3])[:, :, ::-1] + ref_landmark = res_video_landmark_data_pad[ref_index - 3, :, :] + ref_img_crop = ref_img[ + ref_landmark[29, 1] - crop_radius:ref_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, + ref_landmark[33, 0] - crop_radius - crop_radius_1_4:ref_landmark[33, 0] + crop_radius +crop_radius_1_4, + :] + ref_img_crop = cv2.resize(ref_img_crop,(resize_w,resize_h)) + ref_img_crop = ref_img_crop / 255.0 + ref_img_list.append(ref_img_crop) + ref_video_frame = np.concatenate(ref_img_list, 2) + ref_img_tensor = torch.from_numpy(ref_video_frame).permute(2, 0, 1).unsqueeze(0).float().cuda() + + ############################################## load pretrained model weight ############################################## + print('loading pretrained model from: {}'.format(opt.pretrained_clip_DINet_path)) + model = DINet(opt.source_channel, opt.ref_channel, opt.audio_channel).cuda() + if not os.path.exists(opt.pretrained_clip_DINet_path): + raise ('wrong path of pretrained model weight: {}'.format(opt.pretrained_clip_DINet_path)) + state_dict = torch.load(opt.pretrained_clip_DINet_path)['state_dict']['net_g'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove module. + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + model.eval() + ############################################## inference frame by frame ############################################## + if not os.path.exists(opt.res_video_dir): + os.mkdir(opt.res_video_dir) + res_video_path = os.path.join(opt.res_video_dir,os.path.basename(opt.source_video_path)[:-4] + '_facial_dubbing.mp4') + if os.path.exists(res_video_path): + os.remove(res_video_path) + res_face_path = res_video_path.replace('_facial_dubbing.mp4', '_synthetic_face.mp4') + if os.path.exists(res_face_path): + os.remove(res_face_path) + videowriter = cv2.VideoWriter(res_video_path, cv2.VideoWriter_fourcc(*'XVID'), 25, video_size) + videowriter_face = cv2.VideoWriter(res_face_path, cv2.VideoWriter_fourcc(*'XVID'), 25, (resize_w, resize_h)) + for clip_end_index in range(5, pad_length, 1): + print('synthesizing {}/{} frame'.format(clip_end_index - 5, pad_length - 5)) + crop_flag, crop_radius = compute_crop_radius(video_size,res_video_landmark_data_pad[clip_end_index - 5:clip_end_index, :, :],random_scale = 1.05) + if not crop_flag: + raise ('our method can not handle videos with large change of facial size!!') + crop_radius_1_4 = crop_radius // 4 + frame_data = cv2.imread(res_video_frame_path_list_pad[clip_end_index - 3])[:, :, ::-1] + frame_landmark = res_video_landmark_data_pad[clip_end_index - 3, :, :] + crop_frame_data = frame_data[ + frame_landmark[29, 1] - crop_radius:frame_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, + frame_landmark[33, 0] - crop_radius - crop_radius_1_4:frame_landmark[33, 0] + crop_radius +crop_radius_1_4, + :] + crop_frame_h,crop_frame_w = crop_frame_data.shape[0],crop_frame_data.shape[1] + crop_frame_data = cv2.resize(crop_frame_data, (resize_w,resize_h)) # [32:224, 32:224, :] + crop_frame_data = crop_frame_data / 255.0 + crop_frame_data[opt.mouth_region_size//2:opt.mouth_region_size//2 + opt.mouth_region_size, + opt.mouth_region_size//8:opt.mouth_region_size//8 + opt.mouth_region_size, :] = 0 + + crop_frame_tensor = torch.from_numpy(crop_frame_data).float().cuda().permute(2, 0, 1).unsqueeze(0) + deepspeech_tensor = torch.from_numpy(ds_feature_padding[clip_end_index - 5:clip_end_index, :]).permute(1, 0).unsqueeze(0).float().cuda() + with torch.no_grad(): + pre_frame = model(crop_frame_tensor, ref_img_tensor, deepspeech_tensor) + pre_frame = pre_frame.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() * 255 + videowriter_face.write(pre_frame[:, :, ::-1].copy().astype(np.uint8)) + pre_frame_resize = cv2.resize(pre_frame, (crop_frame_w,crop_frame_h)) + frame_data[ + frame_landmark[29, 1] - crop_radius: + frame_landmark[29, 1] + crop_radius * 2, + frame_landmark[33, 0] - crop_radius - crop_radius_1_4: + frame_landmark[33, 0] + crop_radius + crop_radius_1_4, + :] = pre_frame_resize[:crop_radius * 3,:,:] + videowriter.write(frame_data[:, :, ::-1]) + videowriter.release() + videowriter_face.release() + video_add_audio_path = res_video_path.replace('.mp4', '_add_audio.mp4') + if os.path.exists(video_add_audio_path): + os.remove(video_add_audio_path) + cmd = 'ffmpeg -i {} -i {} -c:v copy -c:a aac -strict experimental -map 0:v:0 -map 1:a:0 {}'.format( + res_video_path, + opt.driving_audio_path, + video_add_audio_path) + subprocess.call(cmd, shell=True) + + + + + + + diff --git a/DINet/models/DINet.py b/DINet/models/DINet.py new file mode 100644 index 00000000..5124bf05 --- /dev/null +++ b/DINet/models/DINet.py @@ -0,0 +1,299 @@ +import torch +from torch import nn +import torch.nn.functional as F +import math +import cv2 +import numpy as np +from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d +from sync_batchnorm import SynchronizedBatchNorm1d as BatchNorm1d + +def make_coordinate_grid_3d(spatial_size, type): + ''' + generate 3D coordinate grid + ''' + d, h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + z = torch.arange(d).type(type) + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + z = (2 * (z / (d - 1)) - 1) + yy = y.view(1,-1, 1).repeat(d,1, w) + xx = x.view(1,1, -1).repeat(d,h, 1) + zz = z.view(-1,1,1).repeat(1,h,w) + meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3)], 3) + return meshed,zz + +class ResBlock1d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features,out_features, kernel_size, padding): + super(ResBlock1d, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + if out_features != in_features: + self.channel_conv = nn.Conv1d(in_features,out_features,1) + self.norm1 = BatchNorm1d(in_features) + self.norm2 = BatchNorm1d(in_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.norm1(x) + out = self.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = self.relu(out) + out = self.conv2(out) + if self.in_features != self.out_features: + out += self.channel_conv(x) + else: + out += x + return out + +class ResBlock2d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features,out_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + if out_features != in_features: + self.channel_conv = nn.Conv2d(in_features,out_features,1) + self.norm1 = BatchNorm2d(in_features) + self.norm2 = BatchNorm2d(in_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.norm1(x) + out = self.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = self.relu(out) + out = self.conv2(out) + if self.in_features != self.out_features: + out += self.channel_conv(x) + else: + out += x + return out + +class UpBlock2d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features, out_features, kernel_size=3, padding=1): + super(UpBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + self.norm = BatchNorm2d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + +class DownBlock1d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features, out_features, kernel_size, padding): + super(DownBlock1d, self).__init__() + self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding,stride=2) + self.norm = BatchNorm1d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + +class DownBlock2d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features, out_features, kernel_size=3, padding=1, stride=2): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, stride=stride) + self.norm = BatchNorm2d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + +class SameBlock1d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features, out_features, kernel_size, padding): + super(SameBlock1d, self).__init__() + self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding) + self.norm = BatchNorm1d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + +class SameBlock2d(nn.Module): + ''' + basic block + ''' + def __init__(self, in_features, out_features, kernel_size=3, padding=1): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding) + self.norm = BatchNorm2d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + +class AdaAT(nn.Module): + ''' + AdaAT operator + ''' + def __init__(self, para_ch,feature_ch): + super(AdaAT, self).__init__() + self.para_ch = para_ch + self.feature_ch = feature_ch + self.commn_linear = nn.Sequential( + nn.Linear(para_ch, para_ch), + nn.ReLU() + ) + self.scale = nn.Sequential( + nn.Linear(para_ch, feature_ch), + nn.Sigmoid() + ) + self.rotation = nn.Sequential( + nn.Linear(para_ch, feature_ch), + nn.Tanh() + ) + self.translation = nn.Sequential( + nn.Linear(para_ch, 2 * feature_ch), + nn.Tanh() + ) + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + + def forward(self, feature_map,para_code): + batch,d, h, w = feature_map.size(0), feature_map.size(1), feature_map.size(2), feature_map.size(3) + para_code = self.commn_linear(para_code) + scale = self.scale(para_code).unsqueeze(-1) * 2 + angle = self.rotation(para_code).unsqueeze(-1) * 3.14159# + rotation_matrix = torch.cat([torch.cos(angle), -torch.sin(angle), torch.sin(angle), torch.cos(angle)], -1) + rotation_matrix = rotation_matrix.view(batch, self.feature_ch, 2, 2) + translation = self.translation(para_code).view(batch, self.feature_ch, 2) + grid_xy, grid_z = make_coordinate_grid_3d((d, h, w), feature_map.type()) + grid_xy = grid_xy.unsqueeze(0).repeat(batch, 1, 1, 1, 1) + grid_z = grid_z.unsqueeze(0).repeat(batch, 1, 1, 1) + scale = scale.unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w, 1) + rotation_matrix = rotation_matrix.unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w, 1, 1) + translation = translation.unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w, 1) + trans_grid = torch.matmul(rotation_matrix, grid_xy.unsqueeze(-1)).squeeze(-1) * scale + translation + full_grid = torch.cat([trans_grid, grid_z.unsqueeze(-1)], -1) + trans_feature = F.grid_sample(feature_map.unsqueeze(1), full_grid).squeeze(1) + return trans_feature + +class DINet(nn.Module): + def __init__(self, source_channel,ref_channel,audio_channel): + super(DINet, self).__init__() + self.source_in_conv = nn.Sequential( + SameBlock2d(source_channel,64,kernel_size=7, padding=3), + DownBlock2d(64, 128, kernel_size=3, padding=1), + DownBlock2d(128,256,kernel_size=3, padding=1) + ) + self.ref_in_conv = nn.Sequential( + SameBlock2d(ref_channel, 64, kernel_size=7, padding=3), + DownBlock2d(64, 128, kernel_size=3, padding=1), + DownBlock2d(128, 256, kernel_size=3, padding=1), + ) + self.trans_conv = nn.Sequential( + # 20 →10 + SameBlock2d(512, 128, kernel_size=3, padding=1), + SameBlock2d(128, 128, kernel_size=11, padding=5), + SameBlock2d(128, 128, kernel_size=11, padding=5), + DownBlock2d(128, 128, kernel_size=3, padding=1), + # 10 →5 + SameBlock2d(128, 128, kernel_size=7, padding=3), + SameBlock2d(128, 128, kernel_size=7, padding=3), + DownBlock2d(128, 128, kernel_size=3, padding=1), + # 5 →3 + SameBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128, 128, kernel_size=3, padding=1), + # 3 →2 + SameBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128, 128, kernel_size=3, padding=1), + + ) + self.audio_encoder = nn.Sequential( + SameBlock1d(audio_channel, 128, kernel_size=5, padding=2), + ResBlock1d(128, 128, 3, 1), + DownBlock1d(128, 128, 3, 1), + ResBlock1d(128, 128, 3, 1), + DownBlock1d(128, 128, 3, 1), + SameBlock1d(128, 128, kernel_size=3, padding=1) + ) + + appearance_conv_list = [] + for i in range(2): + appearance_conv_list.append( + nn.Sequential( + ResBlock2d(256, 256, 3, 1), + ResBlock2d(256, 256, 3, 1), + ResBlock2d(256, 256, 3, 1), + ResBlock2d(256, 256, 3, 1), + ) + ) + self.appearance_conv_list = nn.ModuleList(appearance_conv_list) + self.adaAT = AdaAT(256, 256) + self.out_conv = nn.Sequential( + SameBlock2d(512, 128, kernel_size=3, padding=1), + UpBlock2d(128,128,kernel_size=3, padding=1), + ResBlock2d(128, 128, 3, 1), + UpBlock2d(128, 128, kernel_size=3, padding=1), + nn.Conv2d(128, 3, kernel_size=(7, 7), padding=(3, 3)), + nn.Sigmoid() + ) + self.global_avg2d = nn.AdaptiveAvgPool2d(1) + self.global_avg1d = nn.AdaptiveAvgPool1d(1) + def forward(self, source_img,ref_img,audio_feature): + ## source image encoder + source_in_feature = self.source_in_conv(source_img) + ## reference image encoder + ref_in_feature = self.ref_in_conv(ref_img) + ## alignment encoder + img_para = self.trans_conv(torch.cat([source_in_feature,ref_in_feature],1)) + img_para = self.global_avg2d(img_para).squeeze(3).squeeze(2) + ## audio encoder + audio_para = self.audio_encoder(audio_feature) + audio_para = self.global_avg1d(audio_para).squeeze(2) + ## concat alignment feature and audio feature + trans_para = torch.cat([img_para,audio_para],1) + ## use AdaAT do spatial deformation on reference feature maps + ref_trans_feature = self.appearance_conv_list[0](ref_in_feature) + ref_trans_feature = self.adaAT(ref_trans_feature, trans_para) + ref_trans_feature = self.appearance_conv_list[1](ref_trans_feature) + ## feature decoder + merge_feature = torch.cat([source_in_feature,ref_trans_feature],1) + out = self.out_conv(merge_feature) + return out + + + diff --git a/DINet/models/Discriminator.py b/DINet/models/Discriminator.py new file mode 100644 index 00000000..0aa737ab --- /dev/null +++ b/DINet/models/Discriminator.py @@ -0,0 +1,39 @@ +from torch import nn +import torch.nn.functional as F + +class DownBlock2d(nn.Module): + def __init__(self, in_features, out_features, kernel_size=4, pool=False): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) + self.pool = pool + def forward(self, x): + out = x + out = self.conv(out) + out = F.leaky_relu(out, 0.2) + if self.pool: + out = F.avg_pool2d(out, (2, 2)) + return out + + +class Discriminator(nn.Module): + """ + Discriminator for GAN loss + """ + def __init__(self, num_channels, block_expansion=64, num_blocks=4, max_features=512): + super(Discriminator, self).__init__() + down_blocks = [] + for i in range(num_blocks): + down_blocks.append( + DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=4, pool=(i != num_blocks - 1))) + self.down_blocks = nn.ModuleList(down_blocks) + self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) + def forward(self, x): + feature_maps = [] + out = x + for down_block in self.down_blocks: + feature_maps.append(down_block(out)) + out = feature_maps[-1] + out = self.conv(out) + return feature_maps, out diff --git a/DINet/models/Syncnet.py b/DINet/models/Syncnet.py new file mode 100644 index 00000000..44043670 --- /dev/null +++ b/DINet/models/Syncnet.py @@ -0,0 +1,215 @@ +import torch +from torch import nn +class ResBlock1d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features,out_features, kernel_size, padding): + super(ResBlock1d, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + if out_features != in_features: + self.channel_conv = nn.Conv1d(in_features,out_features,1) + self.relu = nn.ReLU() + def forward(self, x): + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + if self.in_features != self.out_features: + out += self.channel_conv(x) + else: + out += x + return out + +class ResBlock2d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features,out_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding) + if out_features != in_features: + self.channel_conv = nn.Conv2d(in_features,out_features,1) + self.relu = nn.ReLU() + def forward(self, x): + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + if self.in_features != self.out_features: + out += self.channel_conv(x) + else: + out += x + return out + +class DownBlock1d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features, out_features, kernel_size, padding): + super(DownBlock1d, self).__init__() + self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding,stride=2) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + +class DownBlock2d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + out = self.pool(out) + return out + +class SameBlock1d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features, out_features, kernel_size, padding): + super(SameBlock1d, self).__init__() + self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + +class SameBlock2d(nn.Module): + ''' + basic block (no BN) + ''' + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding, groups=groups) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + +class FaceEncoder(nn.Module): + ''' + image encoder + ''' + def __init__(self, in_channel, out_dim): + super(FaceEncoder, self).__init__() + self.in_channel = in_channel + self.out_dim = out_dim + self.face_conv = nn.Sequential( + SameBlock2d(in_channel,64,kernel_size=7,padding=3), + # # 64 → 32 + ResBlock2d(64, 64, kernel_size=3, padding=1), + DownBlock2d(64,64,3,1), + SameBlock2d(64, 128), + # 32 → 16 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128, 128), + # 16 → 8 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128, 128), + # 8 → 4 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128, 128), + # 4 → 2 + ResBlock2d(128, 128, kernel_size=3, padding=1), + DownBlock2d(128,128,3,1), + SameBlock2d(128,out_dim,kernel_size=1,padding=0) + ) + def forward(self, x): + ## b x c x h x w + out = self.face_conv(x) + return out + +class AudioEncoder(nn.Module): + ''' + audio encoder + ''' + def __init__(self, in_channel, out_dim): + super(AudioEncoder, self).__init__() + self.in_channel = in_channel + self.out_dim = out_dim + self.audio_conv = nn.Sequential( + SameBlock1d(in_channel,128,kernel_size=7,padding=3), + ResBlock1d(128, 128, 3, 1), + # 9-5 + DownBlock1d(128, 128, 3, 1), + ResBlock1d(128, 128, 3, 1), + # 5 -3 + DownBlock1d(128, 128, 3, 1), + ResBlock1d(128, 128, 3, 1), + # 3-2 + DownBlock1d(128, 128, 3, 1), + SameBlock1d(128,out_dim,kernel_size=3,padding=1) + ) + self.global_avg = nn.AdaptiveAvgPool1d(1) + def forward(self, x): + ## b x c x t + out = self.audio_conv(x) + return self.global_avg(out).squeeze(2) + +class SyncNet(nn.Module): + ''' + syncnet + ''' + def __init__(self, in_channel_image,in_channel_audio, out_dim): + super(SyncNet, self).__init__() + self.in_channel_image = in_channel_image + self.in_channel_audio = in_channel_audio + self.out_dim = out_dim + self.face_encoder = FaceEncoder(in_channel_image,out_dim) + self.audio_encoder = AudioEncoder(in_channel_audio,out_dim) + self.merge_encoder = nn.Sequential( + nn.Conv2d(out_dim * 2, out_dim, kernel_size=3, padding=1), + nn.LeakyReLU(0.2), + nn.Conv2d(out_dim, 1, kernel_size=3, padding=1), + ) + def forward(self, image,audio): + image_embedding = self.face_encoder(image) + audio_embedding = self.audio_encoder(audio).unsqueeze(2).unsqueeze(3).repeat(1,1,image_embedding.size(2),image_embedding.size(3)) + concat_embedding = torch.cat([image_embedding,audio_embedding],1) + out_score = self.merge_encoder(concat_embedding) + return out_score + +class SyncNetPerception(nn.Module): + ''' + use syncnet to compute perception loss + ''' + def __init__(self,pretrain_path): + super(SyncNetPerception, self).__init__() + self.model = SyncNet(15,29,128) + print('load lip sync model : {}'.format(pretrain_path)) + self.model.load_state_dict(torch.load(pretrain_path)['state_dict']['net']) + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + + def forward(self, image,audio): + score = self.model(image,audio) + return score + diff --git a/DINet/models/VGG19.py b/DINet/models/VGG19.py new file mode 100644 index 00000000..a0bf89d4 --- /dev/null +++ b/DINet/models/VGG19.py @@ -0,0 +1,47 @@ +import torch +from torchvision import models +import numpy as np + + +class Vgg19(torch.nn.Module): + """ + Vgg19 network for perceptual loss + """ + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_model = models.vgg19(pretrained=True) + vgg_pretrained_features = vgg_model.features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), + requires_grad=False) + self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), + requires_grad=False) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + X = (X - self.mean) / self.std + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out diff --git a/DINet/models/__pycache__/DINet.cpython-36.pyc b/DINet/models/__pycache__/DINet.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0e0bf51e7c717a103b49b47270fe602b23465c6 GIT binary patch literal 9388 zcmdT~Ta(;I71n)bG?%@2vtGwZupvMO5)-e1KnP%am)y$PO2u&{hy@{!w6imw(QH~; z$DYzq6q6!6kSdBN9w;7o3ROJtANUVYrAK(6_KjC~Abej-nv1>bTMVg4HAmg2Tivbh zug^Js&YY{)?SCwP)N;=#%HNfl%Ru@(e$jtGe8pFLio>VYQ~Ii-su@q~s$JbNWVz{> z@@YAie3qOto_ep+x1Fl2Zy;TB>PVZu^`7E1UQ+y$Uw%*V%YoWm80*d=I2GRp$IdxR z;8gt@IJKO!3{Kr|fYZo1$G}v2YO1LghM>nU-wa#3Vi0zQfq&79Tf18W(O>tEEIRAA zG&X^&h9Bp41Hb5RAc?ZAkF=3KGGe8xCdOFpYKb~lN2bi{le{JK+9Y4{)x=B;Ut3d7 zDl1r-HDzRiXC?X^&jPQMXmh+$Vs%l=$J=AQYp@hO>!WgFO1|cs8J}%|Ur9>fTX;j( zM{Ql!w=pI&sdOzL<4nrgv#lYgjjBmCv3E4gUuoF1PvuuT2m|58!Krhbr&@!7@H-)M zZPD>%odZtPfl8v2e1uDsk&>iYO3ip6TDz&`2`_91sS)4Y1HIk}-c2nL>;+z&>MbvJ z?ENr$XFmvrf!nM(wP??aJ6_L?ux6r)_EObP)!kHmH&t(>>djOgrrOP$sdnQ=stt#Z z)eoZG!2g3H)=>H1&ul!ueI*J+wEfnFt(V?@d3!U6uEm4B?af!Vf_S?>@Pl5IWzOu~ zOpo`yYk@ocK9}F;p7U+A6|q6Zu+>F*uc*A&>GOc`Z{m zt|)w3__4bv0g07ZO_Z_L)y9gdBuKk@qJbK**|m^Yi2+~zzyPn5^CT{EUKzF3Yr1fX zGGYgHE3r*{BUPQGYN%5mV=uT3U%~M8t)`JyJE7a!^}-NSo7M*V@ibSx76`EAYNn+< z&j*a!g5z&4q-M0g7YK0@xwP!Mov;(Tu54`w=-3xQ-dd3pgj{NE48rSY{nTvnnLXCi zT0Q`*Wq3WUPBuF$hS>C|$j5qWyb)Zwk{TlD?KjQTh=SgZ_$=BLpCfsKhU^6~QqI9utkR(_& zNES#AnbIaHtmkc|kOeu<#9J2D?`^Jnx5Sp7Q$|`S*OjS+dt2RVmPHf2iKj_EPx2*_ zFOW=U;ETvMb@2?j%OI(a1QSw6uDm~ zk<&iK#B|@S`x$a1Xb6>=t%fp*FNvxzvl+spX=YfOwl?XvSh@+WWWV3WOZT|D*8NZW z?h;?&Jjn@|&I0lNDre#DI}0-~*;(Q&+gc|%NAgvY$p#W%V|LP0G51oaam~Mp-|V&_ zZ*JSr7~&OieP{KXUzQqw$OE~Zfg zkvXH{%-Eu6l*&%;mZh?)NM)r#SsjM|rkz2bQ;>r?! z$OciMHZqklT%JtL;YglXpu~i*XKF<;%+jF1>37lF z3{Ib9RxThd&ukSR9Hl+9F|({Snlmdi^K&>(xnVZ=CdK?+?3vse(H{hKf_?CmQ5YP< zW3yu`K=(0##>w4fF9pzQmLP8qXh+M!ij;*sHd7YLL-UNaZ>LkJ-G?-kyST*x@#6Cy zk(c0{c^BkCC>{zraUR_ih`WIF4B{vmv(=+uycKcQkqSR+GdI|#`i?q^Cz+#Rca)<4 zf5Gl&9NiJHn^D2~RMusJ_haCYQ+|Wp7C3~!BX16e6h6G}*OaYZh6Vlr0*_;1pV9wV z6JJChGkANr67HA$fI<(Kaaf^*o0*VLR>J+7qnn4??ARWHgyZD?g1r>Tf`lV)4p~P_ zxVu?oM%Pbe9|K3Gx{CKywP^6^9;@| zxKrRz(U8JPVobMMg>JIEM3W_UUQ)fSq?VV|rzLRmYAa|7oT9XmrP1~RQfLjS#r(vQ z>5--EfU=1ZOTf~VCOo*=Qag8!A)UF$?8^bXsUdx+j&-RM2A+^-7I?@8{eI~7Br7dl z?zH=Z4!mRXWRMzfdf{$b7K0e>*Fl(8W4LW2oDG;UViuO7?3Kbh^gWK|MUdhwIj_Rl3qD=*Fv0jU4gQm^~+xk$2nV@zN;=t*SPoDDlq)VGm zxa4D^FIbJXVn|CwNwsB86v(oVMYY z;Wn?ponFDYQ|a2ijxiYEDBUX0o>M+t@^)NH3R z5tZUKl1E7xMi6z7)arY2f4>JTXCVhcZJMVq!u*uxd&!V$GF|loM0!Ek*at?8xEbP+BH~=w5FRGRRF$Ku7ULKjacw4j6yGlD@a-`~ zrdw|-Bjd*Fal^+}-&B6MIYR8RyCCOd5i=|27OmeydJI9OD!)?S`N>-f-d|$O6g@1* z$JmZkrYncc^CPokLO+~8%n@Fjf}$SjN0C06ryr9_ zwXt<@9^g8V&PI}^HuZ%c{1go7f;>PpP3b5=-~r{w0qurN8*75DJW25Cw3){;7#O5M!IEaMo{T9)QWG=ApT?%1fqfmx4LZGCf zu&Jm?p`<0ZlrTjzq@2J2$xEcH!aoB_Sqq5sE8U83oKO-IwP|43P}G_E!4efEP}!qO z-g6~?&qhr|^-+T8KE>>o#B>H$iDYW$W=;GU58;tqC7Ebc(P8!qNsFXQ!bK8ol3ftI z%_yN`7l`h+(~?AUzwPA9o2xNrsi+gW1*?$*zsQ$6{kBUz>)098vvj5oWqOHk%&AAG z-Va0@e(fw?e59{>1z#J{6i_r9_o})It8-CX1W>U28l$$udbaq5sfpQThtJ_7Lb5dP!$iwq&DSV287_tJ$5I?;Cy5>GY$NwLlk7 zV+gwl9H#Qq7|#SqHlObq*Doyt|H)=^#UHYhPjNNMR}fN zrfwJ0a=mJe5)Ec)h8IcT)sC4?60Py3x*_E-AC^*ThD-OT$e$mmTEy_J#iS^*z^R2v zvVpx|kVNHFCnmHPT(7hXztt@ekgBGXaTXaP<)7rAFOR+{oSj^{ZoiH4Y30 z6M6*Gd=8?bm#iWwsn{H}lBTR;iq4s}a{U=t1X#GWV68>(?Cl>o{`z)-UIV5#;bTSM zXdm>NJ`MdWiAFD_PO_2paQPj_SmfEPyoGwzqjHJn!Q1xUqo6pVPWtf zp@RCfO)}0C&?u9ZAy+h4NJ66@c7A-GyZO zb>~)5FCP0{7`JR)2(o+fq}R;~W9;_g5U#QjF9?)i#SIV4n13v?c?;q;H_phVGYd;z z@i{#vzx6KN%1K4(RIZ{UrVlEbZ&TP zK(!YHo<(EW&Zpy5F9`40voe`YqTCFNJoIsU=^$*6qnTW>`eS?w9bRK*1!Viid$4J? zKp;JZosZ}a-9;$hyH)4^k`=?ptXQ%hf&vv}Kv+I|jFg2MKDdNxr4(ugHN`q$?e@qe zKi@nha0dDy(3)dE93XjlPk#h{^dk_!LVUUIqCnUHdy#E=8${T?3QvxsiP1Y~^4B?B zGi6B;BvTir!Qhdu0^-|%>C=n5x1+qKEY}~xQHx^z5Hu(Q=1Sj3u?wR1q*O%=w8UCl zzUn>H{-33Y3e2O^+N$8Sc;d=$ZZSdi@Py5+xx6$cZ&>Ml?>}q*3;z6qVf@P|Tn@??@Q?lhA`D^nj27=!&)6|rrfFO?ge{zhhHyf& zYfbHz4UQ{H;FL7S0jDe~;8Zln1*a-Ja6HW^fm0K8aO#><24_hugR`tT6>wI>Dmbf} zQw8UUSOaHGb3Aa4igj?-HKzv7F>xH6+rY&Olr>h}Ydv&+_n=rt5V zqh$)CWeKxoe`4G=g(d8Vm6jtM;XX84?k9$0v`Xl&lz1P6(Yw7taQB=Tzvwmeb*mk9 zg2pZK8?W^TjdwSj>RI6GVR$d$&taoGAh8jf6Ju&&IStd8pzPZd3)G3-ZV7dBVow}l zK61b-Yo1l$RnXeJYs*&lOh7PbTSeWQ8YtC-fhpUYO(&^!`hIYy-S3CJD5(!d@vK(6 z8%nVKs51_e@~|yLr@y^tVv5ar;zpxkC?#i|RD8eF@5H{Zdb=IAY5WvmyEn* zp()6Dc7#cx85>#t-g{ROA`z7F*lnfxrUzQGRX?bY#~-rZQAlEZi)GEVncND z7@H_^ob2-Lf&8c~1^bBD-pIeNZmwDtbG%mgLfyDorab%Yoe=W$d>;!p>akq+{d=Q! zPd_R9z8D0)ua;4s;&4k^96{LY;Xyu`rYX6^l5(WW{#OdQI^HRP&1scs;cW zZiw+Kxe8JfzvOBtX=o))!3^NCf@-ratK@!-gorI5V5U(Uq5u&JS{@;x+?%d~Na3}D z2BmVz0iy=5d|^?lgW%6*o|r_d2oNb>L^TR}t1D7hb~! zx1{_>S^JMYLJ$aBfZn7O#j+#9W+eqMVU?~K-;qloi9>t;3{d?ZI*F(ywBDXsH^=LR zC0UG5j(&;n5gQ9lwl=+%d>JiL9ZJpY!I-+hPN`d|ksTi`PUR4w#07jS8$%gRSCTge zB?^L9^bL?=>>?^EdSTR3Mf6e<<+Ve4pt3)#CN4sl z99D>@>_-fB$^cnKfvl=P=5FqXuWdOP4I}vyMuE<8_ z``gzCgWiTnoFVQ79G3Pqneyv+HPi5UyZJtzL^KK+7?+E1+|OL*aO7t!grm{T;sy;& zdj?F|#9KNDI2fHjMnkPCX^2z6XTj9j@>MicGELo2`RQ#RSN=T?OVJRp%Oy$P9AF6{ zUeOl}wre*Kpge=^$rY5voW--BX15m&tU8F={~WoOI3o(vDnrf;6j~= zJF$zluHZreq=nuywKLjveiNe>(9~d6`KB3~9()Iy6AezO+3 zrdpu}vnn|l}VwR;R4lx*S0i3GLHSBkgyD^^XAW-WE=jNp|{ zB;G86k(8Og5S^Vg+2Gv>?~THK+-dieWYbL)T~JcJ4f&{aL*_l{2jn@9pm(eD-hG58 z%13310=P448D_Y$4;`-;An-Z5#O-~-_iI0I$~QTkCy70O&Dkk?Xn9Cm6h}S;Ys|I2 z#}OzLWsT2B@&f$xD+BQtb>#JnBheXTs`64%Rk&|&^XdNT!j)6%@>>+1PF(>1qO#;_ z^n`pU7eK9*YFZ(vk!Gt2cgkqCt^i)zl&@oc@+`@#AddmOxs<~C+ibxgQs#J6SeEoJ zQiT=#Tow2yo?4tVX?EnIXb$gz3by}_+^&u`z~&OnrRx)|i2qU(rz-2%wX zx>e8r>(=;KQC^>eYj#z@wa$ulkxZR31#4}lk-9CvgNBluo1`2;Wy5ip&G^7={2%bf z4EF1DGdl#_Qy9PD5IOGGBzcE|`!X|;NKB58D9A@?UNUtES!BaIvp)*q5{3N{j`S}# z=u8jWQznbcF_VdDxf*-j8vb?0A?pSx5lg4xb(e@#{VeCJbn$YAU2aVjY?Uo8Vt4y@ zpYrxKx-nG%)fNthrwz^MNTkLOZdB^dO&>6q&#a-*wSLNkz44!r27 zh>1)1M?V3{LC(akO8urF2SpmVYh$(T2~sKqzPf`ub#~K)K9c#{kjK>8#csG;otPJkfb{5Yh_7FG8`i7j$}Dg^=>?n!JWho+Od2GEfFQDEU}x) z!Bq>}T2VDW+_@F1#C~fM8H{gGqj`dF9b9g034(G$HM${>U#9rVXse($)00QCgTnCUaao&TIfi9G{TO~)hr*M#V%HWfa zPpW4)!MSYw+}pK8d1Cpsd)Gfq->LJoD2u0fRlgiuavCj_}b4T2iFy*jz8o%YWRs5INUS2!1cH;kgB9=n{ literal 0 HcmV?d00001 diff --git a/DINet/models/__pycache__/VGG19.cpython-36.pyc b/DINet/models/__pycache__/VGG19.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..048b20c045e89e7ab59c8d4afe2447871290892d GIT binary patch literal 1751 zcmZ`(O>5jn7@m)1hi+CUB^V0!4a&_bcmU(iGEh2&g1P;y8i>|baP%{#Kz*|hE~^UR~!XXc%GKW4d7 zah^~9(_FIAg0;Tia>#T8Hr_>!jTvuElL3a@BfM%Z(vq!U$JII&|>^ zMkqxo&QM>Av_8Tp)gyyqsy)IzEz=^>(^E6DBD;?=Q2tyf!FeOf z@)^oML3tIdo7Bpcm-2H~{;NF3`_*lf-);hMGM$R1sZC4N3E_s(c{vc@upFh2>bNrmWw^7!&%CM zR>J5&MRcG-IVdYQui$ln(E08A&il^Odid+XFW=6+{YO3gX7#K1`QN|S2M}*GP@!LU zHEGZw4P=G0{cejh;fFk+;k-xRefrsxdU!NgShRn7To0c-&Rg9zcnMZ=M?uaM{Lk#1 zB%%wM$dxf^`+h5FrM}PAp*?b-5?IbQ_p9ehHV=3d(++vAnb1>17B;Y(ak``Fr z4LVF3dn`!Q3M!q>aE0K?tF9>x!Q!3#!z;nLl>rueGm5?phJSh$d?-o236y=w_nUDb zgzx`|{(5WmoyHe}anacRXmjJv$Bi{6?xyWdW9^eomNrgah{o-Wjm4YSJ4d_<)|IEl z5&*(gT*fB!;jfD4@c1kJl2*eGhOU{QzF*BBhU-JepVq$(yAkNC(8U%&2Gt*-S*Y&_ zzlM71OUodWq5e^%XBsdan2|x2=e0pp{q_jVnHd3QW}2$zjAUj=dRAsZw&G3K;x*Wv zzoB4C!L$OErTl_|835Pjmtm0jj%V%qoWv9cL5b9f^3+$)tx#a=2CmNMb8u^QJ~@_W{Xt8#+DTp*k|f RGB, HWC -> CHW) + image1 = torch.tensor(image1).permute(2, 0, 1).unsqueeze(0).float() / 255.0 + image2 = torch.tensor(image2).permute(2, 0, 1).unsqueeze(0).float() / 255.0 + + # 计算 LPIPS 值 + return lpips_model(image1, image2).item() + +# 计算FID +def compute_fid(real_images, fake_images): + print("Computing FID...") + real_dir = './real_images' + fake_dir = './fake_images' + os.makedirs(real_dir, exist_ok=True) + os.makedirs(fake_dir, exist_ok=True) + + for i, img in enumerate(real_images): + cv2.imwrite(os.path.join(real_dir, f"real_{i}.png"), cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + for i, img in enumerate(fake_images): + cv2.imwrite(os.path.join(fake_dir, f"fake_{i}.png"), cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + fid_value = fid_score.calculate_fid_given_paths([real_dir, fake_dir], batch_size=1, device='cuda', dims=2048) + + # 删除临时文件夹 + for img in os.listdir(real_dir): + os.remove(os.path.join(real_dir, img)) + for img in os.listdir(fake_dir): + os.remove(os.path.join(fake_dir, img)) + os.rmdir(real_dir) + os.rmdir(fake_dir) + + print(f"FID computed: {fid_value}") + return fid_value + +def compute_lse_c(real_images, fake_images): + print("Computing LSE-C...") + return np.mean((real_images - fake_images) ** 2) + +def compute_lse_d(real_images, fake_images): + print("Computing LSE-D...") + return np.mean(np.abs(real_images - fake_images)) + +# 评估单个视频的质量(PSNR, SSIM, LPIPS, FID, LSE-C, LSE-D) +def evaluate_video_quality(real_video_path, fake_video_path): + print(f"Evaluating: {real_video_path} vs {fake_video_path}") + real_video = cv2.VideoCapture(real_video_path) + fake_video = cv2.VideoCapture(fake_video_path) + + real_frame_count = int(real_video.get(cv2.CAP_PROP_FRAME_COUNT)) + fake_frame_count = int(fake_video.get(cv2.CAP_PROP_FRAME_COUNT)) + + min_frame_count = min(real_frame_count, fake_frame_count) + + psnr_scores = [] + ssim_scores = [] + lpips_scores = [] # 修改为 LPIPS + frame_idx = 0 + real_images = [] + fake_images = [] + + while True: + ret_real, frame_real = real_video.read() + ret_fake, frame_fake = fake_video.read() + + if ret_real and ret_fake: + frame_real_rgb = cv2.cvtColor(frame_real, cv2.COLOR_BGR2RGB) + frame_fake_rgb = cv2.cvtColor(frame_fake, cv2.COLOR_BGR2RGB) + + # 计算PSNR + psnr_scores.append(compute_psnr(frame_real_rgb, frame_fake_rgb)) + # 计算SSIM + ssim_scores.append(compute_ssim(frame_real_rgb, frame_fake_rgb)) + real_images.append(frame_real_rgb) + fake_images.append(frame_fake_rgb) + # 计算LPIPS + lpips_scores.append(compute_lpips(frame_real_rgb, frame_fake_rgb)) + frame_idx += 1 + + elif ret_fake: + lpips_scores.append(compute_lpips(frame_fake_rgb, frame_fake_rgb)) + else: + break + + avg_psnr = np.mean(psnr_scores) + avg_ssim = np.mean(ssim_scores) + avg_lpips = np.mean(lpips_scores) # 修改为 LPIPS 平均值 + + fid_value = compute_fid(np.array(real_images), np.array(fake_images)) + lse_c_score = compute_lse_c(np.array(real_images), np.array(fake_images)) + lse_d_score = compute_lse_d(np.array(real_images), np.array(fake_images)) + + real_video.release() + fake_video.release() + + print(f"Video evaluated. PSNR: {avg_psnr}, SSIM: {avg_ssim}, LPIPS: {avg_lpips}, FID: {fid_value}, LSE-C: {lse_c_score}, LSE-D: {lse_d_score}") + + return avg_psnr, avg_ssim, avg_lpips, fid_value, lse_c_score, lse_d_score + +# 评估整个测试集 +def evaluate_test_set(real_video_folder, fake_video_folder,evaluate_dir): + print("Evaluating test set...") + video_type = "_facial_dubbing_add_audio.mp4" + psnr_scores = [] + ssim_scores = [] + lpips_scores = [] # 修改为 LPIPS + fid_scores = [] + lse_c_scores = [] + lse_d_scores = [] + + fake_video_list = sorted(os.listdir(fake_video_folder)) + + for fake_video_file in fake_video_list: + if not fake_video_file.endswith(video_type): + continue + + part1 = fake_video_file[:-len(video_type)] + + real_video_path = os.path.join(real_video_folder, part1 + '.mp4') + fake_video_path = os.path.join(fake_video_folder, fake_video_file) + + if not os.path.exists(real_video_path): + print(f"Warning: Real video {real_video_path} not found for {fake_video_path}. Skipping...") + continue + + avg_psnr, avg_ssim, avg_lpips, fid_score, lse_c_score, lse_d_score = evaluate_video_quality(real_video_path, fake_video_path) + + psnr_scores.append(avg_psnr) + ssim_scores.append(avg_ssim) + lpips_scores.append(avg_lpips) # 修改为 LPIPS + fid_scores.append(fid_score) + lse_c_scores.append(lse_c_score) + lse_d_scores.append(lse_d_score) + + avg_psnr = np.mean(psnr_scores) + avg_ssim = np.mean(ssim_scores) + avg_lpips = np.mean(lpips_scores) # 修改为 LPIPS + avg_fid = np.mean(fid_scores) + avg_lse_c = np.mean(lse_c_scores) + avg_lse_d = np.mean(lse_d_scores) + + print(f"Final Results for Test Set:") + print(f"PSNR: {avg_psnr}, SSIM: {avg_ssim}, LPIPS: {avg_lpips}, FID: {avg_fid}, LSE-C: {avg_lse_c}, LSE-D: {avg_lse_d}") + data = { + 'PSNR': [avg_psnr], + 'SSIM': [avg_ssim], + 'LPIPS': [avg_lpips], + 'FID': [avg_fid], + 'LSE-C': [avg_lse_c], + 'LSE-D': [avg_lse_d] + } + df = pd.DataFrame(data) + evaluate_path=os.path.join(evaluate_dir,'evaluate.csv') + df.to_csv(evaluate_path, index=False) + +# 主程序 +if __name__ == "__main__": + print("Starting evaluation...") + # 初始化全局变量 + lpips_model = lpips.LPIPS(net='alex') # 加载 LPIPS 模型 + psnr_metric = PSNR() # 加载 PSNR 计算工具 + + dir = os.path.dirname(os.path.abspath(__file__)) + + real_video_folder = os.path.join(dir, 'asserts', 'test_data', "split_video_25fps") + fake_video_folder = os.path.join(dir, 'asserts', 'inference_result') + evaluate_dir=os.path.join(dir, 'asserts', 'evaluate_result') + if not os.path.exists(evaluate_dir): + os.makedirs(evaluate_dir) + print(f'create: {evaluate_dir}') + # 进行测试集的定性评估 + evaluate_test_set(real_video_folder, fake_video_folder,evaluate_dir=evaluate_dir) diff --git a/DINet/run_inference.py b/DINet/run_inference.py new file mode 100644 index 00000000..79e2230d --- /dev/null +++ b/DINet/run_inference.py @@ -0,0 +1,70 @@ +import os +import subprocess +import argparse + +def get_command_line_args(): + parser = argparse.ArgumentParser(description="Process video files with OpenFace landmarks and audio.") + + # 添加命令行参数 + parser.add_argument('--process_num', type=int, default=1000, help="num for deciding to process videos.") + parser.add_argument('--model_path', type=str, default='./asserts/training_model_weight/clip_training_256/netG_model_epoch_200.pth', help="num for deciding to process videos.") + parser.add_argument('--audio_path', type=str, default="./asserts/examples/driving_audio_1.wav" , help="path of audio") + + # 解析命令行参数 + return parser.parse_args() + +if __name__ =="__main__": + arg= get_command_line_args() + model_path=arg.model_path + if not os.path.exists(model_path): + model_path="./asserts/clip_training_DINet_256mouth.pth" + + print(f"model_path: {model_path}") + + audio_path=arg.audio_path + if not os.path.exists(audio_path): + audio_path="./asserts/examples/driving_audio_1.wav" + print(f"audio_path: {audio_path}") + + # 视频文件夹路径 + video_folder = "./asserts/test_data/split_video_25fps" # 替换为你的实际视频文件夹路径 + landmark_folder = "./asserts/test_data/split_video_25fps_landmark_openface" # 替换为你的landmark文件夹路径 + output_folder = "./asserts/inference_result" # 替换为你希望输出结果的文件夹路径 + #audio_folder = "./asserts/examples/driving_audio_1.wav" # 替换为你的音频文件夹路径 + audio_folder=audio_path + pretrained_model_path = model_path # 替换为你的预训练模型路径 + + # 确保输出文件夹存在 + os.makedirs(output_folder, exist_ok=True) + # 获取所有的视频文件 + video_files = [f for f in os.listdir(video_folder) if f.endswith('.mp4')] + max=len(video_files) + num=0 + + # 批量处理每个视频文件 + for video_file in video_files: + # 这里假设输入视频文件与landmark和音频文件具有相同的文件名 + video_name = os.path.splitext(video_file)[0] # 获取文件名(不带扩展名) + + # 构建输入文件路径 + input_video_path = os.path.join(video_folder, video_file) + input_landmark_path = os.path.join(landmark_folder, video_name + '.csv') + input_audio_path = audio_folder + + # 构建命令并调用模型推理的 Python 脚本 + try: + print(f"Processing video: {video_file}") + subprocess.run([ + "python", "inference.py", + "--mouth_region_size", "256", # 这里的参数根据需求调整 + "--source_video_path", input_video_path, + "--source_openface_landmark_path", input_landmark_path, + "--driving_audio_path", input_audio_path, + "--pretrained_clip_DINet_path", pretrained_model_path, + ], check=True) + print(f"Successfully processed: {video_file}") + except subprocess.CalledProcessError as e: + print(f"Error occurred while processing {video_file}: {e}") + num+=1 + if num>=arg.process_num or num>max: + break \ No newline at end of file diff --git a/DINet/sync_batchnorm/__init__.py b/DINet/sync_batchnorm/__init__.py new file mode 100644 index 00000000..6d9b36c7 --- /dev/null +++ b/DINet/sync_batchnorm/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import set_sbn_eps_mode +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .batchnorm import patch_sync_batchnorm, convert_model +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/DINet/sync_batchnorm/__pycache__/__init__.cpython-36.pyc b/DINet/sync_batchnorm/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..287dc98fc09f8672fb18092241d4ad5db948c461 GIT binary patch literal 465 zcmZ{gO-sWt7{`;ot(_YT87N*o^&;plA|mPv>ZJ%nCZQ0zHn6g!ElCvSH}Wgx>OuSp zUi3+QK@SV$$Nzbr@C(UgIPAY}e-!%;LNDmgf*PH}Rtca)CFV$Ag0amGWQRDyC9d#@ z*Y#ZJ`NS6i3A)87ZczTGW)aii)OWv1@i_>$+IVrmmh? zd70~CnKn`!mKz2~n+B6IcFef70;)hs7dF=5GQ*;&ADPtMl^l-kCb`aUbD49__`K4~ zDWpYStjyk*K(dxsMXsx+rp5Oxjr%70l}h_f(zWm`e2c)M2Qbl(fyvM|Bfv9~1Dh@0 zkEbWuonlgD^NaNA{yIyTTIr_ElADz2?EecgN~^lklpeHe6Q07)J#com1s-Lu9N;k? F;ZLSzho=Al literal 0 HcmV?d00001 diff --git a/DINet/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc b/DINet/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f362924b6b6523f1022cd768afce00107d8c9c3e GIT binary patch literal 15307 zcmeHOO>i8?b)KL7!|nnEL687{M7AVL0%%~(A2E&*2@(lNqEZqpnUtukArA)I4PX|t zGs~V?5`Y$|Sj&`ih#kAiha`vGoqG;B<&txbIm{_3r%J`DTyjaJd~v?lJv+Mq3H?M; zqyu+rdZvG-`*rtwufKl%=IZEZ`ICv?H{N~KFn((c{Z8Q-X-0J0Gz@NV(>Ln+Zq`i| zx9XO<+s#}vUoYU<@{6rfy<{4@lZe^=NUK~gw<`6DY1}b_b2&p)x65+PH0aOGoKY{= zD}Q1H)(6J!IlPhcM_Xg{G2|IM?;mT8*N>|==h8R-K7F%*k`rmk0x$X}TF0rRu8EiU z2ru&rukz8&LVc2txhMHCUV32EPw{bn9QP?c!B5~mEsW-AKFLq=Q+#UE?pgISDrFie zr#CIc&%R;sGyLoWgP#?po?SnO_&Gj;_=q@%_)~~K#mk77Rs1~S=XnM3ii%$lW|Q#? zy!ycCnf0eRWUMgLT8;G1N15xZG^!+ zA!F5?pX3|9+iEw6Qu@)|#zK9ZiQqSQz8FiMn zJfFWUqR8D8a+=a|5&D&ssB;lmG#PU=`=^DHq8HqAqL?>OIQyw7LM&QDAPmx4&{$(X zGd4`jNB+S`-R969kjEV94BYdiH*hbI?!dkHGb3lzOBj`7$tmZtGV;lPZT{#V$vh>VbU4(Z(@yoF(^8r;scM~+#k?ph zg1oG>$CHW^`R+a8w7s^NErR)kzmZe}k}ToI9VsID92%0(Q}6->M2!3{0^)zifp&`` z&`8=wl3>b<2#5`pb~DUzYr-s91+!$9tt##XbMM0A`K#9Na6^eRsvs3amlEL_=m`?x zu#C`*uEbW;1}_=lC~y$20JILlH1_TISQqu}=jhIJC$Q-46BXl7=a&3uJRNUN#3zW) zyLJ!nu3Gzr{bDeUoXO_NM{|&V>X)W0cg=XJTj-h72JY0K!X)aRq!{|wGpCIGQZRY2 zl%lBAErE9Q0zzC^QHpABc4^BE0^u{#C9LH|En24**d`?Aqy0H1cG{vrk2S^D4{7AJ zcAJ)@S(6eN9c&nECg$r&SuGHbTIQ;Ksb~o|NGko5^m zAVOZHCKR_9Hb^*8?w%*^E9EUOQD%(_mr-NN&+;isy^NqfLaB;nH8+wPogI-^=#9J= zv^#N9Q8y>rav@3=)Jthfr&TYfkECOBM7eyLUM)Z=;N8ROy9fdyY)W@XQ^aw(wBS#}SYwJO? z8}UxP*a+LZj>ix8|EzOJzev&cy@neL#|cqMR7PMxqzf2>vbA^Rh`xVHHJPlBqNWlq zH6H(hhZwR;paKBX{<*cU7c!(RlcZT#&hig1T4}7PVmi;pno}CYm5h?Eu=&VofY%MH zYyF+|3j^|)$9hBhBi}7_b3Gd=#r;y8r?QYdfZ+X+O=G{jvltgh)|ZVBub@n2zuGNB z(^RQGtWqQ03NKF@I^^Xk-|{q-HATZTkd_HJshs}=+BpMpYw53UsR zh|aM_*a8m%YXEQ1QiCb;8qu68x8BFZ%!_M}K%83`u4oJaz@&K^IY+wys#DlH~&Dh&w zHzSxBo)#`eFv;#FNd9n=6f8F|Ihbx5ykA*&b0FVVI&4K=LqSC=(63dyB;V}Ebq8qG z`Ca)XALB(sxcFy?E0|R<_FnzJ(%-KvBc0@xV3s1CfTkjSf=CBnl(52Bzq0REiSWU= zq6=@KsG>MAn}ZC!RKAV8pW;fX&Dx@2I-)HaCOyCDAem~A(kYfV@g|`zJBnyD1)XEpR-(2&N?G?B9-Nda&^7F;o<0zXOKZMXg1;Z2t*hRBVobhH}s3BPXCPZ$X*z;rN50$;YTpH4u#4}o{#@J%)eS^l>Y^2Mojjl_hp+3|k46s^B)$Dz&B&*Ub#u_Z51?IsqXeYz3g!4^cM9 z?r(XGEhfd?jt97{f!ZRwTuX{U=xj=tYg@o3aff02+bhZruy^vkU^@u!2MjxZP{+jr zfw+F5f65*pBMNI3x7^5$W0~aRP{RIEM_n44{Qyr%j`X=)pu`CZ$oim=e^P$m^*chD z8YIE`*^&}E`976Ua`ZAi5x>?87-0awdRceaAuw2?uG7_>Y2Ra{DOWZQHTn{$XNHsLwN)_F+^z`cW=&vmFp4@@zVDCF{W2$Fm4Kxpt;l>-co?t~B!^6hc%>o3 z2z~VZR3Ib9JCdZfzlBM=`#Oe5?7%4J3z9wzO&AEWA4F^|2x`NER&?7no{ikd*g z_d~gzR;y_0S3TG7#S0PhsAjSxXw^fzC9*(hV5f~)q(Fp=K5Ny1#zCj9@BsKjIe%Ac z8_c9G+}9Xx15BwFMXspRf}q=kmQl4~?v*kVZUiGlI3R6~t-}m*5NTq2h+P||58Q*w zJG9=`R5fn~*gd#>4jFid7;4Cd%Aga;ayECky89lc0=1QPWH+8;*ioUNn1AD)_lWD4-g$3c4Hrzum{>)p7IzS_ z1MYIdLrRbZ0!qJdE9jj0NHD>LcxrB4}0}cMyj50LoXFeYz91xrk6<4X-2i| z(_VWvrF{9I9HfKeySu1!TR?`H%hC-z6WDb?8(@vwz+NNDW>0lp4;sjXKKI<9A!q@G z2xGS(#&R&91Qs z+g^dam={C9!O$&20KlAH5c=V!k}XU|?6SAyQYFbx3kh^GzfPP(iMsJak3F{`kMZbCWCfT4C*n(P675glKeK4rQ zDT=3tUizW3K7FpY*bFHzrmS1EN-EBn)zrME7quZCTU=aZ&oVj@V{c{pgignm-OY?n1juY$O9aYl2w-G0de|pz+wx-i zY|%4JN;VndC^y}Kq~~y>2%!e#+}BNz!`+1-obkE?adS=BP{^mHXsW){shP>E9Mg*`vk}i7Sty|_8S4RH|N-$FN$Sv!N zI-qO@Cul)d@bzdUyTi8Bi6To8ZT9XHJo^pc**T^eDVv3P_Ba6BFZK-})+=8VV%7D2 z#fWu|UA;EPUcB}tkn1zRS3}74JkYB37_@4yJfvfWy{9vv`r=mrQ5AT+fRl<8l~#{N zrC0dpK&5g6G#-se2}^#0>&qe1Dr#`|J|7~z_C*kBAHIG|5h(!k6GZw1kv=i4J~6F+ z8%(Rq8A$j-2G;jOALdD%7Bq%Ttb`+BXH$SRoR7aY8*B!9J{NH1u!U~{w>|-|bkO$g zFv!lajPJ6^JPN~RCBFd}d-ZFAvAW)`3C5BY_WuRSk{{U+l)d&tE+~8TD}b^D zLZ7472$h$OqiLxGF=&{-Fpyei`Ik<;of?34lu? z9qwOE+>!!)Y13&B*^kL1Pm)6?r{0lcqG6S-apnD8unJb@=tg#Ix|IF>zwrR4Pa`&) z(~igDWAQkB z^#oVF4Y+VxaNow;aHh57I-EV>tc_~|*9lycJ+zAZNjTn~#QhZRrzjm|jpkHyx_P>1 zG|%*C&o$+cL-*MOX=k13yX$w1o$ny$T+cw8(`YjXKP|KQR5yni%i|2ef^+D|QI@~k*I!hCoO9g<&h>B(3BM({ z8X*EBd4HyBy0WOI3@vnkTFa*W;>yL${|P5%k3CaRWmJO)4>EPN`o!-5)qxNkr){Bj z0#F}CQH(0Tv7(%o{7*Z$!XZ5T1T*q!;}Qh@4yz-p;S%>eI%hwS{VdM3a2)GqZda_m zs|F)&n-*~T^-sT4p00<+O88L?9$E1ExcrARbb@&BXo=3t4!MiWvYzdpFpmbLQ-oV26a^GO+)LDoIZAt;f)^+t&oS*Y_DA&eAp-i` z4TtSXfy0jhZe}l0+8B9Rahykid8G)FGESaTPpsL3lvKn9LG?1AQkvd(U{}6FNscN& z)vAKQF_ls?)V+w*%p4w!<72Qszy-_9!3}c^u47fx!hal3r;uJk?xZ;lSItRl_SvM; z2=N(VgFX$4k}`%q7CW);t*bG`*FUL>tZj4xRmk=AK$nLSzK>SxXR@N{zz$?gk}j{G z%L??rzNtC=N+SIM5b67*@}^>Xq9J>UTo2TzKlKZbW*+d#mZVqJyeBz& zD;de0_vzdZN&@G8%GVE85JP``K_1yjSIdB3DDq2ORvW_$h4~uCQ`MZXEt<((&Mp;7`BI*~ Pxfu7y7x8~(?AZSTLaYU{ literal 0 HcmV?d00001 diff --git a/DINet/sync_batchnorm/__pycache__/comm.cpython-36.pyc b/DINet/sync_batchnorm/__pycache__/comm.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..481ad881ad054fd591ccb983595ddae7bfcb51eb GIT binary patch literal 4744 zcmai2+iu&)8RkKfB{_CnZ<39-0Xog5yH;VV=mBjJEQ&aBvu)9+o5XDvf(we~j3mkw zsmxGzEXX%;fL?ZAqc6~hA$r$~Vqc**?f1`!6n#J|gF|vSGyieE?;rnkZLRsQ&Hn`c zcP#6lR_Ry8^AlX^1P!;i9a&v-x4X7}ce)ODxXUX8x9fgs@hW#dSlkhn3%6Sp_Nc~d zd}Ux?INcTfR_6`gM30M}I>xVZ7vo(~$M^=`Z}AG=D|&vD+pn$G+P9Exb+TK}r|DFR zS3*srG;6$${1fq2I1yR%^;sOe@Kq{gKtuDNgNG-$)HWJnk-|{c?z*J#2di7*6<)=? zN-E-B<8_E@NSZf$y2sL?6h7ZoeqXS@md(QPBobp0r+ykHahv^73C@x@I%B@VUF@dG zE^cfBz1o3|h3H3jTJ#w%^$-oW=9Xur_Q<)QCl^nZxixYxY`j&u1vTtWtCp>Ja$&8} zC5#7I<@ZT&oK?MuZjEP2%tO)zJprCuc2-d$>i^A>w;<@>KYRB05K2mQ_-3#3<=4MC z+!yLNO(uu?zwLt;r@XXYSzqH0 z^PSG@)32mB36rT}#tCfbtItz*Bt*>m0xApMZdGL+VkG5hw!WueDLLcUQYKP1sk0U* zY1lt&IdWCMnzgiZOtVJa?qr&Nf(J#bS$5stguSzkOV(mY5z$MA{v1=Zp`R@cl@HM4 zAfhmhj6a9@+F)dK$xKAKvb4ZxTsif_G}qUKzSoo;h?GI5q6m|+O}pvTt}BdToz@Bk zGxo>A^RlMrjT1hN==qlCy`B2e;>(KX@g(p(xsDa(CN&iLl4O!4$ya)N&=EsK)=wlG>#e`52oM;h@H7qsvq6i@q$zYg!=-l6P{2l%f6muS2n6t>SO@!^1J^QHe zP`en_fv|hD`p49eGvpRE-=jtcV%g}BkY#OLs)B|>65w9^=<5?)Y9GzbZoP_wya7C1 z0Z^~vUKb6(^4bUM!tOTtI{a~4k|3X98|=l*KSIR)AZ5KJ5cgPeG!jADZZ;WyyKH|d zfnLB=f$cnr76IBJ+@wD4r#_qb(vKn$vC}Y$SSrK8KuG4tEI{{>9~`r!&x9WgnMq^J z#yAsq4y*`-VrrO7Bd&$gEY_1cBZb@!d8>!TfY(Q< z?`Wx{GYJ9JKuI6|qdg`h-gC#~>h1N>j{$)QPjYwX5_d}G!(qXw23Zql80|_$edWK6qLZca=OoEkBRnuXO6F01I z+IjBmy~!ZTqCGiKkF`{Kqn~lDm%HIHypunZzVM}q1ZTeX%+S|vDIA&0q3(2h`E2Nsr>B!H~)3PNqHzLO?GKnHU zAquq)2&01misEP4vyU%{U5>dFYZpVc#>5pQcy4oh3m9VO44Gcq6$r3 zV))IPEXNd@kQRhhT9e=5fi#iZXm#tU5qpvy(Z7Z!^Sm#cigLb;V+MEG=`aLN23QuGV7kDLB{^OLxCJ7V)}OU?YLg16p;!BiQ^Jnm|e26Biit!{pleekk2fBk6kh*%vZhJ2@XB)4fTA(m3XjBrp zW;IiAUrmJd9QzEN9$=`V^s$;XU9vS%1+Ck&tx|$so>3$Q6}2Q)C&u6Lpg{*G7b*t^ z_PO)!VM@V9`1h?p-#K^ZsOIwYpd^0#;0>iqjS}R%GOuuF3$y7ucGREJ6^?5Dz5Q-y zUcvlFht9csUY+CUaqP(W?IC8?=kAs@YK)rms$TE#jrHzPN#;GsT%;qnQ96Hi= z&l$GLq9bg#;j8xIzUV4qS-i1>9o~m2SaR=+?q+` z2O#G~X`?+@s7OeV)95MOqJ2TW>s?`hUYk8faSQ(S!w9QQschJRV8MDG<@Zn~aX1tF zDIMQBAnl$GIovDW`E*V#8C+n+|M9B+Tn=DqjRQ%{GHV|oFeWP&=dwi+aU`N79ssYB zlEd2H(xld^Muw~^{dgdZVYXs)Qt}QhbGb~;n#*dekI#xueAar+Y1$k19klnIb)2Rg`<}CbYj(Gc zB1IinoB=PL>WN=d!UG%{uq$z{b22v)vBo~o><#=Xe}XxG!DYzIncXSLD?HSJ?`}E^ z5b8<#$A{!9D literal 0 HcmV?d00001 diff --git a/DINet/sync_batchnorm/__pycache__/replicate.cpython-36.pyc b/DINet/sync_batchnorm/__pycache__/replicate.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb76b5c890c7f83024b409c03b74da1f5cbfa999 GIT binary patch literal 3420 zcmc&$U2oeq6eT5DcHAT_vY}m9bmKYvy&q(cfXg`$YGtwfd_ zl1`k$ewhojd)({(!Tyq7_b}`)^kuu3l47T5ieUpb1!O3iy1e)DoOAib++6kdg+GH| zW(oO?Og#>a??F>fpyPzoh&1uP716HMv?$>=cX)}rZKr8JB)rV+7ceKBvE3~33U^+R zrVH~kyae+lVZnTvSGf!G6;8KE?ZO{mnKWuNo!#`c|D7-WC=yY6aRdJE`N6?P5^FKg z>2)y>L0=0`ie3~3z7CVv3-U!O{b2g*?NE1$@6!)^z79HX{Vc$}i-*GC9yA3>Ac8b0 zCr!)vXh|EU8Z|qucwX#xh3BPJ&+8_PE6kyTi2y4PQA+}-(J38{9r z*BcL?KHAw7>Od#Goy~6=Lhq=PIPmt+y*QEGogy+(@12+}ub?{?bQKz)d-znKm7r}v zQv>LRa4c8Jh-#wgknUS!GNxlXvbcqZq3xkHrkgM~vImc}GlbD5dA71+jhsHN|a(HJOjuMW}^&f+8RR+)f>};4H zvmyvM9#$L=+x0vSr|b3d<2^kMNZ&4lmBJ4?EK4X0r0})iZ1036GHi#jA5B}WSIw4B zI$~n~IE*5;Cs-ICB?s{7F#v%1vfb~BSgUt#zZmGj5(|>vi5K$PZf3lZXu)*H*GtS- z3_zCq24e}ksHB_7UME!DC3Y-ei3;0!4646(Gmj&40!$kK8>bgl2MloLcaylyy8TFp zy$Iq+dZAD(X~CYMD)Inn;bE&KB*ehDI0^%yYV|yLne^M8YskDLz_}=xE7SiO1E#gEgluDcT8l*Z)J)Sj2))79( zNbo&`Ag2u)-{ChuK4B{IkBZ2%pd>>46z2$gq}>)YwE$fkkp8TSI3JUrtwU!k<5-Ov zs?@C1s^3bfH({T=0AIbXu0)~ILDKE5eU7Hku1lUTzf(@?D^UUgs5KZAU#HjJ0pdoj zoZ69yQwP_l(wWGVI#hRtd$b8S&g z2hgcf8m4Y$F{S${?YA4zBfoCI+oDVo<=VV7A{vDDDY_)(ujBKqqs%D(^WV-N4So(`Vl+sh?t!uebzBD z+*dm3A~fGL93rH#icXh}{~qao3*Bq%{HtDJV44f2*h0M(hyC!#2lduWCvPR)|4!Zj z%N9Nb^WKB8x(wYbK-drvHnIR~aMc6WSsCkSMlahCOJ0Yy1*7JYh)=FySv5gcrE{@3 zp8W6>_b~jjs>rY9^V2&im}e|BYHnH$RTwMn$AOR-MLNH(;7W|5d|%2$X2hE@LY*WL z)XZb)_Y@TLIX-=7ls7PfoQ$8~*a!jo_W+tgN|X6{%e5}MwHf&tyfr`8i3~dRIIb5r z&-z)L#vzaZxU9fM-py!aE)*tthUyZxm@2vEdT-KIS;?yM8orrJO2wMBEsB4xbIGY( Hx%B>D-(btD literal 0 HcmV?d00001 diff --git a/DINet/sync_batchnorm/batchnorm.py b/DINet/sync_batchnorm/batchnorm.py new file mode 100644 index 00000000..bf8d7a73 --- /dev/null +++ b/DINet/sync_batchnorm/batchnorm.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections +import contextlib + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm + +try: + from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast +except ImportError: + ReduceAddCoalesced = Broadcast = None + +try: + from jactorch.parallel.comm import SyncMaster + from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback +except ImportError: + from .comm import SyncMaster + from .replicate import DataParallelWithCallback + +__all__ = [ + 'set_sbn_eps_mode', + 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', + 'patch_sync_batchnorm', 'convert_model' +] + + +SBN_EPS_MODE = 'clamp' + + +def set_sbn_eps_mode(mode): + global SBN_EPS_MODE + assert mode in ('clamp', 'plus') + SBN_EPS_MODE = mode + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dimensions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' + + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + if not self.track_running_stats: + import warnings + warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + if hasattr(torch, 'no_grad'): + with torch.no_grad(): + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + if SBN_EPS_MODE == 'clamp': + return mean, bias_var.clamp(self.eps) ** -0.5 + elif SBN_EPS_MODE == 'plus': + return mean, (bias_var + self.eps) ** -0.5 + else: + raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + + +@contextlib.contextmanager +def patch_sync_batchnorm(): + import torch.nn as nn + + backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + + nn.BatchNorm1d = SynchronizedBatchNorm1d + nn.BatchNorm2d = SynchronizedBatchNorm2d + nn.BatchNorm3d = SynchronizedBatchNorm3d + + yield + + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup + + +def convert_model(module): + """Traverse the input module and its child recursively + and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d + to SynchronizedBatchNorm*N*d + + Args: + module: the input module needs to be convert to SyncBN model + + Examples: + >>> import torch.nn as nn + >>> import torchvision + >>> # m is a standard pytorch model + >>> m = torchvision.models.resnet18(True) + >>> m = nn.DataParallel(m) + >>> # after convert, m is using SyncBN + >>> m = convert_model(m) + """ + if isinstance(module, torch.nn.DataParallel): + mod = module.module + mod = convert_model(mod) + mod = DataParallelWithCallback(mod, device_ids=module.device_ids) + return mod + + mod = module + for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d], + [SynchronizedBatchNorm1d, + SynchronizedBatchNorm2d, + SynchronizedBatchNorm3d]): + if isinstance(module, pth_module): + mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + + for name, child in module.named_children(): + mod.add_module(name, convert_model(child)) + + return mod diff --git a/DINet/sync_batchnorm/batchnorm_reimpl.py b/DINet/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 00000000..18145c33 --- /dev/null +++ b/DINet/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNorm2dReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/DINet/sync_batchnorm/comm.py b/DINet/sync_batchnorm/comm.py new file mode 100644 index 00000000..922f8c4a --- /dev/null +++ b/DINet/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/DINet/sync_batchnorm/replicate.py b/DINet/sync_batchnorm/replicate.py new file mode 100644 index 00000000..b71c7b8e --- /dev/null +++ b/DINet/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/DINet/sync_batchnorm/unittest.py b/DINet/sync_batchnorm/unittest.py new file mode 100644 index 00000000..998223a0 --- /dev/null +++ b/DINet/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) + diff --git a/DINet/train_DINet_clip.py b/DINet/train_DINet_clip.py new file mode 100644 index 00000000..79dc8822 --- /dev/null +++ b/DINet/train_DINet_clip.py @@ -0,0 +1,156 @@ +from models.Discriminator import Discriminator +from models.VGG19 import Vgg19 +from models.DINet import DINet +from models.Syncnet import SyncNetPerception +from utils.training_utils import get_scheduler, update_learning_rate,GANLoss +from config.config import DINetTrainingOptions +from sync_batchnorm import convert_model +from torch.utils.data import DataLoader +from dataset.dataset_DINet_clip import DINetDataset + + +import random +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import os +import torch.nn.functional as F + +torch.cuda.set_per_process_memory_fraction(1.0) # 使用最多 90% 的 GPU 内存 + + +if __name__ == "__main__": + ''' + clip training code of DINet + in the resolution you want, using clip training code after frame training + + ''' + # load config + opt = DINetTrainingOptions().parse_args() + random.seed(opt.seed) + np.random.seed(opt.seed) + torch.cuda.manual_seed(opt.seed) + # load training data + train_data = DINetDataset(opt.train_data,opt.augment_num,opt.mouth_region_size) + training_data_loader = DataLoader(dataset=train_data, batch_size=opt.batch_size, shuffle=True,drop_last=True) + train_data_length = len(training_data_loader) + # init network + net_g = DINet(opt.source_channel,opt.ref_channel,opt.audio_channel).cuda() + net_dI = Discriminator(opt.source_channel ,opt.D_block_expansion, opt.D_num_blocks, opt.D_max_features).cuda() + net_dV = Discriminator(opt.source_channel * 5, opt.D_block_expansion, opt.D_num_blocks, opt.D_max_features).cuda() + net_vgg = Vgg19().cuda() + net_lipsync = SyncNetPerception(opt.pretrained_syncnet_path).cuda() + # parallel + net_g = nn.DataParallel(net_g) + net_g = convert_model(net_g) + net_dI = nn.DataParallel(net_dI) + net_dV = nn.DataParallel(net_dV) + net_vgg = nn.DataParallel(net_vgg) + # setup optimizer + optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr_g) + optimizer_dI = optim.Adam(net_dI.parameters(), lr=opt.lr_dI) + optimizer_dV = optim.Adam(net_dV.parameters(), lr=opt.lr_dI) + ## load frame trained DInet weight + print('loading frame trained DINet weight from: {}'.format(opt.pretrained_frame_DINet_path)) + checkpoint = torch.load(opt.pretrained_frame_DINet_path) + net_g.load_state_dict(checkpoint['state_dict']['net_g']) + # set criterion + criterionGAN = GANLoss().cuda() + criterionL1 = nn.L1Loss().cuda() + criterionMSE = nn.MSELoss().cuda() + # set scheduler + net_g_scheduler = get_scheduler(optimizer_g, opt.non_decay, opt.decay) + net_dI_scheduler = get_scheduler(optimizer_dI, opt.non_decay, opt.decay) + net_dV_scheduler = get_scheduler(optimizer_dV, opt.non_decay, opt.decay) + # set label of syncnet perception loss + real_tensor = torch.tensor(1.0).cuda() + # start train + for epoch in range(opt.start_epoch, opt.non_decay+opt.decay+1): + net_g.train() + for iteration, data in enumerate(training_data_loader): + # forward + source_clip,source_clip_mask, reference_clip,deep_speech_clip,deep_speech_full = data + source_clip = torch.cat(torch.split(source_clip, 1, dim=1), 0).squeeze(1).float().cuda() + source_clip_mask = torch.cat(torch.split(source_clip_mask, 1, dim=1), 0).squeeze(1).float().cuda() + reference_clip = torch.cat(torch.split(reference_clip, 1, dim=1), 0).squeeze(1).float().cuda() + deep_speech_clip = torch.cat(torch.split(deep_speech_clip, 1, dim=1), 0).squeeze(1).float().cuda() + deep_speech_full = deep_speech_full.float().cuda() + fake_out = net_g(source_clip_mask,reference_clip,deep_speech_clip) + fake_out_half = F.avg_pool2d(fake_out, 3, 2, 1, count_include_pad=False) + source_clip_half = F.interpolate(source_clip, scale_factor=0.5, mode='bilinear') + # (1) Update DI network + optimizer_dI.zero_grad() + _,pred_fake_dI = net_dI(fake_out) + loss_dI_fake = criterionGAN(pred_fake_dI, False) + _,pred_real_dI = net_dI(source_clip) + loss_dI_real = criterionGAN(pred_real_dI, True) + # Combined DI loss + loss_dI = (loss_dI_fake + loss_dI_real) * 0.5 + loss_dI.backward(retain_graph=True) + optimizer_dI.step() + + # (2) Update DV network + optimizer_dV.zero_grad() + condition_fake_dV = torch.cat(torch.split(fake_out, opt.batch_size, dim=0), 1) + _, pred_fake_dV = net_dV(condition_fake_dV) + loss_dV_fake = criterionGAN(pred_fake_dV, False) + condition_real_dV = torch.cat(torch.split(source_clip, opt.batch_size, dim=0), 1) + _, pred_real_dV = net_dV(condition_real_dV) + loss_dV_real = criterionGAN(pred_real_dV, True) + # Combined DV loss + loss_dV = (loss_dV_fake + loss_dV_real) * 0.5 + loss_dV.backward(retain_graph=True) + optimizer_dV.step() + + # (2) Update DINet + _, pred_fake_dI = net_dI(fake_out) + _, pred_fake_dV = net_dV(condition_fake_dV) + optimizer_g.zero_grad() + # compute perception loss + perception_real = net_vgg(source_clip) + perception_fake = net_vgg(fake_out) + perception_real_half = net_vgg(source_clip_half) + perception_fake_half = net_vgg(fake_out_half) + loss_g_perception = 0 + for i in range(len(perception_real)): + loss_g_perception += criterionL1(perception_fake[i], perception_real[i]) + loss_g_perception += criterionL1(perception_fake_half[i], perception_real_half[i]) + loss_g_perception = (loss_g_perception / (len(perception_real) * 2)) * opt.lamb_perception + # # gan dI loss + loss_g_dI = criterionGAN(pred_fake_dI, True) + # # gan dV loss + loss_g_dV = criterionGAN(pred_fake_dV, True) + ## sync perception loss + fake_out_clip = torch.cat(torch.split(fake_out, opt.batch_size, dim=0), 1) + fake_out_clip_mouth = fake_out_clip[:, :, train_data.radius:train_data.radius + train_data.mouth_region_size, + train_data.radius_1_4:train_data.radius_1_4 + train_data.mouth_region_size] + sync_score = net_lipsync(fake_out_clip_mouth, deep_speech_full) + loss_sync = criterionMSE(sync_score, real_tensor.expand_as(sync_score)) * opt.lamb_syncnet_perception + # combine all losses + loss_g = loss_g_perception + loss_g_dI +loss_g_dV + loss_sync + loss_g.backward() + optimizer_g.step() + + print( + "===> Epoch[{}]({}/{}): Loss_DI: {:.4f} Loss_GI: {:.4f} Loss_DV: {:.4f} Loss_GV: {:.4f} Loss_perception: {:.4f} Loss_sync: {:.4f} lr_g = {:.7f} ".format( + epoch, iteration, len(training_data_loader), float(loss_dI), float(loss_g_dI),float(loss_dV), float(loss_g_dV), float(loss_g_perception),float(loss_sync), + optimizer_g.param_groups[0]['lr'])) + + update_learning_rate(net_g_scheduler, optimizer_g) + update_learning_rate(net_dI_scheduler, optimizer_dI) + update_learning_rate(net_dV_scheduler, optimizer_dV) + # checkpoint + if epoch % opt.checkpoint == 0: + if not os.path.exists(opt.result_path): + os.mkdir(opt.result_path) + model_out_path = os.path.join(opt.result_path, 'netG_model_epoch_{}.pth'.format(epoch)) + states = { + 'epoch': epoch + 1, + 'state_dict': {'net_g': net_g.state_dict(),'net_dI': net_dI.state_dict(),'net_dV': net_dV.state_dict()}, + 'optimizer': {'net_g': optimizer_g.state_dict(), 'net_dI': optimizer_dI.state_dict(), 'net_dV': optimizer_dV.state_dict()} + } + torch.save(states, model_out_path) + print("Checkpoint saved to {}".format(epoch)) + torch.cuda.empty_cache() + diff --git a/DINet/train_DINet_frame.py b/DINet/train_DINet_frame.py new file mode 100644 index 00000000..70bd3b09 --- /dev/null +++ b/DINet/train_DINet_frame.py @@ -0,0 +1,121 @@ +from models.Discriminator import Discriminator +from models.VGG19 import Vgg19 +from models.DINet import DINet +from utils.training_utils import get_scheduler, update_learning_rate,GANLoss +from torch.utils.data import DataLoader +from dataset.dataset_DINet_frame import DINetDataset +from sync_batchnorm import convert_model +from config.config import DINetTrainingOptions + +import random +import numpy as np +import os +import torch.nn.functional as F +import torch +import torch.nn as nn +import torch.optim as optim + +if __name__ == "__main__": + ''' + frame training code of DINet + we use coarse-to-fine training strategy + so you can use this code to train the model in arbitrary resolution + ''' + # load config + opt = DINetTrainingOptions().parse_args() + # set seed + random.seed(opt.seed) + np.random.seed(opt.seed) + torch.cuda.manual_seed(opt.seed) + # load training data in memory + train_data = DINetDataset(opt.train_data,opt.augment_num,opt.mouth_region_size) + training_data_loader = DataLoader(dataset=train_data, batch_size=opt.batch_size, shuffle=True,drop_last=True) + train_data_length = len(training_data_loader) + # init network + net_g = DINet(opt.source_channel,opt.ref_channel,opt.audio_channel).cuda() + net_dI = Discriminator(opt.source_channel ,opt.D_block_expansion, opt.D_num_blocks, opt.D_max_features).cuda() + net_vgg = Vgg19().cuda() + # parallel + net_g = nn.DataParallel(net_g) + net_g = convert_model(net_g) + net_dI = nn.DataParallel(net_dI) + net_vgg = nn.DataParallel(net_vgg) + # setup optimizer + optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr_g) + optimizer_dI = optim.Adam(net_dI.parameters(), lr=opt.lr_dI) + # coarse2fine + if opt.coarse2fine: + print('loading checkpoint for coarse2fine training: {}'.format(opt.coarse_model_path)) + checkpoint = torch.load(opt.coarse_model_path) + net_g.load_state_dict(checkpoint['state_dict']['net_g']) + # set criterion + criterionGAN = GANLoss().cuda() + criterionL1 = nn.L1Loss().cuda() + # set scheduler + net_g_scheduler = get_scheduler(optimizer_g, opt.non_decay, opt.decay) + net_dI_scheduler = get_scheduler(optimizer_dI, opt.non_decay, opt.decay) + # start train + for epoch in range(opt.start_epoch, opt.non_decay+opt.decay+1): + net_g.train() + for iteration, data in enumerate(training_data_loader): + # read data + source_image_data,source_image_mask, reference_clip_data,deepspeech_feature = data + source_image_data = source_image_data.float().cuda() + source_image_mask = source_image_mask.float().cuda() + reference_clip_data = reference_clip_data.float().cuda() + deepspeech_feature = deepspeech_feature.float().cuda() + # network forward + fake_out = net_g(source_image_mask,reference_clip_data,deepspeech_feature) + # down sample output image and real image + fake_out_half = F.avg_pool2d(fake_out, 3, 2, 1, count_include_pad=False) + target_tensor_half = F.interpolate(source_image_data, scale_factor=0.5, mode='bilinear') + # (1) Update D network + optimizer_dI.zero_grad() + # compute fake loss + _,pred_fake_dI = net_dI(fake_out) + loss_dI_fake = criterionGAN(pred_fake_dI, False) + # compute real loss + _,pred_real_dI = net_dI(source_image_data) + loss_dI_real = criterionGAN(pred_real_dI, True) + # Combined DI loss + loss_dI = (loss_dI_fake + loss_dI_real) * 0.5 + loss_dI.backward(retain_graph=True) + optimizer_dI.step() + # (2) Update G network + _, pred_fake_dI = net_dI(fake_out) + optimizer_g.zero_grad() + # compute perception loss + perception_real = net_vgg(source_image_data) + perception_fake = net_vgg(fake_out) + perception_real_half = net_vgg(target_tensor_half) + perception_fake_half = net_vgg(fake_out_half) + loss_g_perception = 0 + for i in range(len(perception_real)): + loss_g_perception += criterionL1(perception_fake[i], perception_real[i]) + loss_g_perception += criterionL1(perception_fake_half[i], perception_real_half[i]) + loss_g_perception = (loss_g_perception / (len(perception_real) * 2)) * opt.lamb_perception + # # gan dI loss + loss_g_dI = criterionGAN(pred_fake_dI, True) + # combine perception loss and gan loss + loss_g = loss_g_perception + loss_g_dI + loss_g.backward() + optimizer_g.step() + + print( + "===> Epoch[{}]({}/{}): Loss_DI: {:.4f} Loss_GI: {:.4f} Loss_perception: {:.4f} lr_g = {:.7f} ".format( + epoch, iteration, len(training_data_loader), float(loss_dI), float(loss_g_dI), float(loss_g_perception),optimizer_g.param_groups[0]['lr'])) + + update_learning_rate(net_g_scheduler, optimizer_g) + update_learning_rate(net_dI_scheduler, optimizer_dI) + #checkpoint + if epoch % opt.checkpoint == 0: + if not os.path.exists(opt.result_path): + os.mkdir(opt.result_path) + model_out_path = os.path.join(opt.result_path, 'netG_model_epoch_{}.pth'.format(epoch)) + states = { + 'epoch': epoch + 1, + 'state_dict': {'net_g': net_g.state_dict(), 'net_dI': net_dI.state_dict()},# + 'optimizer': {'net_g': optimizer_g.state_dict(), 'net_dI': optimizer_dI.state_dict()}# + } + torch.save(states, model_out_path) + print("Checkpoint saved to {}".format(epoch)) diff --git a/DINet/utils/__pycache__/data_processing.cpython-36.pyc b/DINet/utils/__pycache__/data_processing.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdaea793a7a3e089bf2496040519282bd06d5eb8 GIT binary patch literal 2098 zcmbVN&2Jk;6rY*>@Oteyshc)U3YED4Mko}hdZ?-v)Tp!VQ|eiq?vr+90UgrmQFUO|u^ZTT zu+!fGG9vNfoPB>Krx_hGN#`_Y)6fj0`d5=01_;(B>88E94wrSNh7zoKhk!+_}rU zHNJlK?(w|?l2wL|H1Tp(V-_$MSXEY~J7ByvXL(K55cS5%2W|^1P6FOars1R?^uksg zO!znuqgG!eZ(;$0P(7sSX7LG1b5XXf*^Hq{P`&gev6&3fR2@hvN zrNFE9MKBJ1SoA70Uq9?sj+aPdRXct0N|?pP*Vk9&JG}xx%sO@HC2HcgL+f+{=Nq&I zo(o(kk)HY()l!L34(aCp8Q+0w?QI3oWQ^oSZf0bN&seqhGnQMDadYH|8yOq2IpNl2 z!tE`RvxzBL&LrAfH8^K|=tHZdDxi&Ra*0T`NpgF|cQ{r26#x|raD`-|W zFSolxFesE)7!_!2GBjnC*DhluaJEHw4K~o=yaqXU_4d*SS-nKi@~?OXX?K%s5$Ph9 z`s%j6^bb1Hh2GkKZr#Ge9brZNk(N|@iE%WkqK_h&uMOB z$u%_hTi(=M#8nPL-qM~A5v*S#+Ae?i)`FoRi?b`~zE!Ee8+_1Epio1hlg6}So@)aYNR?q0kF zaCZQJ_4n;B+w5-VrCkwQxG9WrFhf(piDWFc!4Vfqb~GvMEKzOIUN77y5f2kTjZQ+Z zx%%F9|McQ$TGY!2{1hWA^eUzF(JAuCJGat-IbW%GP8p{EXoVelwIxprxJv>WUgHJs zf9PNHTqWU;qY2cPINn|1;2)jgpPAlz&K2jE`Sw*6{POL~2s`}a!E=QEnichMD7FvT6O@VFp3~S1iO8x&{hp%3^oQ(g9}@kAyd>w$N~V&gKzA``&fG8OJKvd4yWRNv{!io8uH*d2 zx$(K6e-23-5W*2;?l59=%Bam<=0Z;A-oj@-!9MOhScELJI+;hT_1qDj@V|A0zo9Ia zZ4sPNRf>>xM8vu>mfa0uJ8~!6l^xJ}a!>X)1jdtNXV|)gm7Sv@sl)T+kUW7UUpO!qJxc%ind1zD`Ytb(EE6TUfXTI; zoafTi{!}I9tZvPVX=*eFwVUSU((t?x5{td2IG6d<%xs6gNz_ytu9HPMms}-A4r%Ra zIiJ+=0*1|bnV8wX9My%5{`mOIrzc-&sr1R$UmQJu@u!nRsn1MNo*e%9NSc$SN$2_m zcg34+jmlLWah~R>;k=FK^*}iE0a@SL-puH_$J;B4pc8htSF{GjI2p-k=2k?I4gGF# zL7~RW{4Cfw7jETdVMViOXgOX%SM#s$wIDgZR;L3m;Ur%efLJ?_WGgW!GASe z8RF@|uv1erscBj;e~Lh4{_|&$)xo$}fB|ay@)3)kqj?YIgzf${QCdDz#p2kgG@sU< zl8In@Xh|U_$#QPE#RBW4i?UExC2V7T>||uF3&l@Y9M3U~6kay#Z#T%xQ)~E|OjyfK zg@aW9-BgGIEu>l~5EzZvK@E}JgQWW)9MUHpau3oDiSeiFcMsos#OQwjd5m-57em6; zEDZh)iUB{VD43DFalv%p8AuSX-TTgkZ%Ad$Xb>h|3c;x1#6_bE zT6z0G9e16J2-<=y%pz;@2*$ZEt|k2Y&M$KBzB6$}Ai_7^MGMBnRjUd_G^HDA%k8QK zGq);i>!5zm0j-PLI}X+(;T#R!I?(AfPv$Icd=6cX1^IF@mgOZ<1E9+nW>v~d0wsmw z84UCFEGcCjYOuJFnzhbVYGiUcmzO`mH{4M!^sHULR88}e`D>{Pt?pxGdtBsWSSN=B zJv`j%TrqDUlbpHQh+%B8Qv1nLq($AnK^%1l zPQl?AFw=5Tw{5A&jeKQTG?9r}!lcYrUmv0?7o}No*q-gxK2LK|oO7M7CF>ZKq&aX0 zFqr-o%IXhb$V+bbz)f+J2pgAsDaWR;N6{a{c^sb-{3CbAjof|G^ExELGW>n0?a>%8 z+M(-%8|eJMmUv59=nLoo7OOvif|O-honBB&Q4llET=m$vfJ%x)1vm!$z6MH~P(ZK? z$aXRx`n)iJNFcSyw6b`Eh=+RN0p`8Dv#keGYiFITo9#gRuJGTG?`Gf?PSyh|!?AlX z7DobVMFhAF0MUJ@xdmyz3ZTbr(ZaQ%e5VSsy8=9EN>VCf^qMNP{?&esQO@pkf}R#_ z3ATy=)tw{t0M5I9@&!U-FiZ4c0SuHvKpBi@NuJBO{$yZgaxh7iHrI87R1cO~3X5<6 zIC;fFt$3l*DGNZ!<@u0U#<|4lFY$tho_Y^5g;7&|fZ{_Cqak6fiIjqibZpc;s2bi= z=u7H02q0^bF6!t!Q8|#Mu07Ngl1q&^htRiTQKKN`CK|m@{T|!>$-GF6+QYeSq1Xq( z!egm*TI3BMeuPzinwv*Yn5&lgmLOROhe%$fx?#$9pyROB_)pwm7lSL?YxEngWKYR< zulFPxq%(-iTC>hgh_fAt{O5%_mL!A!h>>5eg+Lw16cq7 literal 0 HcmV?d00001 diff --git a/DINet/utils/__pycache__/training_utils.cpython-36.pyc b/DINet/utils/__pycache__/training_utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79b48d04ad6b1a533f961043b66af642eb3f8e9d GIT binary patch literal 2300 zcmZWqOOG2x5bmCr$Fufsb`ym}yi7y_Mv=uU3WShNvVkSi#zaXFXfTW>({1lK8|gq>hgQdX7F_FpX8?+ zA%BrW&4Kw<=<-t#j4+y!5e+HDxs_Q%JG9Nr2_3e?T;}c4&}B8|?~>5lB&^PYUBUux z4{K99^qF;+v>W?y7HL!6$V4QQZO+CS7hUaT@#sE__umC=;-U>SaL9BUM)@O%%4YO| zCC=@UJ=(XJ#q3>w$5muVD^j`4d1CE&c@L~yu(}HS7VJCBW3?wXXs(0C2hF#ju`A;s zINo5*CoX7T2kkOwAA;7Y$Ya7**fH~6pj@_E*;4>ec>rQEGhF7W;zBQ(A!0mUD9(qY3)0`XDkdu26 zQdZ+Cj^OCY`HaWHgiFAMt+L3{TIs0F_=x9Zp_k4B3TzoChd7J?VusYr?)XjtAJ`ti2?V7F^rAkMDUWhsz zFOc34`I>nXF2mGBFT!|A(L^8W&y!9p!Mbf7d@CAg&^!_rBut(JK^&hh+sCL+*Px%o z(F&+|4KxXY!-naKeFM5Y2cjZFG9{E$bVwOJ4cU5{>{tpOUS*wxOvPGYveeM5txQZ# z93}SS*4fU*e%lr+aGY+Gv4}^}KosLrYPS?=u0YoUop==mJ`00uCJdkxZd;+&GZp0BgGAVC^P((q2np#84a0$f+ zUcNJQD{D%p7O;yumf5lAJ5*d(9@I_HMSdQ@ygr*7zU{6K1Y{RHtouCEK|g)~^x(W5 zXjhI)$S4e-_M-@Dl8Pd|BKRPM=O&{2<9;8ig9|Sw+18%oxr8@l<+-liy88(N)(+y- zwd)&ZiSJGFtRL3L60VVhI1g7jLP;nky!@QFiUINiu zgP%hi*5ufn^3KBV1WCk(gkqYS)O-mhl_eUOo;!B#!AN)P%7%JRr&u;=7qY4iR^lxX zy3QY#an7PxwjJ?4?zm}Qj+I`X!$_q;qMP&ejFkT-;wmCOVu7bi9|QsaeaoYh(+dRu zvpoU@VHqfx#|WdMkPQ}qz&|rO`?En9mbe5K0+V>|ySRuOAE5XU1y+Fra&@g=i0@;; z-hu6Tt!~iC%EGO=wr$@NZ{UY-f@u4gD0w{MkSIYEL2-p*5A#M8eK(G?xuzCHtVke| z7{8fYAHhsF=P^!rmM!9JN)r_q&=R9;((z@O7`A-dq04o@W_AH0vLP_S?YdqvdD$_g zab9(2J9#eNM!OCQOds>+wRg*SYmPd^g<%D6*8b%go-6p*DY3>jXaK video_h: + return False,None + elif max(clip_max_w.tolist()) > video_w: + return False,None + elif max(radius_clip) > min(radius_clip) * 1.5: + return False, None + else: + return True,radius_max \ No newline at end of file diff --git a/DINet/utils/deep_speech.py b/DINet/utils/deep_speech.py new file mode 100644 index 00000000..f26a5124 --- /dev/null +++ b/DINet/utils/deep_speech.py @@ -0,0 +1,99 @@ + +import numpy as np +import warnings +import resampy +from scipy.io import wavfile +from python_speech_features import mfcc +import tensorflow as tf + + +class DeepSpeech(): + def __init__(self,model_path): + self.graph, self.logits_ph, self.input_node_ph, self.input_lengths_ph \ + = self._prepare_deepspeech_net(model_path) + self.target_sample_rate = 16000 + + def _prepare_deepspeech_net(self,deepspeech_pb_path): + with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + graph = tf.compat.v1.get_default_graph() + tf.import_graph_def(graph_def, name="deepspeech") + logits_ph = graph.get_tensor_by_name("deepspeech/logits:0") + input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0") + input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0") + + return graph, logits_ph, input_node_ph, input_lengths_ph + + def conv_audio_to_deepspeech_input_vector(self,audio, + sample_rate, + num_cepstrum, + num_context): + # Get mfcc coefficients: + features = mfcc( + signal=audio, + samplerate=sample_rate, + numcep=num_cepstrum) + + # We only keep every second feature (BiRNN stride = 2): + features = features[::2] + + # One stride per time step in the input: + num_strides = len(features) + + # Add empty initial and final contexts: + empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype) + features = np.concatenate((empty_context, features, empty_context)) + + # Create a view into the array with overlapping strides of size + # numcontext (past) + 1 (present) + numcontext (future): + window_size = 2 * num_context + 1 + train_inputs = np.lib.stride_tricks.as_strided( + features, + shape=(num_strides, window_size, num_cepstrum), + strides=(features.strides[0], + features.strides[0], features.strides[1]), + writeable=False) + + # Flatten the second and third dimensions: + train_inputs = np.reshape(train_inputs, [num_strides, -1]) + + train_inputs = np.copy(train_inputs) + train_inputs = (train_inputs - np.mean(train_inputs)) / \ + np.std(train_inputs) + + return train_inputs + + def compute_audio_feature(self,audio_path): + audio_sample_rate, audio = wavfile.read(audio_path) + if audio.ndim != 1: + warnings.warn( + "Audio has multiple channels, the first channel is used") + audio = audio[:, 0] + if audio_sample_rate != self.target_sample_rate: + resampled_audio = resampy.resample( + x=audio.astype(np.float), + sr_orig=audio_sample_rate, + sr_new=self.target_sample_rate) + else: + resampled_audio = audio.astype(np.float) + with tf.compat.v1.Session(graph=self.graph) as sess: + input_vector = self.conv_audio_to_deepspeech_input_vector( + audio=resampled_audio.astype(np.int16), + sample_rate=self.target_sample_rate, + num_cepstrum=26, + num_context=9) + network_output = sess.run( + self.logits_ph, + feed_dict={ + self.input_node_ph: input_vector[np.newaxis, ...], + self.input_lengths_ph: [input_vector.shape[0]]}) + ds_features = network_output[::2,0,:] + return ds_features + +if __name__ == '__main__': + audio_path = r'./00168.wav' + model_path = r'./output_graph.pb' + DSModel = DeepSpeech(model_path) + ds_feature = DSModel.compute_audio_feature(audio_path) + print(ds_feature) \ No newline at end of file diff --git a/DINet/utils/training_utils.py b/DINet/utils/training_utils.py new file mode 100644 index 00000000..30f186f6 --- /dev/null +++ b/DINet/utils/training_utils.py @@ -0,0 +1,51 @@ +from torch.optim import lr_scheduler +import torch.nn as nn +import torch + +def get_scheduler(optimizer, niter,niter_decay,lr_policy='lambda',lr_decay_iters=50): + ''' + scheduler in training stage + ''' + if lr_policy == 'lambda': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch - niter) / float(niter_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_decay_iters, gamma=0.1) + elif lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=niter, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy) + return scheduler + +def update_learning_rate(scheduler, optimizer): + scheduler.step() + lr = optimizer.param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + +class GANLoss(nn.Module): + ''' + GAN loss + ''' + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(input) + + def forward(self, input, target_is_real): + target_tensor = self.get_target_tensor(input, target_is_real) + return self.loss(input, target_tensor) \ No newline at end of file diff --git "a/DINet/\351\205\215\347\275\256\346\226\207\346\241\243.txt" "b/DINet/\351\205\215\347\275\256\346\226\207\346\241\243.txt" new file mode 100644 index 00000000..ce6b0887 --- /dev/null +++ "b/DINet/\351\205\215\347\275\256\346\226\207\346\241\243.txt" @@ -0,0 +1,81 @@ +在 Google Drive 中下载资源 (asserts.zip)。解压缩并将 dir 放入 ./ 中 + +(注意:要在该asserts文件下创建一个文件夹test_data,然后再test_data下面创建两个文件夹split_video_25fps和split_video_25fps_landmark_openface,用来存放测试集和推理数据) + + +为了便于挂载获取需要提前宿主机下载好整个项目DINet + +拉取项目镜像: +docker pull kevia/dinet:latest + +使用镜像: +1.数据处理: + + (1)从 HDTF 数据集下载视频。根据 xx_annotion_time.txt 分割视频,不裁剪和调整视频大小。 + (2)将所有分割的视频重新采样为 25fps,并将视频放入 “./asserts/training_data/split_video_25fps”。 + (3)使用 openface 检测所有视频的平滑面部特征点。将所有 “.csv” 结果放入 “./asserts/training_data/split_video_25fps_landmark_openface” 中。 + 注意:(1),(2),(3)需要处理好数据加入上述对应项目文件夹下面,这里面处理了HDTF数据集,因此可以将./asserts/training_data/split_video_25fps和 + ./asserts/training_data/split_video_25fps_landmark_openface里相对应部分数据移动到./asserts/test_data/split_video_25fps和 + ./asserts/test_data/split_video_25fps_landmark_openface下面作为测试集数据 + + 并且完成上述可以使用预训练模型推理,如果要训练必须进行下面操作对训练数据进行处理(由于提交下面处理数据占用空间过大没法提交) + + (4)从训练集视频中提取帧并将帧保存在 “./asserts/trainning_data/split_video_25fps_frame” 中。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest data_processing.py --extract_video_frame + + (5)从训练集视频中提取音频,并将音频保存在 ./asserts/trainning_data/split_video_25fps_audio 中。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest data_processing.py --extract_audio + + (6)从训练集音频中提取 deepspeech 特征并将特征保存在 “./asserts/trainning_data/split_video_25fps_deepspeech” 中。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest data_processing.py --extract_deep_speech + + (7)裁剪训练集视频的人脸并将图像保存在 “./asserts/trainning_data/split_video_25fps_crop_face” 中。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest data_processing.py --crop_face + + (8)生成训练 json 文件 “./asserts/trainning_data/training_json.json”。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest data_processing.py --generate_training_json + +2.训练过程: + + 我们将训练过程分为帧训练阶段和 clip 训练阶段。在帧训练阶段,我们使用从粗到细的策略,因此您可以在任意分辨率下训练模型。 + (1)框架训练阶段。 + 在帧训练阶段,我们只使用感知损失和 GAN 损失。 + 首先,以 104x80(嘴部区域为 64x64)分辨率训练 DINet。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest train_DINet_frame.py \ + --augment_num=32 --mouth_region_size=64 --batch_size=24 --result_path=./asserts/training_model_weight/frame_training_64 + 您可以在损失收敛时停止训练(我们在大约 270 个 epoch 中停止)。 + + (2)加载预训练模型(面部:104x80 & 嘴巴:64x64)并以更高分辨率训练DINet(面部:208x160 & 嘴巴:128x128)。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest train_DINet_frame.py \ + --augment_num=100 --mouth_region_size=128 --batch_size=80 --coarse2fine \ + --coarse_model_path=./asserts/training_model_weight/frame_training_64/xxxxxx.pth --result_path=./asserts/training_model_weight/frame_training_128 + 您可以在损失收敛时停止训练(我们在大约 200 个 epoch 中停止)。 + + (3)加载预训练模型(面部:208x160 & 嘴巴:128x128)并以更高分辨率训练DINet(面部:416x320 & 嘴巴:256x256)。跑 + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest train_DINet_frame.py \ + --augment_num=20 --mouth_region_size=256 --batch_size=12 --coarse2fine \ + --coarse_model_path=./asserts/training_model_weight/frame_training_128/xxxxxx.pth --result_path=./asserts/training_model_weight/frame_training_256 + 您可以在损失收敛时停止训练(我们在大约 200 个 epoch 中停止)。 + + (4)在剪辑训练阶段,我们使用感知损失、帧/剪辑 GAN 损失和同步损失。加载预训练的帧模型(面部:416x320 & 嘴巴:256x256),预训练的同步网络模型(嘴巴:256x256)并在剪辑设置中训练DINet。跑 + + docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest train_DINet_clip.py \ + --augment_num=3 --mouth_region_size=256 --batch_size=3 \ + --pretrained_syncnet_path=./asserts/syncnet_256mouth.pth --pretrained_frame_DINet_path=./asserts/training_model_weight/frame_training_256/xxxxx.pth \ + --result_path=./asserts/training_model_weight/clip_training_256 + 您可以在损失收敛时停止训练并选择最佳模型(我们的最佳模型为 160 epoch)。 + +3.推理评估阶段 + + (1)docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest run_inference.py --process_num x + process_num 遍历推理个数,结果保存在./asserts/inference_result里面 + 推理阶段默认使用预训练模型,如果使用训练模型加入参数 --model_path ./asserts/training_model_weight/clip_training_256/xxxx.pth + 可以选择音频路径 --audio_path /path/to/your_auido + + + (2)docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest run_evaluate.py + 评估结果保存在./asserts/evaluate_result里面。 + + 注意:如果像推理自己的数据: + 和上述数据处理一样的步骤(1),(2),(3)放到./asserts/test_data文件里相应位置。 + From 6c9e06bbec1686ea5930e2a97ba3790055d71361 Mon Sep 17 00:00:00 2001 From: KeviaWang <104704083+KeviaWang@users.noreply.github.com> Date: Sun, 22 Dec 2024 20:24:22 +0800 Subject: [PATCH 2/3] Add files via upload --- DINet/README.md | 78 ++++++++++++++++++++++++------------------------- 1 file changed, 38 insertions(+), 40 deletions(-) diff --git a/DINet/README.md b/DINet/README.md index d093c347..7c23017a 100644 --- a/DINet/README.md +++ b/DINet/README.md @@ -2,87 +2,85 @@ ![在这里插入图片描述](https://img-blog.csdnimg.cn/178c6b3ec0074af7a2dcc9ef26450e75.png) [Paper](https://fuxivirtualhuman.github.io/pdf/AAAI2023_FaceDubbing.pdf)         [demo video](https://www.youtube.com/watch?v=UU344T-9h7M&t=6s)      Supplementary materials -## Inference -##### Download resources (asserts.zip) in [Google drive](https://drive.google.com/drive/folders/1rPtOo9Uuhc59YfFVv4gBmkh0_oG0nCQb?usp=share_link). unzip and put dir in ./. -+ Inference with example videos. Run + +这是2024年语音识别课程大作业的仓库,用于[DINet](https://github.com/MRzzm/DINet)的复现 +## 数据获取 +##### 在 [Google drive](https://drive.google.com/drive/folders/1rPtOo9Uuhc59YfFVv4gBmkh0_oG0nCQb?usp=share_link)中下载资源 (asserts.zip)。解压缩并将 dir 放入 ./ 中 ++ 使用示例视频进行推理。运行 ```python python inference.py --mouth_region_size=256 --source_video_path=./asserts/examples/testxxx.mp4 --source_openface_landmark_path=./asserts/examples/testxxx.csv --driving_audio_path=./asserts/examples/driving_audio_xxx.wav --pretrained_clip_DINet_path=./asserts/clip_training_DINet_256mouth.pth ``` -The results are saved in ./asserts/inference_result +结果保存在 ./asserts/inference_result + ++ 使用自定义视频进行推理。 +**Note:** 发布的预训练模型是在 HDTF 数据集上训练的。(视频名称在 ./asserts/training_video_name.txt 中) -+ Inference with custom videos. -**Note:** The released pretrained model is trained on HDTF dataset with 363 training videos (video names are in ./asserts/training_video_name.txt), so the generalization is limited. It would be better to test custom videos with normal lighting, frontal view etc.(see the limitation section in the paper). **We also release the training code**, so if a larger high resolution audio-visual dataset is proposed in the further, you can use the training code to train a model with greater generalization. Besides, we release coarse-to-fine training strategy, **so you can use the training code to train a model in arbitrary resolution** (larger than 416x320 if gpu memory and training dataset are available). +使用 [openface](https://github.com/TadasBaltrusaitis/OpenFace)检测自定义视频的平滑面部特征点。 -Using [openface](https://github.com/TadasBaltrusaitis/OpenFace) to detect smooth facial landmarks of your custom video. We run the **OpenFaceOffline.exe** on windows 10 system with this setting: - -| Record | Recording settings | OpenFace setting | View | Face Detector | Landmark Detector | -|--|--|--|--|--|--| -| 2D landmark & tracked videos | Mask aligned image | Use dynamic AU models | Show video | Openface (MTCNN)| CE-CLM | -The detected facial landmarks are saved in "xxxx.csv". Run +检测到的人脸特征点保存在 “xxxx.csv” 中。运行 ```python python inference.py --mouth_region_size=256 --source_video_path= custom video path --source_openface_landmark_path= detected landmark path --driving_audio_path= driving audio path --pretrained_clip_DINet_path=./asserts/clip_training_DINet_256mouth.pth ``` -to realize face visually dubbing on your custom videos. -## Training -### Data Processing -We release the code of video processing on [HDTF dataset](https://github.com/MRzzm/HDTF). You can also use this code to process custom videos. +在您的自定义视频上实现人脸视觉配音。 +## 训练 +### 数据处理 + + 1. 从[HDTF](https://github.com/MRzzm/HDTF)下载视频。根据 xx_annotion_time.txt 分割视频,不裁剪和调整视频大小。 + 2. 将所有分割的视频重新采样为 25fps,并将视频放入 “./asserts/split_video_25fps”。您可以在 “./asserts/split_video_25fps” 中看到两个示例视频。我们使用[软件](http://www.pcfreetime.com/formatfactory/cn/index.html) 对视频进行重新采样。我们在实验中提供了训练视频的名称列表。(请参阅“./asserts/training_video_name.txt”) + 3. 使用 [openface](https://github.com/TadasBaltrusaitis/OpenFace) 检测所有视频的平滑面部特征点。将所有 “.csv” 结果放入 “./asserts/split_video_25fps_landmark_openface” 中。您可以在 “./asserts/split_video_25fps_landmark_openface” 中看到两个示例 csv 文件。 + - 1. Downloading videos from [HDTF dataset](https://github.com/MRzzm/HDTF). Splitting videos according to xx_annotion_time.txt and **do not** crop&resize videos. - 2. Resampling all split videos into **25fps** and put videos into "./asserts/split_video_25fps". You can see the two example videos in "./asserts/split_video_25fps". We use [software](http://www.pcfreetime.com/formatfactory/cn/index.html) to resample videos. We provide the name list of training videos in our experiment. (pls see "./asserts/training_video_name.txt") - 3. Using [openface](https://github.com/TadasBaltrusaitis/OpenFace) to detect smooth facial landmarks of all videos. Putting all ".csv" results into "./asserts/split_video_25fps_landmark_openface". You can see the two example csv files in "./asserts/split_video_25fps_landmark_openface". - 4. Extracting frames from all videos and saving frames in "./asserts/split_video_25fps_frame". Run + 4. 从所有视频中提取帧并将帧保存在 “./asserts/split_video_25fps_frame” 中。运行 ```python python data_processing.py --extract_video_frame ``` - 5. Extracting audios from all videos and saving audios in "./asserts/split_video_25fps_audio". Run + 5. 从所有视频中提取音频,并将音频保存在 ./asserts/split_video_25fps_audio 中。运行 ```python python data_processing.py --extract_audio ``` - 6. Extracting deepspeech features from all audios and saving features in "./asserts/split_video_25fps_deepspeech". Run + 6. 从所有音频中提取 deepspeech 特征并将特征保存在 “./asserts/split_video_25fps_deepspeech” 中。运行 ```python python data_processing.py --extract_deep_speech ``` - 7. Cropping faces from all videos and saving images in "./asserts/split_video_25fps_crop_face". Run + 7. 裁剪所有视频的人脸并将图像保存在 “./asserts/split_video_25fps_crop_face” 中。运行 ```python python data_processing.py --crop_face ``` - 8. Generating training json file "./asserts/training_json.json". Run + 8. 生成训练 json 文件 “./asserts/training_json.json”。运行 ```python python data_processing.py --generate_training_json ``` -### Training models -We split the training process into **frame training stage** and **clip training stage**. In frame training stage, we use coarse-to-fine strategy, **so you can train the model in arbitrary resolution**. +### 训练模型 +训练过程分为帧训练阶段和 clip 训练阶段。在帧训练阶段,我们使用从粗到细的策略,因此您可以在任意分辨率下训练模型。 -#### Frame training stage. -In frame training stage, we only use perception loss and GAN loss. +#### 框架训练阶段。 +在帧训练阶段,我们只使用感知损失和 GAN 损失 - 1. Firstly, train the DINet in 104x80 (mouth region is 64x64) resolution. Run + 1. 首先,以 104x80(嘴部区域为 64x64)分辨率训练 DINet。运行 ```python python train_DINet_frame.py --augment_num=32 --mouth_region_size=64 --batch_size=24 --result_path=./asserts/training_model_weight/frame_training_64 ``` -You can stop the training when the loss converges (we stop in about 270 epoch). - 2. Loading the pretrained model (face:104x80 & mouth:64x64) and train the DINet in higher resolution (face:208x160 & mouth:128x128). Run - ```python + + 2. 加载预训练模型(面部:104x80 & 嘴巴:64x64)并以更高分辨率训练DINet(面部:208x160 & 嘴巴:128x128)。运行 python train_DINet_frame.py --augment_num=100 --mouth_region_size=128 --batch_size=80 --coarse2fine --coarse_model_path=./asserts/training_model_weight/frame_training_64/xxxxxx.pth --result_path=./asserts/training_model_weight/frame_training_128 ``` -You can stop the training when the loss converges (we stop in about 200 epoch). - 3. Loading the pretrained model (face:208x160 & mouth:128x128) and train the DINet in higher resolution (face:416x320 & mouth:256x256). Run + + 3. 加载预训练模型(面部:208x160 & 嘴巴:128x128)并以更高分辨率训练DINet(面部:416x320 & 嘴巴:256x256)。运行 ```python python train_DINet_frame.py --augment_num=20 --mouth_region_size=256 --batch_size=12 --coarse2fine --coarse_model_path=./asserts/training_model_weight/frame_training_128/xxxxxx.pth --result_path=./asserts/training_model_weight/frame_training_256 ``` -You can stop the training when the loss converges (we stop in about 200 epoch). -#### Clip training stage. -In clip training stage, we use perception loss, frame/clip GAN loss and sync loss. Loading the pretrained frame model (face:416x320 & mouth:256x256), pretrained syncnet model (mouth:256x256) and train the DINet in clip setting. Run + +#### 剪辑训练阶段。 +在剪辑训练阶段,我们使用感知损失、帧/剪辑 GAN 损失和同步损失。加载预训练的帧模型(面部:416x320 & 嘴巴:256x256),预训练的同步网络模型(嘴巴:256x256)并在剪辑设置中训练DINet。运行 ```python python train_DINet_clip.py --augment_num=3 --mouth_region_size=256 --batch_size=3 --pretrained_syncnet_path=./asserts/syncnet_256mouth.pth --pretrained_frame_DINet_path=./asserts/training_model_weight/frame_training_256/xxxxx.pth --result_path=./asserts/training_model_weight/clip_training_256 ``` -You can stop the training when the loss converges and select the best model (our best model is at 160 epoch). -## Acknowledge -The AdaAT is borrowed from [AdaAT](https://github.com/MRzzm/AdaAT). The deepspeech feature is borrowed from [AD-NeRF](https://github.com/YudongGuo/AD-NeRF). The basic module is borrowed from [first-order](https://github.com/AliaksandrSiarohin/first-order-model). Thanks for their released code. \ No newline at end of file +## 声明 +整体实现思路来自[https://github.com/MRzzm/DINet](https://github.com/MRzzm/DINet) \ No newline at end of file From 6c4e4dcb1f32c8ffdc07f03d57e01bc5408c2355 Mon Sep 17 00:00:00 2001 From: KeviaWang <104704083+KeviaWang@users.noreply.github.com> Date: Mon, 23 Dec 2024 10:12:18 +0800 Subject: [PATCH 3/3] Add files via upload --- DINet/README.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/DINet/README.md b/DINet/README.md index 7c23017a..0fce21b7 100644 --- a/DINet/README.md +++ b/DINet/README.md @@ -4,6 +4,9 @@ 这是2024年语音识别课程大作业的仓库,用于[DINet](https://github.com/MRzzm/DINet)的复现 +# 复现注意事项 +首先这里知识对于原项目进行了复述,因此具体操作参考配置文档.txt文件进行使用。 + ## 数据获取 ##### 在 [Google drive](https://drive.google.com/drive/folders/1rPtOo9Uuhc59YfFVv4gBmkh0_oG0nCQb?usp=share_link)中下载资源 (asserts.zip)。解压缩并将 dir 放入 ./ 中 + 使用示例视频进行推理。运行 @@ -82,5 +85,21 @@ python train_DINet_frame.py --augment_num=20 --mouth_region_size=256 --batch_siz python train_DINet_clip.py --augment_num=3 --mouth_region_size=256 --batch_size=3 --pretrained_syncnet_path=./asserts/syncnet_256mouth.pth --pretrained_frame_DINet_path=./asserts/training_model_weight/frame_training_256/xxxxx.pth --result_path=./asserts/training_model_weight/clip_training_256 ``` + +# 改进推理和评估 +1.上述推理过程中过于繁琐不便于我们进行测试集的推理和评估,因此这里对于该部分进行了修改,通过一个统一的脚本进行推理,我们只需要将测试集放到test_data文件夹下面我们就可以很方便的进行推理: +``` +docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest run_inference.py --process_num x +``` +process_num 遍历推理个数,结果保存在./asserts/inference_result里面 + 推理阶段默认使用预训练模型,如果使用训练模型加入参数 --model_path ./asserts/training_model_weight/clip_training_256/xxxx.pth + 可以选择音频路径 --audio_path /path/to/your_auido + +2.由于该项目并没有评估脚本来对于模型进行评估,因此这里加入了专用的评估脚本,实现了NIQE,PSNR,FID,SSIM,LSE-C,LSE-D指标,通过运行 +``` +docker run --rm -v path/to/your/DINet:/app --gpus all -it --shm-size=8G -e CUDA_LAUNCH_BLOCKING=1 kevia/dinet:latest run_evaluate.py +``` +评估结果保存在./asserts/evaluate_result里面。 + ## 声明 整体实现思路来自[https://github.com/MRzzm/DINet](https://github.com/MRzzm/DINet) \ No newline at end of file